In [1]:
%load_ext autoreload
%autoreload 2

### Markdown
- Understand gathering in last layer
- What is my baseline?
- What is the expected output?
- Are there valid metrics?

In [2]:
from pathlib import Path
import torch
import pickle
from transformer.datasets import get_specialized_vocabulary, GrammarDataset
from torch.utils.data import Dataset, DataLoader
from dotted_dict import DottedDict
import torch.optim as optim
import torch.nn as nn
from transformer.utils import count_parameters
from tqdm import tqdm
import pandas as pd

In [3]:
config = DottedDict()
config.n_vis = 16
config.batch_size = 512     
config.pred_min = 1      # min number of masked tokens [MSK]
config.pred_max = 1      # max number of masked tokens
config.pred_freq = 0.15  # number of mask tokens = pred_freq * d_l
config.d_model = 8       # embed. dimension of tokens and positions
config.d_k = 256           
config.d_q = 256
config.d_v = config.d_model
config.d_ff = 4 * config.d_model
config.n_heads = 16       # number of attention heads
config.d_sentence = 32          # number of tokens in sentence
config.n_layers = 8
config.device = "cuda:0"
config.p_data = Path("data") / "grammar-00.pkl"
config.n_epochs = 10
config.lr = 0.001
#
config.freqs = DottedDict()
config.freqs.print_valid_preds = 318 * 1

### Load Data

In [4]:
with open(config.p_data, "rb") as file:
    data = pickle.load(file)

In [5]:
data_train = data["data_train"]
data_valid = data["data_valid"]

In [6]:
tok_dict = get_specialized_vocabulary(data["vocabulary"])
print(tok_dict)

{'[PAD]': 0, '[MSK]': 1, '[CLS]': 2, 'b': 3, ')': 4, 'c': 5, '[': 6, 'a': 7, ']': 8, '(': 9, 'e': 10, 'd': 11}


In [7]:
len(tok_dict)

12

In [8]:
ds_train = GrammarDataset(data["data_train"], tok_dict, d_sentence=config.d_sentence)
ds_valid = GrammarDataset(data["data_valid"], tok_dict, d_sentence=config.d_sentence)

In [9]:
dl_train = DataLoader(ds_train, batch_size=config.batch_size, shuffle=True, num_workers=8)
dl_valid = DataLoader(ds_valid, batch_size=config.batch_size, shuffle=False, num_workers=8)

### Create Model

In [10]:
import torch
import torch.nn as nn
from transformer.layers import Embedding, AttentionEncoder
from transformer.utils import get_attn_mask

In [11]:
class BERT(nn.Module):
    def __init__(
        self, d_vocab: int, d_model: int, d_sentence: int,
        n_layers, n_heads, d_k, d_v, d_ff
    ):
        super(BERT, self).__init__()
        #
        self.d_vocab = d_vocab
        self.d_model = d_model
        self.d_sentence = d_sentence
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_ff = d_ff
        #
        assert self.d_v == self.d_model # not optimal but hey ...
        
        # Input Embeddings
        self.embedding = Embedding(d_vocab, d_model, d_sentence)
        
        # Attention Layers
        self.layers = []
        for _ in range(n_layers):
            layer = AttentionEncoder(d_model, d_k, d_v, n_heads, d_ff)
            self.layers.append(layer)
        self.layers = nn.ModuleList(self.layers)
        
        # Output Head
        self.norm = nn.LayerNorm(d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.gelu = torch.nn.GELU()
        
        # Output Decoder
        #  = inverse Embedding
        # There might be a better solution
        self.decoder = nn.Linear(d_model, d_vocab)
        self.decoder.weight = self.embedding.tok_emb.weight
        self.decoder_bias = nn.Parameter(torch.zeros(d_vocab))
    
    
    def forward(self, input_ids, input_mask_pos):
        mask = get_attn_mask(input_ids)
        out = self.embedding(input_ids)
        for layer in self.layers:
            out, attn = layer(out, mask)
        
        # [b, max_pred, d_model]
        masked_pos = input_mask_pos[:, :, None].expand(-1, -1, out.size(-1))
        h_masked = torch.gather(out, 1, masked_pos)
        h_masked = self.norm(self.gelu(self.linear(h_masked)))
        #
        logits = self.decoder(h_masked) + self.decoder_bias
        
        return logits

In [12]:
model = BERT(d_vocab=len(tok_dict),
             d_model=config.d_model,
             d_sentence=config.d_sentence,
             n_layers=config.n_layers,
             n_heads=config.n_heads,
             d_k=config.d_k,
             d_v=config.d_v,
             d_ff=config.d_ff)

In [13]:
model = model.to(config.device)

In [14]:
print("#Params: {:,}".format(count_parameters(model)))

#Params: 612,320


In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.lr)

### Vis data

In [16]:
vis_data = next(iter(dl_valid))
tok_list_vis, mask_idcs_vis, mask_toks_vis = next(iter(dl_valid))
tok_list_vis = tok_list_vis[:config.n_vis]
mask_idcs_vis = mask_idcs_vis[:config.n_vis]
mask_toks_vis = mask_toks_vis[:config.n_vis]

In [17]:
logits = model(tok_list_vis.to(config.device), mask_idcs_vis.to(config.device))
preds_vis = logits.argmax(axis=2)

In [18]:
def get_verbose_output(tok_list, mask_toks, preds, ds):
    #
    all_sentences = []
    all_labels = []
    all_predictions = []
    #
    for idx in range(preds.size(0)):
        sentence = [ds.idx_dict[tok_id.item()] for tok_id in tok_list[idx] if tok_id.item() not in (0, 2)]
        sentence = "".join(sentence)
        all_sentences.append(sentence)
        #
        label = ds.idx_dict[mask_toks[idx].item()]
        pred = ds.idx_dict[preds[idx].item()]

        all_labels.append(label)
        all_predictions.append(pred)

    return all_sentences, all_labels, all_predictions

In [19]:
all_sentence, all_labels, all_preds = get_verbose_output(tok_list_vis, mask_toks_vis, preds_vis, ds_valid)
df = pd.DataFrame({'input': all_sentence, 'label': all_labels, 'pred': all_preds})
print(df)

                                  input label pred
0                      (ab)[a[MSK]babb]     a    b
1                  ((ab))[MSK]((baba)))     (    b
2              (baabababbababaabb[MSK])     a    b
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    b
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    b
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    b
7          (([cbbaabbb[MSK]aabbaabad]))     a    b
8           ([ab][MSK][(ba)]([ccababe])     )    b
9        [[MSK]ccaababbaabbe)(ccababe)]     (    b
10                 [MSK]([aabb]))(baba)     (    b
11           (cbb[MSK]babaaabbbabaabad)     a    b
12           ((((ba))))([[[(ba)[MSK]]])     ]    b
13        [(((ccabababbababaab[MSK])))]     e    b
14             [(cabd)([b[MSK]bababa])]     a    b
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


### Train

In [20]:
global_step = 0.
model = model.to(config.device)
for epoch in range(config.n_epochs):
    step, losses = 0, 0
    p_bar = tqdm(dl_train, desc=f"Train {epoch}")
    for tok_list, mask_idcs, mask_toks in p_bar:
        tok_list = tok_list.to(config.device)
        mask_toks = mask_toks.to(config.device)
        mask_idcs = mask_idcs.to(config.device)
        optimizer.zero_grad()
        logits = model(tok_list, mask_idcs)
        loss = criterion(logits.transpose(1, 2), mask_toks) # for masked LM
        loss.backward()
        optimizer.step()
        step += 1
        global_step +=1
        losses += loss.item()
        p_bar.set_postfix({'loss': losses / step})
        
        if global_step % config.freqs.print_valid_preds == 0:
            with torch.no_grad():
                logits = model(tok_list_vis.to(config.device), mask_idcs_vis.to(config.device))
                preds_vis = logits.argmax(axis=2).cpu()
            all_sentence, all_labels, all_preds = get_verbose_output(tok_list_vis, mask_toks_vis, preds_vis, ds_valid)
            df = pd.DataFrame({'input': all_sentence, 'label': all_labels, 'pred': all_preds})
            print(df)

Train 0: 100%|██████████| 318/318 [00:56<00:00,  5.59it/s, loss=2.07]
Train 1:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    a
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    a
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    a
5              ((ab[MSK]a)[([[cbad]])])     b    a
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    a
9        [[MSK]ccaababbaabbe)(ccababe)]     (    [
10                 [MSK]([aabb]))(baba)     (    [
11           (cbb[MSK]babaaabbbabaabad)     a    b
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    ]
14             [(cabd)([b[MSK]bababa])]     a    )
15         ([cabbaabd])[cb[MSK]abaabad]     b    a


Train 1: 100%|██████████| 318/318 [00:57<00:00,  5.51it/s, loss=1.42]
Train 2:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    )
1                  ((ab))[MSK]((baba)))     (    [
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    [
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    a
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    [
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    ]
9        [[MSK]ccaababbaabbe)(ccababe)]     (    [
10                 [MSK]([aabb]))(baba)     (    [
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    b
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 2: 100%|██████████| 318/318 [00:57<00:00,  5.49it/s, loss=1.01]
Train 3:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    [
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    b
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    )
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    [
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    b
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 3: 100%|██████████| 318/318 [00:58<00:00,  5.47it/s, loss=0.738]
Train 4:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    [
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    ]
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    ]
9        [[MSK]ccaababbaabbe)(ccababe)]     (    [
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    e
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 4: 100%|██████████| 318/318 [00:57<00:00,  5.49it/s, loss=0.53] 
Train 5:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    e
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    ]
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    (
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    e
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 5: 100%|██████████| 318/318 [00:57<00:00,  5.50it/s, loss=0.417]
Train 6:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    e
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    )
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    (
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    b
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 6: 100%|██████████| 318/318 [00:57<00:00,  5.49it/s, loss=0.333]
Train 7:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    [
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    )
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    (
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    e
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 7: 100%|██████████| 318/318 [00:57<00:00,  5.55it/s, loss=0.266]
Train 8:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    e
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    )
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    (
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    e
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 8: 100%|██████████| 318/318 [00:57<00:00,  5.50it/s, loss=0.216]
Train 9:   0%|          | 0/318 [00:00<?, ?it/s]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    e
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    )
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    (
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    e
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b


Train 9: 100%|██████████| 318/318 [00:57<00:00,  5.50it/s, loss=0.193]

                                  input label pred
0                      (ab)[a[MSK]babb]     a    a
1                  ((ab))[MSK]((baba)))     (    (
2              (baabababbababaabb[MSK])     a    a
3   [ccbaabbae](([MSK][bbabaa]]))(cabd)     [    e
4   [[(ccababababe)][(ab[MSK]]((abab))]     )    )
5              ((ab[MSK]a)[([[cbad]])])     b    b
6                    [MSK]cbad)[abbaab]     (    (
7          (([cbbaabbb[MSK]aabbaabad]))     a    a
8           ([ab][MSK][(ba)]([ccababe])     )    )
9        [[MSK]ccaababbaabbe)(ccababe)]     (    (
10                 [MSK]([aabb]))(baba)     (    (
11           (cbb[MSK]babaaabbbabaabad)     a    a
12           ((((ba))))([[[(ba)[MSK]]])     ]    ]
13        [(((ccabababbababaab[MSK])))]     e    e
14             [(cabd)([b[MSK]bababa])]     a    a
15         ([cabbaabd])[cb[MSK]abaabad]     b    b



