In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import pandas as pd
import typing as T
import string
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import pickle
from sklearn.metrics import classification_report
import os

In [28]:
torch.manual_seed(42)

<torch._C.Generator at 0x106cc6ff0>

In [29]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
device

device(type='mps')

In [30]:
def get_data():
    train = pd.read_json("./data/subtaskA_train_monolingual.jsonl", lines=True)
    dev = pd.read_json("./data/subtaskA_dev_monolingual.jsonl", lines=True)
    return train, dev

train, dev = get_data()

In [31]:
WHITESPACE = "<WS>"
PUNCTUATION = "<PUNCT>"
DIGIT = "<DIGIT>"
UNK = "<UNK>"
SENT_TERMINATE = "<SENT_TERMINATE>"

BOS = "<BOS>"
EOS = "<EOS>"
PAD = "<PAD>"

def map_char(char: str):
    sentence_ending = [".", "!", "?"]
    if char.isspace():
        return WHITESPACE
    if char in sentence_ending:
        return SENT_TERMINATE
    if char in string.punctuation:
        return PUNCTUATION
    if char in string.digits:
        return DIGIT
    if char not in string.printable:
        return UNK
    return char

def build_vocab(train_set: pd.DataFrame):
    vocab = set()
    for _, series in train_set.iterrows():
        text: str = series["text"]
        tokens: T.List[str] = [*text.lower().strip()]
        tokens = [map_char(token) for token in tokens]
        for token in tokens:
            vocab.add(token)
    vocab = list(vocab)
    
    vocab.append(BOS)
    vocab.append(EOS)
    vocab.append(PAD)
    
    word2idx = {
        word: idx for idx, word in enumerate(vocab)
    }
    idx2word = {
        idx: word for idx, word in enumerate(vocab)
    }
    return word2idx, idx2word, vocab

def get_vocab():
    fp = "./data/charlm_vocab.pkl"
    try:
        with open(fp, "rb") as f:
            return pickle.load(f)
    except:
        train, _ = get_data()
        res = build_vocab(train)
        with open(fp, "wb") as f:
            pickle.dump(res, f)
        return res
    
word2idx, idx2word, vocab = get_vocab()

In [32]:
def get_text_tokens(text:str):
    tokens: T.List[str] = [*text.lower().strip()]
    tokens = [map_char(token) for token in tokens]
    return tokens

def tokenize(texts: T.List[str], max_len=None, add_special_tokens=True):
    tokenized_texts = [get_text_tokens(t) for t in texts]
    
    longest_len = max([len(t) for t in tokenized_texts])
    if (max_len < longest_len):
        longest_len = max_len
    tokenized_texts = [t[:longest_len] for t in tokenized_texts]
    
    tokens, attentions = [], []
    for tokenized_text in tokenized_texts:
        
        pad_amount = longest_len - len(tokenized_text)
        if add_special_tokens:
            tokenized_text = [BOS] + tokenized_text + [EOS]
        
        tokenized_text += [PAD] * (pad_amount)
        tokens.append([word2idx.get(token, UNK) for token in tokenized_text])
        attentions.append([1 if token != PAD else 0 for token in tokenized_text])
    return torch.tensor(tokens, device=device), torch.tensor(attentions, device=device)

def decode(tokens: T.List[T.List[int]]):
    return [[idx2word[token] for token in tokenized_text] for tokenized_text in tokens]

In [33]:
class TaskA_Dataset(Dataset):
    def __init__(self, split="train") -> None:
        if (split == "train"):
            self.data = pd.read_json("./data/subtaskA_train_monolingual.jsonl", lines=True)
        else:
            self.data = pd.read_json("./data/subtaskA_dev_monolingual.jsonl", lines=True)
        
    
    def __len__(self):
        return len(self.data) 
    
    def __getitem__(self, index):
        item = self.data.iloc[index]
        text, label, _id = item["text"], item["label"], item["id"]
        return text, label, _id
    
        

In [34]:
class CharLM(nn.Module):
    def __init__(self, vocab_size=None, emb_size=8, hidden_size=1024, num_layers=1) -> None:
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.lstm = nn.LSTM(
            hidden_size=hidden_size,
            input_size=emb_size,
            num_layers=num_layers,
            batch_first=True,
        )
        self.classifier_head = nn.Linear(hidden_size, 2)

    
    def forward(self, input_ids, attention):
        embedded = self.emb(input_ids)
        out, _ = self.lstm(embedded)
        out = out[:, -1, :]
        pred = self.classifier_head(out)
        pred = F.log_softmax(pred, dim=1)
        return pred
        

