In [149]:
import torch
from conllu import parse
import socket
from torch import nn
import torch.nn.functional as F
import transformers
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from transformers import BertModel, BertTokenizer
from transformers import AdamW
import tqdm

In [150]:
def collate_fn(batch):
    longest_y = max([y.size(0) for X, y in batch])
    X = [X for X,y in batch]
    y = torch.stack([F.pad(y, (0, longest_y - y.size(0)), value=-1) for X, y in batch])
    return X, y

In [151]:
def build_mask(tokenizer, ids):
    tok_sents = [tokenizer.convert_ids_to_tokens(i) for i in ids]
    mask = []
    for sentence in tok_sents:
        current = []
        for n, token in enumerate(sentence):
            if token in tokenizer.all_special_tokens or token.startswith('##'):
                continue
            else:
                current.append(n)
        mask.append(current)
        
    mask = tokenizer.pad({'input_ids': mask}, return_tensors='pt')['input_ids']
    return mask

In [152]:
TRAIN_PATH = "./data/train.conllu"
DEV_PATH = "./data/dev.conllu"
local_path = 'models/216/'

In [153]:
class NERdata(Dataset):
    
    def __init__(self, data, label_vocab=None):
        self.data = data
        self.label_vocab = label_vocab if label_vocab else list(set([token["misc"]["name"] for text in data for token in text]))
        self.label_vocab.extend(['@UNK'])
        self.label_indexer = {i: n for n, i in enumerate(self.label_vocab)}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        text = [token["form"] for token in x]
        label = [token["misc"]["name"] for token in x]
    
        Y = torch.LongTensor([self.label_indexer[i] if i in self.label_vocab else self.label_indexer['@UNK'] for i in label] )
        
        return text, Y
        
        

In [154]:
class NERmodel(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self._bert = BertModel.from_pretrained('ltgoslo/norbert')
        # for param in self._bert.parameters():
        #     param.requires_grad = False
            
        self._head = nn.Linear(768, num_labels)
        
    def forward(self, batch, mask):
        b = self._bert(batch)
        pooler = b.last_hidden_state[:, mask].diagonal().permute(2,0,1)
        return self._head(pooler)

In [162]:
train_dataset = NERdata(parse(open(TRAIN_PATH, "r").read())[0:36*2])
val_dataset = NERdata(parse(open(DEV_PATH, "r").read())[0:36*2], label_vocab=train_dataset.label_vocab)
train_loader = DataLoader(train_dataset, batch_size=36, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=36, collate_fn=collate_fn)

In [163]:
model = NERmodel(len(train_dataset.label_vocab))
tokenizer = transformers.BertTokenizer.from_pretrained('ltgoslo/norbert', do_basic_tokenize=False)
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = AdamW(model.parameters(), lr=1e-3)

In [164]:
for epoch in range(5):
    model.train()
    for X, y in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        X = tokenizer(X, is_split_into_words=True, return_tensors='pt', padding=True)['input_ids']
        batch_mask = build_mask(tokenizer, X)
        y_pred = model(X, batch_mask).permute(0, 2, 1)
        loss = criterion(y_pred, y.squeeze())
        loss.backward()
        optimizer.step()

    model.eval()
    correct, total = 0, 0
    for X, y in val_loader:
        X = tokenizer(X, is_split_into_words=True, return_tensors='pt', padding=True)['input_ids']
        batch_mask = build_mask(tokenizer, X)
        y_pred = model(X, batch_mask).permute(0, 2, 1)
        #correct += (y_pred.argmax(dim=1) == y.squeeze()).nonzero().size(0)
        total += y.size(0)

    #print(f"epoch: {epoch}; loss: {loss.item()}; val. acc = {correct / total}")
    print(f"epoch: {epoch}; loss: {loss.item()};")

100%|██████████| 2/2 [00:45<00:00, 22.61s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

epoch: 0; loss: 0.48966094851493835;


100%|██████████| 2/2 [00:41<00:00, 20.99s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

epoch: 1; loss: 0.42910173535346985;


100%|██████████| 2/2 [00:31<00:00, 15.82s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

epoch: 2; loss: 0.5230061411857605;


100%|██████████| 2/2 [00:31<00:00, 15.89s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

epoch: 3; loss: 0.8582348227500916;


100%|██████████| 2/2 [00:28<00:00, 14.42s/it]


epoch: 4; loss: 0.8054361939430237;
