In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import pandas as pd
import typing as T
import string
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import pickle
from sklearn.metrics import classification_report

In [2]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
device

device(type='mps')

In [3]:
def get_data():
    train = pd.read_json("./data/subtaskA_train_monolingual.jsonl", lines=True)
    dev = pd.read_json("./data/subtaskA_dev_monolingual.jsonl", lines=True)
    return train, dev

train, dev = get_data()

In [4]:
WHITESPACE = "<WS>"
PUNCTUATION = "<PUNCT>"
DIGIT = "<DIGIT>"
UNK = "<UNK>"
SENT_TERMINATE = "<SENT_TERMINATE>"

BOS = "<BOS>"
EOS = "<EOS>"
PAD = "<PAD>"

def map_char(char: str):
    sentence_ending = [".", "!", "?"]
    if char.isspace():
        return WHITESPACE
    if char in sentence_ending:
        return SENT_TERMINATE
    if char in string.punctuation:
        return PUNCTUATION
    if char in string.digits:
        return DIGIT
    if char not in string.printable:
        return UNK
    return char

def build_vocab(train_set: pd.DataFrame):
    vocab = set()
    for _, series in train_set.iterrows():
        text: str = series["text"]
        tokens: T.List[str] = [*text.lower().strip()]
        tokens = [map_char(token) for token in tokens]
        for token in tokens:
            vocab.add(token)
    vocab = list(vocab)
    
    vocab.append(BOS)
    vocab.append(EOS)
    vocab.append(PAD)
    
    word2idx = {
        word: idx for idx, word in enumerate(vocab)
    }
    idx2word = {
        idx: word for idx, word in enumerate(vocab)
    }
    return word2idx, idx2word, vocab

def get_vocab():
    fp = "./data/charlm_vocab.pkl"
    try:
        with open(fp, "rb") as f:
            return pickle.load(f)
    except:
        train, _ = get_data()
        res = build_vocab(train)
        with open(fp, "wb") as f:
            pickle.dump(res, f)
        return res
    
word2idx, idx2word, vocab = get_vocab()

In [5]:
def get_text_tokens(text:str):
    tokens: T.List[str] = [*text.lower().strip()]
    tokens = [map_char(token) for token in tokens]
    return tokens

def tokenize(texts: T.List[str], max_len=None, add_special_tokens=True):
    tokenized_texts = [get_text_tokens(t) for t in texts]
    
    longest_len = max([len(t) for t in tokenized_texts])
    if (max_len < longest_len):
        longest_len = max_len
    tokenized_texts = [t[:longest_len] for t in tokenized_texts]
    
    tokens, attentions = [], []
    for tokenized_text in tokenized_texts:
        
        pad_amount = longest_len - len(tokenized_text)
        if add_special_tokens:
            tokenized_text = [BOS] + tokenized_text + [EOS]
        
        tokenized_text += [PAD] * (pad_amount)
        tokens.append([word2idx[token] for token in tokenized_text])
        attentions.append([1 if token != PAD else 0 for token in tokenized_text])
    return torch.tensor(tokens, device=device), torch.tensor(attentions, device=device)

def decode(tokens: T.List[T.List[int]]):
    return [[idx2word[token] for token in tokenized_text] for tokenized_text in tokens]

In [6]:
class TaskA_Dataset(Dataset):
    def __init__(self, split="train") -> None:
        if (split == "train"):
            self.data = pd.read_json("./data/subtaskA_train_monolingual.jsonl", lines=True)
        else:
            self.data = pd.read_json("./data/subtaskA_dev_monolingual.jsonl", lines=True)
        
    
    def __len__(self):
        return len(self.data) 
    
    def __getitem__(self, index):
        item = self.data.iloc[index]
        text, label, _id = item["text"], item["label"], item["id"]
        return text, label, _id
    
        

MAX char ~ 200.000
MEAN char ~ 5000

