## Model

In [None]:
%%capture
!pip install -q transformers

In [None]:
import re
from collections import defaultdict
from random import random

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ConstantLR, LinearLR, OneCycleLR, ReduceLROnPlateau
from transformers import (BartModel, BartTokenizerFast, BertForMaskedLM, BertForTokenClassification,
                          BertTokenizerFast, BertForSequenceClassification, BertModel,
                          RobertaForMaskedLM, RobertaForTokenClassification, RobertaTokenizerFast,
                          RobertaModel, RobertaForSequenceClassification, RobertaForCausalLM)
from transformers import DataCollatorForLanguageModeling

In [None]:
class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Add a roberta model for contexts
        self.roberta_tokenizer = RobertaTokenizerFast.from_pretrained('roberta-large')
        self.roberta_model = RobertaForMaskedLM.from_pretrained('roberta-large')
        
        # freeze embedding layer, so that we don't only modify trained words
        for param in self.roberta_model.roberta.embeddings.word_embeddings.parameters():
            param.requires_grad = False
        
        self.data_collator = DataCollatorForLanguageModeling(self.roberta_tokenizer, mlm=True, mlm_probability=0.15)

        self.metaphor = self.roberta_tokenizer.encode(' metaphor', add_special_tokens=False)[0]
        self.literal = self.roberta_tokenizer.encode(' literal', add_special_tokens=False)[0]

    def forward(self, sentences, word_indices, add_random_mask=False, labels=None, predict=False):

        samples = []
        for i, (sentence, word_index) in enumerate(zip(sentences, word_indices)):
            if predict is False:
                # label = 'literal' if labels[i] == self.literal else 'metaphor'
                label = self.roberta_tokenizer.mask_token
            else:
                label = self.roberta_tokenizer.mask_token
            word_matches = list(re.finditer(r'\S+', sentence))
            str_start, str_end = word_matches[word_index].start(), word_matches[word_index].end()
            sentence = sentence + (f'{self.roberta_tokenizer.sep_token}{self.roberta_tokenizer.sep_token}'
                                   f'The word {sentence[str_start:str_end]} is a {label}.')
            samples.append(sentence)

        tokenized = self.roberta_tokenizer(samples, padding=True, return_tensors='pt').to(self.device)

        if labels is not None:
            # masked = self.data_collator(tokenized['input_ids'].tolist())
            # sentence_outputs = self.roberta_model(input_ids=masked['input_ids'].to(self.device),
            #                                       labels=masked['labels'].to(self.device),
            #                                       attention_mask=tokenized['attention_mask'])
            # loss = sentence_outputs.loss
            processed_labels = []
            for i, input_ids in enumerate(tokenized['input_ids']):
                sample_labels = []
                for input_id in input_ids:
                    if input_id == self.roberta_tokenizer.mask_token_id:
                        if labels[i] == self.metaphor:
                            sample_labels.append(self.metaphor)
                        else:
                            sample_labels.append(self.literal)
                    else:
                        sample_labels.append(-100)
                processed_labels.append(sample_labels)
            sentence_labels = torch.tensor(processed_labels).to(self.device)
            sentence_outputs = self.roberta_model(input_ids=tokenized['input_ids'],
                                                  attention_mask=tokenized['attention_mask'],
                                                  labels=sentence_labels)
            loss = sentence_outputs.loss
        else:
            sentence_outputs = self.roberta_model(input_ids=tokenized['input_ids'], attention_mask=tokenized['attention_mask'])
            loss = None

        if predict is True:

            outputs = []
            for sample_index, input_ids in enumerate(tokenized['input_ids']):
                for input_id_index, input_id in enumerate(input_ids):
                    if input_id == self.roberta_tokenizer.mask_token_id:
                        outputs.append(sentence_outputs.logits[sample_index, input_id_index])
                        break
                else:
                    print('Did not find mask!')
        
            return torch.softmax(torch.stack(outputs, dim=0), dim=-1), loss

        return sentence_outputs.logits, loss

## You should not need to change things below this line, unless you change the type of model output above! Just run the cells below to test.

## Dataset

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader

In [None]:
class MelbertDataset(Dataset):

    def __init__(self, csv_path='https://www.dropbox.com/s/1j2c13i4wlz647k/train.tsv?dl=1'):
        self.df = pd.read_csv(csv_path, sep='\t' if 'tsv' in csv_path else ',', quotechar='`', quoting=3, doublequote=False)
        if 'train' in csv_path:
            neg_df = self.df[self.df['label'] == 0]
            pos_df = self.df[self.df['label'] == 1]
            while len(neg_df) > len(pos_df):
                pos_df = pos_df.append(pos_df.copy(), ignore_index=True)
            pos_df = pos_df.iloc[:len(neg_df)]
            self.df = neg_df.append(pos_df, ignore_index=True).reset_index()
            self.df = self.df.sample(frac=1., random_state=1234)
        print(self.df.describe())

    def __getitem__(self, item):
        row = self.df.iloc[item]
        return row['sentence'], row['label'], row['w_index'], row['label']

    def __len__(self):
        return len(self.df)

In [None]:
ds = MelbertDataset()

             level_0          label        w_index
count  281966.000000  281966.000000  281966.000000
mean   140982.500000       0.500000      13.233748
std     81396.717339       0.500001      12.281684
min         0.000000       0.000000       0.000000
25%     70491.250000       0.000000       4.000000
50%    140982.500000       0.500000      10.000000
75%    211473.750000       1.000000      19.000000
max    281965.000000       1.000000     108.000000


In [None]:
ds[52]

