In [None]:
%load_ext autoreload
%autoreload 2

### TODO
- Understand gathering in last layer
- different loss?
- padding mask is strange
- What is my baseline?
- What is the expected output?

### Resources
- https://github.com/codertimo/BERT-pytorch/
- https://nlp.seas.harvard.edu/2018/04/03/attention.html
- https://jalammar.github.io/illustrated-transformer/
- https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
- https://arxiv.org/abs/1810.04805
- https://neptune.ai/blog/unmasking-bert-transformer-model-performance

In [None]:
from pathlib import Path
import torch
import pickle
from transformer.datasets import get_specialized_vocabulary, GrammarDataset
from torch.utils.data import Dataset, DataLoader
from dotted_dict import DottedDict
import torch.optim as optim
import torch.nn as nn
from transformer.utils import count_parameters
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
config = DottedDict()
config.n_vis = 16
config.batch_size = 512     
config.pred_min = 1               # min number of masked tokens [MSK]
config.pred_max = 1               # max number of masked tokens
config.pred_freq = 0.15           # number of mask tokens = pred_freq * d_l
config.d_model = 8                # embed. dimension of tokens and positions
config.d_k = 64           
config.d_q = 64
config.d_v = config.d_model
config.d_ff = 3 * config.d_model
config.n_heads = 4               # number of attention heads
config.d_sentence = 32            # number of tokens in sentence
config.n_layers = 2
config.device = "cuda:0"
config.p_data = Path("data") / "grammar-00.pkl"
config.n_epochs = 50
config.lr = 0.001
#
config.freqs = DottedDict()
config.freqs.print_valid_preds = 318 * 1   # steps
config.freqs.eval = 1                      # epochs

### Load Data

In [None]:
with open(config.p_data, "rb") as file:
    data = pickle.load(file)

In [None]:
data_train = data["data_train"]
data_valid = data["data_valid"]

In [None]:
tok_dict = get_specialized_vocabulary(data["vocabulary"])
print(tok_dict)

In [None]:
len(tok_dict)

In [None]:
ds_train = GrammarDataset(data["data_train"], tok_dict, d_sentence=config.d_sentence)
ds_valid = GrammarDataset(data["data_valid"], tok_dict, d_sentence=config.d_sentence)

In [None]:
dl_train = DataLoader(ds_train, batch_size=config.batch_size, shuffle=True, num_workers=8)
dl_valid = DataLoader(ds_valid, batch_size=config.batch_size, shuffle=False, num_workers=8)

### Create Model

In [None]:
import torch
import torch.nn as nn
from transformer.layers import Embedding, AttentionEncoder
from transformer.utils import get_attn_mask

In [None]:
class BERT(nn.Module):
    def __init__(
        self, d_vocab: int, d_model: int, d_sentence: int,
        n_layers, n_heads, d_k, d_v, d_ff
    ):
        super(BERT, self).__init__()
        #
        self.d_vocab = d_vocab
        self.d_model = d_model
        self.d_sentence = d_sentence
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_ff = d_ff
        #
        assert self.d_v == self.d_model # not optimal but hey ...
        
        # Input Embeddings
        self.embedding = Embedding(d_vocab, d_model, d_sentence)
        
        # Attention Layers
        self.layers = []
        for _ in range(n_layers):
            layer = AttentionEncoder(d_model, d_k, d_v, n_heads, d_ff)
            self.layers.append(layer)
        self.layers = nn.ModuleList(self.layers)
        
        # Output Head
        self.norm = nn.LayerNorm(d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.gelu = torch.nn.GELU()
        
        # Output Decoder
        #  = inverse Embedding
        # There might be a better solution
        self.decoder = nn.Linear(d_model, d_vocab)
        self.decoder.weight = self.embedding.tok_emb.weight
        self.decoder_bias = nn.Parameter(torch.zeros(d_vocab))
    
    
    def forward(self, input_ids, input_mask_pos):
        mask = get_attn_mask(input_ids)
        out = self.embedding(input_ids)
        for layer in self.layers:
            out, attn = layer(out, mask)
        
        # [b, max_pred, d_model]
        masked_pos = input_mask_pos[:, :, None].expand(-1, -1, out.size(-1))
        h_masked = torch.gather(out, 1, masked_pos)
        h_masked = self.norm(self.gelu(self.linear(h_masked)))
        #
        logits = self.decoder(h_masked) + self.decoder_bias
        
        return logits

In [None]:
model = BERT(d_vocab=len(tok_dict),
             d_model=config.d_model,
             d_sentence=config.d_sentence,
             n_layers=config.n_layers,
             n_heads=config.n_heads,
             d_k=config.d_k,
             d_v=config.d_v,
             d_ff=config.d_ff)

In [None]:
model = model.to(config.device)

In [None]:
print("#Params: {:,}".format(count_parameters(model)))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.lr)

### Vis data

In [None]:
vis_data = next(iter(dl_valid))
tok_list_vis, mask_idcs_vis, mask_toks_vis = next(iter(dl_valid))
tok_list_vis = tok_list_vis[:config.n_vis]
mask_idcs_vis = mask_idcs_vis[:config.n_vis]
mask_toks_vis = mask_toks_vis[:config.n_vis]

In [None]:
logits = model(tok_list_vis.to(config.device), mask_idcs_vis.to(config.device))
preds_vis = logits.argmax(axis=2)

In [None]:
def get_verbose_output(tok_list, mask_toks, preds, grammar_ds):
    #
    all_sentences = []
    all_labels = []
    all_predictions = []
    #
    for idx in range(preds.size(0)):
        sentence = [grammar_ds.idx_dict[tok_id.item()] for tok_id in tok_list[idx] if tok_id.item() not in (0, 2)]
        sentence = "".join(sentence)
        all_sentences.append(sentence)
        #
        label = grammar_ds.idx_dict[mask_toks[idx].item()]
        pred = grammar_ds.idx_dict[preds[idx].item()]

        all_labels.append(label)
        all_predictions.append(pred)

    return all_sentences, all_labels, all_predictions

In [None]:
all_sentence, all_labels, all_preds = get_verbose_output(tok_list_vis, mask_toks_vis, preds_vis, ds_valid)
right_pred = [p == l for p, l in zip(all_preds, all_labels)]
df = pd.DataFrame({'input': all_sentence, 'label': all_labels, 'pred': all_preds, 'match': right_pred})
print(df)

### Train

In [None]:
global_step = 0.
model = model.to(config.device)
#
all_accs = []
all_train_losses = []
all_valid_losses = []
#
for epoch in range(config.n_epochs):
    step, losses = 0, 0
    p_bar = tqdm(dl_train, desc=f"Train {epoch}")
    
    # TRAIN LOOP
    for tok_list, mask_idcs, mask_toks in p_bar:
        model.train()
        tok_list = tok_list.to(config.device)
        mask_toks = mask_toks.to(config.device)
        mask_idcs = mask_idcs.to(config.device)
        optimizer.zero_grad()
        logits = model(tok_list, mask_idcs)
        loss = criterion(logits.transpose(1, 2), mask_toks) # for masked LM
        loss.backward()
        optimizer.step()
        step += 1
        global_step +=1
        losses += loss.item()
        p_bar.set_postfix({'loss': losses / step})
        
        if global_step % config.freqs.print_valid_preds == 0:
            with torch.no_grad():
                logits = model(tok_list_vis.to(config.device), mask_idcs_vis.to(config.device))
                preds_vis = logits.argmax(axis=2).cpu()
            all_sentence, all_labels, all_preds = get_verbose_output(tok_list_vis, mask_toks_vis, preds_vis, ds_valid)
            right_pred = [p == l for p, l in zip(all_preds, all_labels)]
            df = pd.DataFrame({'input': all_sentence, 'label': all_labels, 'pred': all_preds, 'match': right_pred})
            print(df)
    all_train_losses.append(losses)
    # EVAL LOOP
    if epoch % config.freqs.eval == 0:
        losses, accs, step = 0., 0., 0
        p_bar = tqdm(dl_valid, desc=f"Eval {epoch}")
        for tok_list, mask_idcs, mask_toks in p_bar:
            tok_list = tok_list.to(config.device)
            mask_toks = mask_toks.to(config.device)
            mask_idcs = mask_idcs.to(config.device)
            #
            model.eval()
            with torch.no_grad():
                logits = model(tok_list, mask_idcs)
                loss = criterion(logits.transpose(1, 2), mask_toks) # for masked LM
                preds = logits.argmax(axis=2)
                acc = (preds == mask_toks).sum() / preds.size(0)
                #
                losses += loss.item()
                accs += acc.item()
                step += 1
                p_bar.set_postfix({'loss': losses / step, 'acc': accs / step})
            all_valid_losses.append(losses)
            all_accs.append(accs)
        if (accs / step) >= 0.99:
            print("Solved")
            break

### Inspect Attention

In [None]:
model.eval() 
with torch.no_grad():
    logits = model(tok_list_vis.to(config.device), mask_idcs_vis.to(config.device))
    preds_vis = logits.argmax(axis=2).cpu()
    all_sentence, all_labels, all_preds = get_verbose_output(tok_list_vis, mask_toks_vis, preds_vis, ds_valid)
    right_pred = [p == l for p, l in zip(all_preds, all_labels)]
    df = pd.DataFrame({'input': all_sentence, 'label': all_labels, 'pred': all_preds, 'match': right_pred})
    print(df)

In [None]:
input_ids = tok_list_vis.to(config.device)
mask = get_attn_mask(input_ids)
#
attentions = []
with torch.no_grad():
    emb = out = model.embedding(input_ids)
    for layer in model.layers:
        out, attn = layer(out, mask)
        attentions.append(attn.cpu())

In [None]:
sample_idx = 0
print(all_sentence[sample_idx])
atts = [attention[sample_idx].squeeze() for attention in attentions]
#
xtick_labels = [str(i) for i in list(tok_list_vis[sample_idx].numpy())]
xtick_labels = [ds_valid.idx_dict[i] for i in list(tok_list_vis[sample_idx].numpy())]
sentence = all_sentence[sample_idx]

n_rows = len(atts)
n_cols = config.n_heads
#
plt_scale = 12
#
figsize = (n_cols * plt_scale, n_rows * plt_scale)
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        if n_cols == 1:
            ax = axes[row_idx]
            attn = atts[row_idx]
        else:
            attn = atts[row_idx][col_idx]
            ax = axes[row_idx][col_idx]
        ax.imshow(attn)
        #
        #ax.set_title(sentence)
        ax.set_xticklabels(xtick_labels)
        ax.set_xticks(list(range(32)))
        #
        ax.set_yticks(list(range(32)))
        ax.set_yticklabels(xtick_labels)