In [1]:
import os
if os.getcwd().endswith('notebooks'): os.chdir('..')
    
from random import randint, shuffle
from random import random as rand

import numpy as np
import torch
import torch.nn as nn
import argparse
from tensorboardX import SummaryWriter
import os
import multiprocessing as mp
import src.tokenization
import src.models
import src.optim
import src.train
from src.utils import set_seeds, get_device
from torch.utils.data import Dataset, DataLoader
from src.data import seek_random_offset, SentPairDataset, Pipeline, Preprocess4Pretrain, seq_collate

from config import CONFIG as args

cfg = src.train.Config.from_json(args.train_cfg)
model_cfg = src.models.Config.from_json(args.model_cfg)

tokenizer = src.tokenization.FullTokenizer(vocab_file=args.vocab, do_lower_case=True)
tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x))

pipeline = [Preprocess4Pretrain(args.max_pred,
                                args.mask_prob,
                                list(tokenizer.vocab.keys()),
                                tokenizer.convert_tokens_to_ids,
                                model_cfg.max_len,
                                args.mask_alpha,
                                args.mask_beta,
                                args.max_gram)]
data_iter = DataLoader(SentPairDataset(args.data_file,
                            cfg.batch_size,
                            tokenize,
                            model_cfg.max_len,
                            pipeline=pipeline), 
                        batch_size=cfg.batch_size, 
                        collate_fn=seq_collate,
                        num_workers=mp.cpu_count())

from src.pretrain import Discriminator
# discriminator = Discriminator(model_cfg)

from src.pretrain import Generator
generator_cfg = src.models.Config.from_json(args.generator_cfg)
# generator = Generator(generator_cfg)

```
for batch in tqdm(data_iter):
    input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next, original_ids = batch
```

In [121]:
SPD = SentPairDataset('./data/wiki.test.tokens',
                        16,
                        tokenize,
                        400,
                        pipeline=pipeline)
print(f"number_of_test_senteces: {len(SPD)}")
one_obs = SPD[0]
print(f"number of outputs per datapoint: {len(one_obs)}")

number_of_test_senteces: 4358
number of outputs per datapoint: 8


In [122]:
input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next, original_ids = one_obs

### Suspicious feature #1
A very suspicious thing over the `input_ids` and `masked_ids`. The `103` seems to appear frequently in `input_ids`, but perhaps, we would have expect it to appear more in `masked_ids`.

In [67]:
print(f"token_id of MASK token: {tokenizer.convert_tokens_to_ids(['[MASK]'])}\n")
list(zip(input_ids,masked_ids))[:15]

token_id of MASK token: [103]



[(101, 101),
 (102, 102),
 (2728, 2728),
 (1026, 1026),
 (4895, 4895),
 (2243, 2243),
 (1028, 1028),
 (2003, 2003),
 (2019, 2019),
 (2394, 2394),
 (2143, 2143),
 (103, 1010),
 (103, 2547),
 (1998, 1998),
 (3004, 3004)]

In [64]:
from src.tokenization import convert_ids_to_tokens
' '.join([list(tokenizer.vocab)[masked_id] for masked_id in masked_ids])

'[CLS] [SEP] robert < un ##k > is an english film , television and theatre actor . he had a guest @ - @ starring role on the television series the bill in 2000 . this was followed by a starring role in the play heron ##s written by simon stephens , which was performed in 2001 at the royal court theatre . he had a guest role in the television series judge john [SEP]'

In [65]:
' '.join([list(tokenizer.vocab)[masked_id] for masked_id in input_ids])

'[CLS] [SEP] robert < un ##k > is an english film [MASK] [MASK] and theatre actor [MASK] he had a guest [MASK] - @ starring role on the [MASK] series the bill in [MASK] . this was followed by a starring role in the play heron ##s written by simon stephens [MASK] [MASK] [MASK] performed in 2001 at the royal court theatre [MASK] he had a guest role in the television [MASK] judge john [SEP]'

### Suspicious feature #2

In [73]:
list(zip(input_ids,segment_ids))[:8]

[(101, 0),
 (102, 0),
 (2728, 1),
 (1026, 1),
 (4895, 1),
 (2243, 1),
 (1028, 1),
 (2003, 1)]

### The backward step
The backward step has two involved backprop steps in `generator_loss` and `discriminator_loss`. 

### 1. Generator loss
`(input_ids, segment_ids, input_mask, masked_pos)` -> `model` -> `logits_lm, logits_clsf`

In [76]:
from src.pretrain import Generator
generator_cfg = src.models.Config.from_json(args.generator_cfg)
generator = Generator(generator_cfg)
generator.load_state_dict(torch.load(os.path.join('saved', 'generator.pt')))