def collate_fn(data):
    labels = [i[1] for i in data]
    texts = [i[0] for i in data]
    ids = [i[2] for  i in data]
    max_len = 10_000
    input_ids, attentions = tokenize(texts, max_len=max_len)
    return input_ids, attentions, torch.tensor(labels, device=device), torch.tensor(ids, device=device)

In [35]:
def evaluate(model, dataset):
    dev_dataloader = DataLoader(dataset, shuffle=False, batch_size=10, collate_fn=collate_fn)
    y_pred = []
    y_gold = []
    with torch.no_grad():
        for input_ids, attentions, labels, _ in dev_dataloader:
            out = model(input_ids, attentions)
            for i in range(out.shape[0]):
                pred = torch.argmax(out[i]).item()
                y_pred.append(pred)
                y_gold.append(labels[i].item())
    
    print(classification_report(y_gold, y_pred))
    

In [36]:
def make_checkpoint(model, optimizer, epoch, prefix="classifier"):
    try:
        os.mkdir("checkpoints")
    except FileExistsError:
        pass
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
    }
    torch.save(checkpoint, f"checkpoints/{prefix}_{epoch}.pt")

In [37]:
def train(model=None, optimizer=None, dataloader=None, n_epochs=5, checkpoint_prefix=None, start_epoch=1):
    criterion = nn.NLLLoss(reduction="mean")
    
    
    for epoch in range(start_epoch, n_epochs):
        with tqdm(total=len(dataloader)) as pbar:
            pbar.set_description(f"Epoch {epoch}")
            losses = []
            for input_ids, attentions, labels, _ in dataloader:
                optimizer.zero_grad()
                
                classifier_out = model(input_ids, attentions)
                
                # ------------------
                # Classifier loss
                # ------------------
                
                loss = criterion(classifier_out, labels)
                
                if torch.isnan(loss):
                    print("LOSS IS NAN")
                    continue
                                
                # ------------------
                # Backprop
                # ------------------
                
                losses.append(loss.item())
                loss.backward()
                optimizer.step()
                pbar.update(1)

            print("LOSS", sum(losses) / len(losses))
            evaluate(model, TaskA_Dataset(split="dev"))
            make_checkpoint(model, optimizer, epoch, prefix=checkpoint_prefix)
    
    return losses

In [38]:
model = CharLM(
    vocab_size=len(vocab),
    hidden_size=256,
    num_layers=2
)
model.to(device)
start_epoch = 1
prefix="charclass_256_2"
optimizer = AdamW(model.parameters(), lr=0.001)
ds = TaskA_Dataset(split="train")
loader = DataLoader(
    ds,
    shuffle=True,
    batch_size=8,
    collate_fn=collate_fn
)

CP = None

if CP:
    checkpoint_data = torch.load(CP)
    model.load_state_dict(checkpoint_data["model"])
    optimizer.load_state_dict(checkpoint_data["optimizer"])
    start_epoch = checkpoint_data["epoch"] + 1
    print("-------------------------")
    print("CHECKPOINT MODEL EVAL")
    print("-------------------------")
    # evaluate(model, TaskA_Dataset(split="dev"))
    print()

losses = train(
    model=model, 
    optimizer=optimizer, 
    dataloader=loader, 
    n_epochs=5, 
    checkpoint_prefix=prefix,
    start_epoch=start_epoch
)

Epoch 1: 100%|██████████| 14970/14970 [8:05:47<00:00,  3.72s/it]  

LOSS 0.6671743715275266


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 1: 100%|██████████| 14970/14970 [8:09:32<00:00,  1.96s/it]


              precision    recall  f1-score   support

           0       0.50      1.00      0.67      2500
           1       0.00      0.00      0.00      2500

    accuracy                           0.50      5000
   macro avg       0.25      0.50      0.33      5000
weighted avg       0.25      0.50      0.33      5000



Epoch 2:  12%|█▏        | 1849/14970 [48:21<5:43:07,  1.57s/it] 


KeyboardInterrupt: 