<a href="https://colab.research.google.com/github/myutman/NLP/blob/master/HW4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers
!pip install wandb



In [2]:
import transformers
from transformers import BertTokenizer, BertModel

import torch
import torch.nn as nn
import torch.optim as optim
from torch.functional import F

from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import re

import wandb
wandb.init(project="my-project")

W&B Run: https://app.wandb.ai/myutman/my-project/runs/4zu3z05z

In [0]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [36]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
class MLP(nn.Module):
    def __init__(self, nin, nout, n_hidden):
        super(MLP, self).__init__()
        self.layers = []
        self.layers.extend([
            nn.Linear(nin, 32),
            nn.Sigmoid(),
            nn.Dropout(0.5)
        ])
        for i in range(1, n_hidden - 1):
            self.layers.extend([
                nn.Linear(32, 32),
                nn.Sigmoid(),
                nn.Dropout(0.5)
            ])
        self.layers.extend([
            nn.Linear(32, nout)
        ])
        self.layers = nn.Sequential(*self.layers)
    
    def forward(self, x):
        return self.layers(x)


class QAModel(nn.Module):
    def __init__(self):
        super(QAModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
        for par in self.bert.parameters():
            par.requires_grad = False

        self.mlp_fr = MLP(768, 1, 3)
        self.softmax_fr = nn.Softmax(dim=-1)

        self.mlp_to = MLP(768, 1, 3)
        self.softmax_to = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        states, _ = self.bert(x, attention_mask=mask)
        
        fr_vec = self.mlp_fr(states).reshape(-1, 256)
        fr_out = self.softmax_fr(fr_vec)
        
        to_vec = self.mlp_fr(states).reshape(-1, 256)
        to_out = self.softmax_to(to_vec)

        return fr_out, to_out

EPS = 1e-9

def J(output_froms, output_tos, froms, tos):
    #at_froms = froms != -1
    #good_output_froms = output_froms[at_froms]
    #good_froms = froms[at_froms]

    #at_tos = tos != -1
    #good_output_tos = output_tos[at_tos]
    #good_tos = tos[at_tos]

    #fr_probs = - torch.log(torch.gather(good_output_froms, -1, good_froms[:,None]) + EPS)
    #to_probs = - torch.log(torch.gather(good_output_tos, -1, good_tos[:, None]) + EPS)
    #no_fr_probs = - torch.log(output_froms + EPS).mean(dim=-1) * (froms == -1).float()
    #no_to_probs = - torch.log(output_tos + EPS).mean(dim=-1) * (tos != -1).float()
    #return (fr_probs.sum() + to_probs.sum() + no_fr_probs.sum() + no_to_probs.sum()) / len(froms)

    fr_probs = - torch.log(output_froms[torch.arange(len(output_froms)), froms] + EPS) * (froms != -1).float()
    to_probs = - torch.log(output_tos[torch.arange(len(output_tos)), tos] + EPS) * (tos != -1).float()
    no_fr_probs = - torch.log(output_froms + EPS).mean(dim=-1) * (froms == -1).float()
    no_to_probs = - torch.log(output_tos + EPS).mean(dim=-1) * (tos == -1).float()
    return (fr_probs.sum() + to_probs.sum() + no_fr_probs.sum() + no_to_probs.sum()) / len(froms)




In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

In [0]:
df = pd.read_csv('/content/gdrive/My Drive/NLP/train_qa.csv')

In [0]:
def prepare_sample(text, quest, ans):
    X = []
    froms = []
    tos = []
    masks = []

    text = ' '.join(re.findall('\w+', text.lower()))
    quest = ' '.join(re.findall('\w+', quest.lower()))
    ans = ' '.join(re.findall('\w+', ans.lower()))

    text_tokens = tokenizer.encode(text)[1:-1]
    quest_tokens = tokenizer.encode(quest)[1:-1]
    ans_tokens = tokenizer.encode(ans)[1:-1]
    
    fr = -1
    to = -1
    for i in range(len(text_tokens) - len(ans_tokens)):
        if text_tokens[i:i + len(ans_tokens)] == ans_tokens:
            fr = i
            to = i + len(ans_tokens) - 1
    
    if fr == -1:
        return None, None, None, None

    l = 0
    r = max_len - len(quest_tokens) - 4
    if len(text_tokens) > max_len - len(quest_tokens) - 3:
        step = (max_len - len(quest_tokens) - 3) // 3
    else:
        step = len(text_tokens)
    while len(text_tokens) > 0:
        cnt = min(len(text_tokens), max_len - len(quest_tokens) - 3)
        tokens = [tokenizer.cls_token_id] + quest_tokens + [tokenizer.sep_token_id] + text_tokens[:cnt] + [tokenizer.sep_token_id] + [tokenizer.pad_token_id] * (max_len - cnt - len(quest_tokens) - 3)
        mask = [0 if (i < len(quest_tokens) + 1) or (tokens[i] == tokenizer.pad_token_id) or (tokens[i] == tokenizer.sep_token_id) else 1 for i in range(max_len)]
        X.append(tokens)
        masks.append(mask)
        if fr >= l and fr <= r:
            froms.append(len(quest_tokens) + 2 + fr - l)
        else:
            froms.append(-1)
        if to >= l and to <= r:
            tos.append(len(quest_tokens) + 2 + to - l)
        else:
            tos.append(-1)

        text_tokens = text_tokens[step:]
        l += step
        r += step
    return X, froms, tos, masks

In [0]:
max_len = 256

def colate_fn(data):
    #print(len(data))
    X, fr, to, mask = zip(*data)
    #print(torch.tensor(X).shape)
    #print(torch.tensor(y).shape)
    #print(torch.tensor(mask).shape)
    return torch.tensor(X), torch.tensor(fr), torch.tensor(to), torch.tensor(mask)

def prepare_dataset(df, batch_size = 16):
    texts = list(df['paragraph'])
    quests = list(df['question'])
    anss = list(df['answer'])
    X = []
    froms = []
    tos = []
    masks = []
    for text, quest, ans in list(zip(texts, quests, anss)):
        x, fr, to, mask = prepare_sample(text, quest, ans)
        if x == None:
            continue
        X.extend(x)
        froms.extend(fr)
        tos.extend(to)
        masks.extend(mask)

    X_train, X_test, froms_train, froms_test, tos_train, tos_test, mask_train, mask_test = train_test_split(X, froms, tos, masks, test_size=0.2)
    train_data = torch.utils.data.DataLoader(list(zip(X_train, froms_train, tos_train, mask_train)), batch_size=batch_size, collate_fn=colate_fn)
    test_data = torch.utils.data.DataLoader(list(zip(X_test, froms_test, tos_test, mask_test)), batch_size=batch_size, collate_fn=colate_fn)
    return train_data, test_data

In [10]:
train_data, test_data = prepare_dataset(df)

Token indices sequence length is longer than the specified maximum sequence length for this model (513 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (793 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (539 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (515 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (766 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for thi

In [11]:
model = QAModel()
model.to(device)
model.train()

QAModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   

In [0]:
#Head training

num_epochs = 1
adam = optim.Adam(filter(lambda par: par.requires_grad, model.parameters()), lr=1e-3)

train_losses = []
test_losses = []

kek = False
for i in range(num_epochs):
    losses = []
    for X, fr, to, mask in train_data:
        adam.zero_grad()
        out_fr, out_to = model(X.to(device), mask.to(device))

        loss = J(out_fr.to(device), out_to.to(device), fr.to(device), to.to(device))
        wandb.log({'epoch': i, 'train-loss': loss.cpu().item()})

        #loss = J(out_fr.cpu(), out_to.cpu(), fr, to)
        losses.append(float(loss.cpu().item()))

        loss.backward()
        adam.step()

    train_losses.append(np.mean(losses))

    losses = []
    with torch.no_grad():
        for X, fr, to, mask in test_data:
            out_fr, out_to = model(X.to(device), mask.to(device))

            loss = J(out_fr.to(device), out_to.to(device), fr.to(device), to.to(device))
            losses.append(float(loss.cpu().item()))
            wandb.log({'epoch': i, 'test-loss': loss.cpu().item()})

    test_losses.append(np.mean(losses))    

In [0]:
train_data, test_data = prepare_dataset(df, 2)

for par in model.parameters():
    par.requires_grad = True

#Fine tuning
num_epochs = 3
adam = optim.Adam(filter(lambda par: par.requires_grad, model.parameters()), lr=5e-5)

train_losses = []
test_losses = []

kek = False
for i in range(num_epochs):
    losses = []
    for X, fr, to, mask in train_data:
        adam.zero_grad()
        out_fr, out_to = model(X.to(device), mask.to(device))

        loss = J(out_fr.to(device), out_to.to(device), fr.to(device), to.to(device))
        wandb.log({'epoch': i, 'train-loss': loss.cpu().item()})

        #loss = J(out_fr.cpu(), out_to.cpu(), fr, to)
        losses.append(float(loss.cpu().item()))

        loss.backward()
        adam.step()

    train_losses.append(np.mean(losses))

    losses = []
    with torch.no_grad():
        for X, fr, to, mask in test_data:
            out_fr, out_to = model(X.to(device), mask.to(device))

            loss = J(out_fr.to(device), out_to.to(device), fr.to(device), to.to(device))
            losses.append(float(loss.cpu().item()))
            wandb.log({'epoch': i, 'test-loss': loss.cpu().item()})

    test_losses.append(np.mean(losses))

In [39]:
test_df = pd.read_csv('/content/gdrive/My Drive/NLP/dataset_1.txt', delimiter='\t')

model.eval()
with torch.no_grad():
    with open('/content/gdrive/My Drive/NLP/output.txt', 'w') as output:
        for pid, qid, text, quest in zip(test_df['paragraph_id'], test_df['question_id'], test_df['paragraph'], test_df['question']):
            X, _, _, mask = prepare_sample(text, quest, '')
            X = torch.tensor(X).to(device)
            mask = torch.tensor(mask).to(device)
            f, t = model(X, mask)
            #mask = mask.reshape(-1).bool()
            #X = X.reshape(-1)[mask]
            #f = f.reshape(-1)[mask]
            #t = t.reshape(-1)[mask]

            mx = 0

            f = f.cpu().numpy()
            t = t.cpu().numpy()
            X = X.cpu().numpy()

            for i in range(t.shape[0]):
                if np.max(f[i]) * np.max(t[i]) > mx:
                    mx = np.max(f[i]) * np.max(t[i])
                    fr = np.argmax(f[i])
                    to = np.argmax(t[i])

                    ans = re.sub('#', '', tokenizer.decode(X[i][fr:to + 1]))
            if ans.isspace() or ans == '':
                ans = quest

            output.write(f'{qid}\t{ans}\n')



Token indices sequence length is longer than the specified maximum sequence length for this model (623 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (562 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (511 > 512). Running this sequence through the model will result in indexing errors


In [0]:
tokenizer.decode(torch.tensor(X).reshape(-1)[f.argmax(-1):t.argmax(-1) + 1].cpu().numpy())

In [0]:
text = ' '.join(re.findall('\w+', text.lower()))
quest = ' '.join(re.findall('\w+', quest.lower()))
text_tokens = tokenizer.encode(text)[1:-1]
quest_tokens = tokenizer.encode(quest)[1:-1]
kek = [tokenizer.cls_token_id] + quest_tokens + [tokenizer.sep_token_id] + text_tokens + [tokenizer.sep_token_id]

In [0]:
tokenizer.decode(kek[f.argmax(-1):t.argmax(-1) + 1])

In [0]:
!nvidia-smi

Mon Dec  9 17:56:23 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.36       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru