In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, dataloader
import torch.nn.functional as F
import json
import pickle
import re
from typing import List, Union
from tqdm import tqdm
import math
import os
import numpy as np


In [2]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
device

device(type='mps')

In [3]:
PAD = "<PAD>"
UNK = "<UNK>"

with open("wiki2vec/dump_t.pkl", "rb") as f:
    word2idx, idx2word, embeddings = pickle.load(f)
    
WHITESPACE = "<WS>"
word2idx[WHITESPACE] = len(word2idx)
idx2word[word2idx[WHITESPACE]] = WHITESPACE
dim = embeddings.shape[1]
ws_emb = torch.randn(1, dim)
embeddings = torch.cat([embeddings, ws_emb], dim=0)
embeddings.shape

torch.Size([4530033, 300])

In [15]:
class TaskC(Dataset):
    def __init__(self, split="train", enrich=False, max_len=2000, word2idx=None):
        super().__init__()
        self.split = split
        self.enrich = enrich
        self.max_len = max_len
        self.word2idx = word2idx
        self.data, self.perplexities = self.load_data()
    
    def load_data(self):
        f = "data/subtaskC_train.jsonl" if self.split == "train" else "data/subtaskC_dev.jsonl"
        ppl_f = "data/ppl_train.json" if self.split == "train" else "data/ppl_dev.json"
        data = []
        
        with open(f, "r", encoding="utf-8") as f:
            for line in f:
                parsed = json.loads(line)
                data.append(parsed)
        with open(ppl_f, "r", encoding="utf-8") as f:
            ppl_data = json.load(f)
                
        
        return data, ppl_data
    
    def _clean(self, word):
        # all non alphanumeric
        replace = re.compile(r"[^a-zA-Z0-9\-]")
        word = replace.sub("", word)
        return word.lower().strip()
    
    def _is_whitespace(self, word):
        pat = re.compile(r"^\s*$")
        return pat.match(word) is not None

    def tokenize(self, text) -> List[int]:
        tokens = text.split(" ")
        ids = []
        for token in tokens:
            token = self._clean(token)
            if self._is_whitespace(token):
                ids.append(self.word2idx[WHITESPACE])
            elif token in self.word2idx:
                ids.append(self.word2idx[token])
            else:
                ids.append(self.word2idx[UNK])
        if len(ids) > self.max_len:
            ids = ids[:self.max_len]
        elif len(ids) < self.max_len:
            ids = ids + ([self.word2idx[PAD]] * (self.max_len - len(ids)))
        return torch.tensor(ids, dtype=torch.long)
    
    def get_perplexities(self, idx):
        ppl = self.perplexities[idx]
        
        if len(ppl) > self.max_len:
            ppl = ppl[:self.max_len]
        elif len(ppl) < self.max_len:
            last = ppl[-1]
            ppl = ppl + ([last] * (self.max_len - len(ppl)))
        return torch.tensor(ppl)
    
    def get_label_vector(self, boundary:int):
        base = [0] * self.max_len
        if boundary > -1:
            for i in range(boundary, self.max_len):
                base[i] = 1
        return torch.tensor(base, dtype=torch.long)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data = self.data[index]
        text = self.tokenize(data["text"])
        ppl = self.get_perplexities(index)
        label = self.get_label_vector(data["label"])
        return text, ppl, label

In [5]:
class BiLSTM_Labeller(nn.Module):
    def __init__(
        self, 
        pretrained_embeddings = None,
        hidden_size = 256,
        num_layers = 1,
        feature_vector_dim = 3
    ) -> None:
        
        super().__init__()
        self.hidden_size = hidden_size
        
        self.emb = nn.Embedding.from_pretrained(pretrained_embeddings)
        self.emb.cpu()
        self.input_size = self.emb.weight.shape[1] + feature_vector_dim
        
        self.lstm = nn.LSTM(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
        )
        self.lstm.to(device)
        # multiply by 2 because of bidirectional
        self.classifier_head = nn.Linear(self.hidden_size * 2, 2)
        self.classifier_head.to(device)
        
    def forward(self, input_ids, feature_vectors):
        inputs: torch.Tensor = self.emb(input_ids)
        inputs = torch.cat((inputs, feature_vectors), dim=-1)
        inputs = inputs.to(device)
        outputs, _ = self.lstm(inputs)
        logits = self.classifier_head(outputs)
        predicted = F.log_softmax(logits, dim=-1)
        return predicted
    
    def predict(self, input_ids, feature_vectors) -> List[int]:
        inputs = torch.cat((input_ids, feature_vectors), dim=-1)
        inputs = inputs.to(device)
        predicted = self.forward(inputs)
        predicted = torch.argmax(predicted, dim=-1)
        p = predicted.cpu().numpy().tolist()
        r = []
        for item in p:
            try:
                r.append(item.index(1))
            except ValueError:
                r.append(len(item) // 2)
        return r
    
model = BiLSTM_Labeller(
    pretrained_embeddings=embeddings
)



In [6]:
def get_start_position(pred: torch.Tensor) -> int:
    pred_np: np.ndarray = pred.detach().numpy()
    result = []
    for labels in pred_np:
        try:
            result.append(labels.tolist().index(1))
        except ValueError:
            result.append(len(labels) // 2)
    return result

In [17]:
taskC_dev = TaskC(split="dev", word2idx=word2idx)
dev_dataloader = dataloader.DataLoader(taskC_dev, batch_size=32, shuffle=False)

def eval(model):
    pred = []
    gold = []
    with torch.no_grad():
        for input_ids, perplexities, labels in dev_dataloader:
            labels = get_start_position(labels)
            predicted = model.predict(input_ids, perplexities)
            pred.extend(predicted)
            gold.extend(labels)
    distances = [abs(pred[i] - gold[i]) for i in range(len(pred))]
    print("MEAN ABSOLUTE DISTANCE:", sum(distances) / len(distances))
    return pred, gold



In [12]:
def make_checkpoint(model, optimizer, epoch):
    try:
        os.mkdir("checkpoints")
    except FileExistsError:
        pass
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
    }
    torch.save(checkpoint, f"checkpoints/checkpoint_{epoch}.pt")

In [13]:
def train(model, dataloader, optimizer, criterion, epochs=10):
    model.train()
    
    for epoch in range(1, epochs+1): 
        losses = []
        with tqdm(dataloader, total=len(dataloader)) as pbar:
            for batch in dataloader:
                optimizer.zero_grad()
                input_ids, perplexities, labels = batch
                # do not put inputs to cuda, because emb is on cpu
                labels = labels.to(device)
                
                outputs = model(input_ids, perplexities)
                # put outputs in shape (batch_size * seq_len, 2)
                outputs = outputs.reshape(-1, 2)
                # put labels in shape (batch_size * seq_len)
                labels = labels.reshape(-1)
                loss = criterion(outputs, labels)
                losses.append(loss.item())
                loss.backward()
                optimizer.step()
                pbar.update(1)
        print(f"\n---- EVALUATING MODEL at epoch {epoch} ----")
        print("LOSS:", sum(losses) / len(losses))
        eval(model)
        print("------------------------------------------\n")
        # make_checkpoint(model, optimizer, epoch)

In [16]:
loader = dataloader.DataLoader(
    TaskC(word2idx=word2idx, split="train"),
    batch_size=32,
    shuffle=True,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

train(model, loader, optimizer, criterion)

100%|██████████| 115/115 [02:06<00:00,  1.10s/it]


---- EVALUATING MODEL at epoch 1 ----
LOSS: 0.07050031822012819





TypeError: unsupported operand type(s) for -: 'list' and 'int'