## BERT Self-Attention Model 

In [7]:
import pytorch_lightning as pl
import torch
from torch import FloatTensor, LongTensor
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple

from layers import SelfAttention

from sklearn.metrics import accuracy_score, f1_score

from transformers import BertModel, BertTokenizer, DistilBertTokenizer,\
    DistilBertModel

from typing import List

def get_lens(batch: torch.tensor) -> torch.tensor:
    batch = batch.detach().cpu()
    lens = [len(np.where(row>0)[0]) for row in batch]
    return torch.tensor(lens)

class BertAttentionClassifier(pl.LightningModule):
    
    def __init__(self, 
                 num_classes: int):
        super(BertAttentionClassifier, self).__init__()
    
        self.bert = BertModel.from_pretrained('bb_lm_ft/')
        self.num_classes = num_classes
        self.linear1 = nn.Linear(self.bert.config.hidden_size, 256)
        self.self_attention = SelfAttention(256,batch_first=True, non_linearity="tanh")
        self.out = nn.Linear(256, num_classes)
                    
    def forward(self, 
               input_ids: torch.tensor,
               sent_lengths: List[int]):
        h, attn = self.bert(input_ids=input_ids)
        linear1 = torch.nn.functional.relu(self.linear1(h))        
        attention, _ = self.self_attention(linear1, sent_lengths)
        logits = self.out(attention)
        return logits, attn
    
    def training_step(self, batch, batch_idx):
        # batch
        input_ids, labels = batch
        sent_lengths = get_lens(input_ids)
        
        # predict
        y_hat, attn = self.forward(input_ids, sent_lengths)
        
        # loss 
        loss = F.cross_entropy(y_hat, labels)

        # logs
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_idx):
        input_ids, labels = batch
        sent_lengths = get_lens(input_ids)
        
        y_hat, attn = self.forward(input_ids, sent_lengths)
        
        loss = F.cross_entropy(y_hat, labels)
        
        a, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        labels = labels.cpu()

        val_acc = accuracy_score(labels, y_hat)
        val_acc = torch.tensor(val_acc)
        
        val_f1 = f1_score(labels, y_hat, average='micro')
        val_f1 = torch.tensor(val_f1)

        return {'val_loss': loss, 'val_acc': val_acc, 'val_f1': val_f1}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        avg_val_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
        
        tensorboard_logs = {'val_loss': avg_loss, 'avg_val_acc': avg_val_acc, 'avg_val_f1': avg_val_f1}
        
        return {'avg_val_loss': avg_loss, 'avg_val_f1':avg_val_f1 ,'progress_bar': tensorboard_logs}
    
    def test_step(self, batch, batch_idx):
        input_ids, labels = batch
        sent_lengths = get_lens(input_ids)
        
        y_hat, attn = self.forward(input_ids, sent_lengths)
        
        loss = F.cross_entropy(y_hat, labels)
        
        a, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        labels = labels.cpu()

        val_acc = accuracy_score(labels, y_hat)
        val_acc = torch.tensor(val_acc)
        
        val_f1 = f1_score(labels, y_hat, average='micro')
        val_f1 = torch.tensor(val_f1)

        return {'val_loss': loss, 'val_acc': val_acc, 'val_f1': val_f1}
    
    def test_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        avg_val_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
        
        tensorboard_logs = {'val_loss': avg_loss, 'avg_val_acc': avg_val_acc, 'avg_val_f1': avg_val_f1}
        
        return {'avg_val_loss': avg_loss, 'avg_val_f1':avg_val_f1 ,'progress_bar': tensorboard_logs}
            
    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], 
                                lr=2e-05, eps=1e-08)

    @pl.data_loader
    def train_dataloader(self):
        return train_dataloader_
    
    @pl.data_loader
    def val_dataloader(self):
        return val_dataloader_
    
    @pl.data_loader
    def test_dataloader(self):
        return test_dataloader_
    

## Train Model

In [8]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
from imblearn.over_sampling import RandomOverSampler
from sklearn.metrics import classification_report
from transformers import BertTokenizer, DistilBertTokenizer
from typing import List
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import TensorDataset, RandomSampler, DataLoader

dat = pd.read_csv("data/task_2_data.csv")
le = LabelEncoder()

train = dat[dat["source"]=="train"]
dev = dat[dat["source"]!="train"]

le = le.fit(train["label"])
train["encoded_label"] = le.fit_transform(train["label"]) 
train["num_words"] = train["text"].apply(lambda x: len(x.split()))

random_seed = 1956
tokenizer = BertTokenizer.from_pretrained('bb_lm_ft/')

train, val = train_test_split(train, test_size=.15,
                              stratify=train["encoded_label"],
                              random_state=random_seed)

INFO:transformers.tokenization_utils:Model name 'bb_lm_ft/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming 'bb_lm_ft/' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils:Didn't find file bb_lm_ft/added_tokens.json. We won't load it.
INFO:transformers.tokenization_utils:loading file bb_lm_ft/vocab.txt
INFO:transformers.tokenization_utils:loading file None
INFO:transformers.tokenization_utils

In [5]:
BATCH_SIZE = 24

X_train = [torch.tensor(tokenizer.encode(text)) for text in train["text"]]
X_train = pad_sequence(X_train, batch_first=True, padding_value=0)
y_train = torch.tensor(train["encoded_label"].tolist())

X_val = [torch.tensor(tokenizer.encode(text)) for text in val["text"]]
X_val = pad_sequence(X_val, batch_first=True, padding_value=0)
y_val = torch.tensor(val["encoded_label"].tolist())

ros = RandomOverSampler(random_state=random_seed)
X_train_resampled, y_train_resampled = ros.fit_resample(X_train, y_train)

X_train_resampled = torch.tensor(X_train_resampled)
y_train_resampled = torch.tensor(y_train_resampled)

In [9]:
train_dataset = TensorDataset(X_train_resampled, y_train_resampled)
train_dataloader_ = DataLoader(train_dataset,
                               sampler=RandomSampler(train_dataset),
                               batch_size=BATCH_SIZE)

val_dataset = TensorDataset(X_val, y_val)
val_dataloader_ = DataLoader(val_dataset,
                             sampler=RandomSampler(val_dataset),
                             batch_size=BATCH_SIZE)

dev_ids = [torch.tensor(tokenizer.encode(text)) for text in dev["text"]]
dev_ids = pad_sequence(dev_ids, batch_first=True, padding_value=0)

dev_dataset = TensorDataset(dev_ids)
dev_dataloader_ = DataLoader(dev_dataset, batch_size=BATCH_SIZE)


In [46]:
for batch in train_dataloader_:
  print(batch[0].shape, batch[1].shape)
  break

torch.Size([24, 192]) torch.Size([24])


In [50]:
# model = BertAttentionClassifier(num_classes=14)

# trainer = pl.Trainer(gpus=1, 
#                      max_epochs=1,
#                      default_save_path=".pl_bert_sa_logs/")
# trainer.fit(model)

# print("finished training")
# print("savin model")
# torch.save(model.state_dict(), "bb_sa_lm.pt")

INFO:transformers.configuration_utils:loading configuration file bb_lm_ft/config.json
INFO:transformers.configuration_utils:Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "do_sample": false,
  "eos_token_ids": 0,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pad_token_id": 0,
  "pruned_heads": {},
  "repetition_pen

## Generate Predictions

In [15]:
from utils import generate_t2_submission
from tqdm import tqdm

model = BertAttentionClassifier(num_classes=14)
model.load_state_dict(torch.load("bb_sa_lm.pt"))
model.eval()
model.to("cpu")

all_preds = []
for batch in tqdm(dev_dataloader_):
    i = batch[0]
    sl = get_lens(i)

    preds, _ = model(i, sl)
      
    a, y_hat = torch.max(preds, dim=1)
    y_hat = y_hat.cpu()
    
    all_preds.extend(y_hat)
print("finished")

INFO:transformers.configuration_utils:loading configuration file bb_lm_ft/config.json
INFO:transformers.configuration_utils:Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "do_sample": false,
  "eos_token_ids": 0,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pad_token_id": 0,
  "pruned_heads": {},
  "repetition_pen

In [27]:
from typing import List

def generate_t2_sub(preds: List[str]) -> List[str]:
    """ Take a list of prediction and update the TC template
        with those predictions """
    with open("data/dev-task-TC-template.out", "r") as f:
        lines = f.readlines()
    
    final = []
    for i, line in enumerate(lines):
        pred = preds[i].strip()
        line = line.replace("?", pred)
        final.append(line)
    
    return final

preds = le.inverse_transform(all_preds)

lines = generate_t2_sub(preds)

with open("submissions/bb_sa_lm_preds_t2.txt", "w") as f:
    for line in lines:
        f.write(line.strip() + "\n")

In [21]:
preds

tensor([[-3.6668, -0.9791, -2.4421, -3.1846, -2.0220, -0.9879,  0.6122, -3.1605,
          7.8311,  0.3629,  1.0154, -1.6361, -1.6982, -3.3624],
        [-3.7051, -2.2085, -2.3973, -3.4145, -3.2419, -1.3998,  4.3762, -2.2037,
          5.4715,  2.0114,  0.6773, -2.4220, -2.7661, -2.4518],
        [-2.3218, -2.0301, -3.3197, -2.5521, -2.4829, -1.5881, -0.5952, -1.7798,
          1.9647,  7.0162,  2.1267, -1.9072, -3.2837, -1.6825],
        [-3.9650, -1.1671, -2.3636, -3.0723, -2.7098, -1.6617,  2.2528, -3.0903,
          7.4692,  0.9137,  0.7633, -1.6863, -2.5879, -3.0991],
        [-3.3977, -1.7206, -3.5181, -3.6897, -3.2311, -1.0395, -0.0298, -2.4877,
          6.3027,  2.2526,  3.9654, -1.3684, -2.7997, -3.3794],
        [-2.7979, -2.2817, -3.1292, -3.1799, -2.7487, -0.2284,  0.5271, -2.6144,
          4.4566,  5.5592,  1.7827, -2.8251, -2.8130, -2.8017],
        [-1.6488, -1.0604, -2.9936, -1.9140, -1.5754, -0.4984, -1.1813,  7.6305,
         -0.9371, -0.7844,  1.8205, -0.3798, -2.1