Then we run through the forward step to see it generate the logits.

In [160]:
# generating the same 
input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next, original_ids = one_obs

# transformer hidden states
h = generator.transformer(torch.tensor(input_ids).view(1,-1), 
                          torch.tensor(segment_ids).view(1,-1), 
                          torch.tensor(input_mask).view(1,-1))
print(f"h.shape: {h.shape}")

# pooled hidden states
pooled_h = generator.activ1(generator.fc(h[:, 0]))
print("\n running the first hidden state of the CLS token through proj kernel")
print(f"pooled_h.shape: {pooled_h.shape}")

# Here i honestly hav no fucking idea what hes doing
# gather is a like a multi-dim indexing function
# but looking at the masked_pos
# it looks like nothing happened
masked_pos = torch.tensor(masked_pos).view(1,-1)[:, :, None].expand(-1, -1, h.size(-1))
h_masked = torch.gather(h, 1, masked_pos)
print(f"\n masked_pos.shape: {masked_pos.shape}")
print("masked_pos sample")
display( masked_pos[:,:5,:5] )
h_masked = generator.norm(generator.activ2(generator.linear(h_masked)))

# logits of word predictions
logits_lm = generator.decoder2(generator.decoder1(h_masked)) + generator.decoder_bias
print(f"\n logits_lm.shape: {logits_lm.shape}")

# sentence order prediction 
logits_clsf = generator.classifier(pooled_h)
print(f"\n logits_clsf.shape: {logits_clsf.shape}")

# calculating the loss
cross_ent = nn.CrossEntropyLoss(reduction='none')
sent_cross_ent = nn.CrossEntropyLoss()

# MLM loss
loss_lm = cross_ent(logits_lm.transpose(1, 2), torch.tensor(masked_ids).view(1,-1) ) # for masked LM
display( loss_lm.round() )

# Sentence order prediction loss
loss_sop = sent_cross_ent(logits_clsf, torch.tensor([int(is_next)]))

h.shape: torch.Size([1, 75, 64])

 running the first hidden state of the CLS token through proj kernel
pooled_h.shape: torch.Size([1, 64])

 masked_pos.shape: torch.Size([1, 75, 64])
masked_pos sample


tensor([[[0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4]]])


 logits_lm.shape: torch.Size([1, 75, 30522])

 logits_clsf.shape: torch.Size([1, 2])


### 2. Discriminator loss

In [196]:
generate_logits = logits_lm

Before reaching the discriminator, the author massages the generator output. 

In [203]:
from torch.autograd import Variable

# in the src, author does this
# batch[3] = torch.argmax(generate_logits, dim=2).detach()
masked_ids, segment_ids, input_mask, input_ids, _, _, is_next, original_ids = one_obs
input_ids = torch.argmax(generate_logits, dim=2).detach()

masked_label = (torch.tensor(masked_ids).long() != torch.tensor(original_ids))
non_masked_label = torch.tensor(masked_ids) == torch.tensor(original_ids)
input_ids[non_masked_label.view(1,-1)] = torch.tensor(original_ids)[non_masked_label]

is_replaced = Variable((input_ids.long() != torch.tensor(original_ids).view(1,-1).long()).float())
# is_replaced = is_replaced.cuda()

Then we run the replaced token ids through the discriminators

In [234]:
from src.pretrain import Discriminator
discriminator = Discriminator(model_cfg)
discriminator.load_state_dict(torch.load(os.path.join('saved', 'discriminator.pt')))

<All keys matched successfully>

In [244]:
# get the token logits
h = discriminator.transformer(torch.tensor(input_ids).view(1,-1), 
                              torch.tensor(segment_ids).view(1,-1), 
                              torch.tensor(input_mask).view(1,-1) )
logits = discriminator.discriminator(h)

# Sentence order prediction
cls_h = discriminator.activ1(discriminator.fc(h[:, 0]))
logits_clsf = discriminator.classifier(cls_h)

  


Finally we claculate the loss

In [254]:
# replaced token detection
d_bce_loss = nn.BCEWithLogitsLoss(reduction='none')
logits_lm = logits.squeeze(-1)
loss_lm = d_bce_loss(logits_lm, is_replaced) # for masked LM

# sentence order prediction
loss_lm = loss_lm.mean()
loss_sop = sent_cross_ent(logits_clsf, torch.tensor([int(is_next)]))

### 3. Finally, the backward step

In [None]:
# calculate loss
ratio = 50
d_loss = lm_loss*ratio + nsp_loss

# sum of generator and discriminator loss
total_loss = g_loss + d_loss 
loss_sum += total_loss.item() # collates the loss of the whole batch

# run the backward step
# total_loss.backward()
# self.optimizer.step()