In [7]:
class CharLM(nn.Module):
    def __init__(self, vocab_size=None, emb_size=8, hidden_size=1024, num_layers=1) -> None:
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.lstm = nn.LSTM(
            hidden_size=hidden_size,
            input_size=emb_size,
            num_layers=num_layers,
            batch_first=True,
        )
        self.lstm2lm = nn.Linear(hidden_size, vocab_size)
        self.lstm2class = nn.Linear(hidden_size, 2)
    
    def forward(self, input_ids):
        embedded = self.emb(input_ids)
        out, _ = self.lstm(embedded)
        lm_out = self.lstm2lm(out)
        lm_out = F.log_softmax(lm_out, dim=-1)
        classification_out = self.lstm2class(torch.mean(out, dim=1))
        classification_out = F.log_softmax(classification_out, dim=-1)
        return lm_out, classification_out

def collate_fn(data):
    labels = [i[1] for i in data]
    texts = [i[0] for i in data]
    ids = [i[2] for  i in data]
    max_len = 10000
    input_ids, attentions = tokenize(texts, max_len=max_len)
    return input_ids, attentions, torch.tensor(labels, device=device), torch.tensor(ids, device=device)

In [8]:
def evaluate(model, dataset):
    dev_dataloader = DataLoader(dataset, shuffle=False, batch_size=10, collate_fn=collate_fn)
    y_pred = []
    y_gold = []
    with torch.no_grad():
        for input_ids, attentions, labels, _ in dev_dataloader:
            _, out = model(input_ids)
            for i in range(out.shape[0]):
                pred = torch.argmax(out[i]).item()
                y_pred.append(pred)
                y_gold.append(labels[i].item())
    
    print(classification_report(y_gold, y_pred))
    

In [9]:
from math import isnan


def train(model=None, optimizer=None, dataloader=None, n_epochs=5):
    lm_criterion = nn.NLLLoss(reduction="mean")
    cl_criterion = nn.NLLLoss(reduction="sum")
    
    losses = []
    
    with tqdm(total=len(dataloader)) as pbar:
        for i in range(1, n_epochs+1):
            pbar.set_description(f"Epoch {i}")
            for input_ids, attentions, labels, text_ids in dataloader:
                optimizer.zero_grad()
                
                lm_out, classifier_out = model(input_ids)
                loss = torch.tensor(0, dtype=torch.float32, device=device)
                
                # ------------------
                # LM loss
                # ------------------
                
                for i in range(input_ids.shape[0]):
                    if labels[i].item() == 0:
                        # only train LM on human texts
                        y_pred = lm_out[i, :-1]
                        y_gold = input_ids[i, 1:]
                        loss_update = lm_criterion(y_pred, y_gold)
                        if isnan(loss_update.item()):
                            print(text_ids[i].item())
                            print(input_ids)
                            print(y_pred, y_gold)
                            return
                        loss += loss_update
                
                # ------------------
                # Classifier loss
                # ------------------
                loss_update = cl_criterion(classifier_out, labels)
                loss += loss_update
                
                # ------------------
                # Backprop
                # ------------------
                
                losses.append(loss.item())
                loss.backward()
                optimizer.step()
                pbar.update(1)
                pbar.set_postfix({"LOSS": sum(losses)/len(losses)})
                
            evaluate(model, TaskA_Dataset(split="dev"))
    
    return losses

In [10]:
model = CharLM(
    vocab_size=len(vocab),
    hidden_size=126
)
model.to(device)
evaluate(model, TaskA_Dataset(split="dev"))
optimizer = AdamW(model.parameters(), lr=0.005)
ds = TaskA_Dataset(split="train")
loader = DataLoader(
    ds,
    shuffle=True,
    batch_size=4,
    collate_fn=collate_fn
)

losses = train(model=model, optimizer=optimizer, dataloader=loader)

              precision    recall  f1-score   support

           0       0.57      0.13      0.21      2500
           1       0.51      0.90      0.65      2500

    accuracy                           0.52      5000
   macro avg       0.54      0.52      0.43      5000
weighted avg       0.54      0.52      0.43      5000



  from .autonotebook import tqdm as notebook_tqdm
Epoch 1:   0%|          | 47/29940 [00:25<4:26:31,  1.87it/s, LOSS=8.69]


KeyboardInterrupt: 