('But in every other conceivable way that action was entirely wrong .',
 0,
 9,
 0)

## Train

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report

In [None]:
def train(model, epochs=3, train_batch_size=32, valid_batch_size=4, lr=2e-5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # print('seed ==', torch.random.initial_seed())
    # seed == 15019865508169630983
    torch.random.manual_seed(15019865508169630983)

    model.train()

    dataset = MelbertDataset(csv_path='https://www.dropbox.com/s/1j2c13i4wlz647k/train.tsv?dl=1')
    dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
    valid_dataset = MelbertDataset(csv_path='https://www.dropbox.com/s/tpqjd5xt8cb3p9e/test.tsv?dl=1')
    valid_dataloader = DataLoader(valid_dataset, batch_size=valid_batch_size)

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    best_state_dict = model.state_dict()
    best_valid_loss = None

    metaphor = model.roberta_tokenizer.encode(' metaphor', add_special_tokens=False)[0]
    literal = model.roberta_tokenizer.encode(' literal', add_special_tokens=False)[0]
    
    for epoch in range(epochs):
        running_loss = 0.
        for b, batch in enumerate(valid_dataloader):
            X, y, w_indices, _ = batch
            y = [metaphor if y_ == 1 else literal for y_ in y]

            output, loss = model(X, w_indices,
                                 labels=torch.tensor(y).to(model.device),
                                 add_random_mask=True)
            running_loss += loss.item()

            print(f'Epoch {epoch} Batch {b} Loss {loss.item()} Running loss: {running_loss / (b + 1)}')

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # validate
        model.eval()
        y_true = []
        y_pred = []
        valid_running_loss = 0.
        for b, batch in enumerate(valid_dataloader):
            X, y, w_indices, _ = batch
            y = [metaphor if y_ == 1 else literal for y_ in y]
            y_true.extend(y)

            with torch.no_grad():
                output, valid_loss = model(X, w_indices,
                                           labels=torch.tensor(y).to(model.device),
                                           add_random_mask=False,
                                           predict=True)
                y_pred.extend(torch.argmax(output, dim=-1).flatten().cpu().tolist())
            
            valid_running_loss += valid_loss.item()

            print(f'VALID Epoch {epoch} Batch {b} Loss {valid_loss.item()} Running loss: {valid_running_loss / (b + 1)}')

        y_true = [1 if yt == metaphor else 0 for yt in y_true]
        y_pred = [1 if yt == metaphor else 0 for yt in y_pred]
        print(classification_report(np.array(y_true),
                                    np.array(y_pred),
                                    target_names=['no_metaphor', 'metaphor'],
                                    digits=5))

        if (best_valid_loss is None) or (running_loss / (b // valid_batch_size + 1) < best_valid_loss):
            best_state_dict = model.state_dict()
            best_valid_loss = running_loss / (b // valid_batch_size + 1)
        elif running_loss / (b // valid_batch_size + 1) > best_valid_loss:
            model.load_state_dict(best_state_dict)

        torch.save(model, f'model_{epoch}_{valid_running_loss / (b + 1)}.pt')

        model.train()

In [None]:
model = Model()
model = model.to(model.device)
train(model, epochs=3, train_batch_size=4 * 16, valid_batch_size=4, lr=1e-5)

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

             level_0          label        w_index
count  281966.000000  281966.000000  281966.000000
mean   140982.500000       0.500000      13.233748
std     81396.717339       0.500001      12.281684
min         0.000000       0.000000       0.000000
25%     70491.250000       0.000000       4.000000
50%    140982.500000       0.500000      10.000000
75%    211473.750000       1.000000      19.000000
max    281965.000000       1.000000     108.000000
              label       w_index
count  22196.000000  22196.000000
mean       0.179402     13.605289
std        0.383697     13.172939
min        0.000000      0.000000
25%        0.000000      4.000000
50%        0.000000     10.000000
75%        0.000000     19.000000
max        1.000000    114.000000
Epoch 0 Batch 0 Loss 6.040916919708252 Running loss: 6.040916919708252
Epoch 0 Batch 1 Loss 5.229048728942871 Running loss: 5.6349828243255615
Epoch 0 Batch 2 Loss 0.8456472158432007 Running loss: 4.038537621498108
Epoch 0 Batch 3 Loss

KeyboardInterrupt: ignored

# IGNORE WHAT FOLLOWS FOR TESTING (WAS INQUIRING INTO ATTENTION PATTERNS FOR METAPHOR)

In [None]:
## Experimentation

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-large')
model = RobertaModel.from_pretrained('roberta-large')

In [None]:
tokenizer([('trope', 'metaphor')])

In [None]:
# tokenized = tokenizer('Latest corporate unbundler reveals laid-back approach : Roland Franklin , who is leading a 697m pound break-up bid for DRG , talks to Frank Kane', return_tensors='pt')
tokenized = tokenizer('The jury went out of the room to the adjacent room.', return_tensors='pt')
with torch.no_grad():
    outputs = model(**tokenized, output_attentions=True)

In [None]:
[a.shape for a in outputs.attentions]

In [None]:
outputs.attentions[-1][0].shape

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
df = pd.DataFrame(data=outputs.attentions[0][-1][0].tolist(),
                  columns=[tokenizer.decode([t]) for t in tokenized['input_ids'][0]])
df['labels'] = [tokenizer.decode([t]) for t in tokenized['input_ids'][0]]
df = df.set_index('labels')

In [None]:
df

In [None]:
plt.figure(figsize=(24, 18))
sns.heatmap(df, annot=True, fmt=".2f")