In [None]:
"""
Written By Dean Shumbanhete. Much of the code was taken verbatim from the seminal implementation "The Annotated Transformer" by Harvard NLP
This is a seminal resource which implements the original "Attention Is All You Need" paper.
This model will be used as the baseline to compare the effect of bootstrap techniques on NLP processing.

Encoder maps an input sequence of symbol representations (x1,…,xn) to a sequence of continuous representations z=(z1,…,zn).
Given z, the decoder then generates an output sequence (y1,…,ym) of symbols one element at a time. At each step the model is auto-regressive,
consuming the previously generated symbols as additional input when generating the next.
"""
!pip install accelerate -q

import numpy as np
import math
import pandas as pd
import seaborn as sns
sns.set_context(context="talk")
from torch.utils.data import DataLoader, Dataset 
import matplotlib.pyplot as plt
from pathlib import Path
import copy
from collections import Counter, defaultdict
import json 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import time
from tqdm import tqdm
import random
import os
from accelerate import Accelerator



#----------------------1. Housekeeping------------------------------------
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
global max_src_in_batch, max_tgt_in_batch
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print("Device:", device) 

cfg = dict(
    data_root      = "/kaggle/input/maskhaner/pytorch/default/1/",   # change if needed
    vocab_min_freq = 0,
    max_len        = 64,
    batch_size     = 32,
    embed_dim      = 128,
    hidden_dim     = 256,
    lr             = 1e-3,
    mlm_prob       = 0.15,
    epochs         = 100,
    ensemble_masks = 5,
    patience       = 5,
    checkpoint_dir = "/kaggle/working/checkpoints",
    seed           = SEED,
)

#-----------------------2. Architecture------------------------------------------
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        #At each forward pass the model processes the masked source and target sequences
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
# --------‐ Vocabulary & Data Reading --------

class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

def read_split(path: str):
    """Return list of token lists from file (tokenized)."""
    lines = Path(path).read_text(encoding="utf-8").splitlines()
    sents = []
    for ln in lines:
        ln = ln.strip()
        if not ln:
            continue
        # assume tokenized: tokens separated by spaces
        toks = ln.split()
        sents.append(toks)
    return sents

class Vocab:
    def __init__(self, sentences, min_freq: int):
        counter = {}
        for s in sentences:
            for w in s:
                counter[w] = counter.get(w, 0) + 1
        self.itos = ["<PAD>", "<MASK>", "<UNK>"] + [w for w, c in counter.items() if c >= min_freq]
        self.stoi = {w:i for i,w in enumerate(self.itos)}
        self.pad_index = self.stoi["<PAD>"]
        self.mask_index = self.stoi["<MASK>"]
        self.unk_index = self.stoi["<UNK>"]

    def encode(self, sent):
        return [ self.stoi.get(w, self.unk_index) for w in sent]

# --------‐ Dataset and Collate fn for MLM with Ensemble Masks --------

class MLMDataset(Dataset):
    def __init__(self, vocab, sentences):
        self.vocab = vocab
        self.sentences = sentences
        #Fix for problems with vocabulary being referenced before being assigned

    def __len__(self):
        return len(self.sentences)


    #if the sentence is too short for seq len then add the pad token/index to the remaining slots in the sequence. Needed to work with tensors.
    def __getitem__(self, idx):
        seq = self.vocab.encode(self.sentences[idx])[:cfg["max_len"]]
        PAD = self.vocab.pad_index
        if len(seq) < cfg["max_len"]:
            seq = seq + [PAD] * (cfg["max_len"] - len(seq))
        return torch.tensor(seq, dtype=torch.long)

def collate_mlm(batch, vocab):
    """
    batch: list of sequences [seq_len] each
    returns:
      x_ens: [B, K, T] masked inputs (ensemble versions),
      labels: [B, T] with -100 on non‐masked positions,
      src_mask: [B, 1, T] mask for padding tokens
      Note: copied from Bayesian. Could encapsulate these in classes for better readability and reusability.
    """
    x_raw = torch.stack(batch)                       # [B, T]
    B, T = x_raw.shape
    K = cfg["ensemble_masks"]

    PAD, UNK, MASK = vocab.stoi["<PAD>"], vocab.stoi["<UNK>"], vocab.stoi["<MASK>"]

    # one shared target (labels)
    labels = x_raw.clone()
    prob = torch.rand_like(labels, dtype=torch.float)
    mask = prob < cfg["mlm_prob"]
    labels[~mask] = -100  # positions not masked are ignored in loss

    # K different masked versions
    x_ens = []
    for _ in range(K):
        x = x_raw.clone()
        # 80% of mask positions => MASK token
        m1 = (prob < 0.8 * cfg["mlm_prob"])
        x[m1] = MASK
        # 10% => random token
        m2 = (prob >= 0.8 * cfg["mlm_prob"]) & (prob < 0.9 * cfg["mlm_prob"])
        # sample random tokens (ignore PAD, UNK etc.)
        rand_tokens = torch.randint(low=3, high=len(vocab.itos), size=(B, T), device=x.device)
        x[m2] = rand_tokens[m2]
        # 10% remain unchanged
        x_ens.append(x)

    x_ens = torch.stack(x_ens, dim=1)   # [B, K, T]
    src_mask = (x_ens != PAD).unsqueeze(1)  # [B, 1, T]
    return x_ens.to(device), labels.to(device), src_mask.to(device)

# --------‐ Model Definition --------

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
      #This is where we do the m,asking of the input through each layer. We simulate this in the Bayesian network masked ensemble.
      #In this model we leverage the multihead attention mechanism as well as parallelism. In the Bayesian network we use the dataloader to reduce
      #data dependencies, allowing us to work in a quasi-concurrent manner.
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class LayerNorm(nn.Module):
  #Defines a residual connection around each of the two sub layers, followed by layer normalisation.
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
      #Output of each sublayer is LayerNorm ( x+ Sublayer (x)) where Sublayer(x) is the function implemented by the sub-layer itself.
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
  # A residual layer followed by layer normalisation.
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class EncoderLayer(nn.Module):
  #Class used to define the sublayers for each layer. By applying dropout to the output of each sublayers
  #we produce output predictions of dimension 512. These sublayers contain the multi-head self attention m,echanism,
  # and the position-wise fully connected neural network (feed-forawrd).
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
      #TO DO
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        x = self.sublayer[1](x, self.feed_forward)
        return x

class Decoder(nn.Module):
  #Class defines a decoder stack composed of N=6 identical layers. Decoder inserts a third sub-layer, which performs mutli-head attention ovver the output of the encoder stack. 
  #We therefore need to apply residual connectiosn around the sub-layers, followed by layer normalisation. 
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)



class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        x = self.sublayer[2](x, self.feed_forward)
        return x
    def subsequent_mask(size):
      attn_shape = (1, size, size)
      subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
      return torch.from_numpy(subsequent_mask)==0


"""
Attention

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

The authors call their particular attention “Scaled Dot-Product Attention”. The input consists of queries and keys of dimension dk,
and values of dimension dv. We compute the dot products of the query with all keys, divide each by √dk, and apply a softmax function to obtain the weights on the values. Apart from Dot-product attention, there is also additive attention which compurtes the compatability function using a feed-forward network with a single hidden layer.

ot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.

While for small values of dk the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of dk.

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. Where the projections are parameter matrices WQi∈Rdmodel×dk, WKi∈Rdmodel×dk, WVi∈Rdmodel×dv and WO∈Rhdv×dmodel.
In this model the authors employ h=8 parallel attention layers, or heads. For each of these we use
dk=dv=dmodel/h=64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

Additional to the attention sublayers, each of the layers contains a fully connected feedforward network which is appplied to each position seperately and identically.
Annalogous to two convolutions with kernel size 1. The dimensionality of input and output is dmodel=512, and the inner-layer has dimensionality dff=2048.

"""

def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        self.h = h
        self.d_k = d_model // h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        seq_len = query.size(1)
        #Do linear projections in each batch => h x d_k
        query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1,2) for lin, x in zip(self.linears[:-1], (query, key, value))]
        #Apply the attention on all the projection vectors for this batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        #Concatenate using the view function, and apply a linear filter over output.
        x = x.transpose(1,2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

class PositionWiseFeedForward(nn.Module):
  #Class defines a model that uses learned embeddings to convert the input tokens and output tokens to vectors of dimension d=512
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

class Embeddings(nn.Module):
  #Class defines the learned linear transformation and softmax function to convert the decoder output into predicted token probabilities.
    def __init__(self, vocab_size, d_model):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-(math.log(10000.0) / d_model)))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)        
        return self.dropout(x)


def make_m(src_vocab_size, tgt_vocab_size, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model, dropout)
    ff = PositionWiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(src_vocab_size, d_model), c(position)),
        nn.Sequential(Embeddings(tgt_vocab_size, d_model), c(position)),
        Generator(d_model, tgt_vocab_size)
    )
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

  
class NoamOpt:
  def __init__(self, model_size, factor, warmup, optimizer):
    self.optimizer = optimizer
    self._step = 0
    self.warmup = warmup
    self.factor = factor
    self.model_size = model_size
    self._rate = 0

  def rate(self):
    return self.factor * (self.model_size ** (-0.5) * min(self._step ** (-0.5), self._step * self.warmup ** (-1.5)))

  def step(self):
    self._step += 1
    rate = self.rate()
    for p in self.optimizer.param_groups:
      p['lr'] = rate
    self._rate = rate
    self.optimizer.step()

# --------‐ Loss & LossCompute modifications for MLM ---------------------------------------------------------------------

class LabelSmoothing(nn.Module):
    "Implement label smoothing (if you want; optional)."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

class LossCompute:
    def __init__(self, generator, opt=None):
        self.generator = generator
        self.opt = opt

    def __call__(self, out, labels):
        """
        out: [B*K, T, V] (logits from model before generator if generator is separate, else adjust)
        labels: [B*K, T] with -100 for non‐mask positions
        returns: loss (averaged over masked tokens)
        """
      
        # here assume out is the raw transformer outputs (before generator)
        logits = self.generator(out)  # [B*K, T, Vocab]

        Bk, T, V = logits.size()
        logits_flat = logits.contiguous().view(-1, V)           # [Bk * T, V]
        labels_flat = labels.contiguous().view(-1)              # [Bk * T]
        loss = F.cross_entropy(logits_flat, labels_flat, ignore_index=-100)

        if self.opt is not None:
            loss.backward()
            self.opt.optimizer.zero_grad()
            self.opt.step()

        # Return total loss, plus count of masked tokens if prefered; here return loss and masked count
        # Compute number of non‐ignore positions
        masked = (labels_flat != -100).sum().item()
        return loss.item(), masked

# --------‐ Training / Evaluation Loop ---------------------------------------------------------------------

def run_mlm(data_loader, model, loss_compute, vocab, training=True):
    
    total_loss = 0.0
    total_masked = 0
    if training:
        model.train()
       
        data_iter = tqdm(data_loader, desc="Training", leave=False)
    else:
        model.eval()
        data_iter = tqdm(data_loader, desc="Validation", leave=False)
    torch.set_grad_enabled(training)
    for batch in data_iter:
            x_ens, labels, src_mask = batch   # x_ens: [B, K]
            B, K, T = x_ens.size()
           
            x_input = x_ens.view(B * K, T)  # [B*K, T]
            labels = labels.to(device) 
            src_mask_input = (x_input == vocab.pad_index).unsqueeze(1)  # [B*K, 1, T]
            out = model.forward(x_input, x_input, src_mask_input, src_mask_input)  # shape [B*K, T, d_model]
            
            loss_val, masked_count = loss_compute(out, labels.repeat_interleave(K, dim=0))

            total_loss += loss_val * masked_count
            total_masked += masked_count

    avg_loss = total_loss / total_masked if total_masked > 0 else float("inf")
    perplexity = math.exp(avg_loss) if avg_loss < float("inf") else float("inf")
    return avg_loss, perplexity

# --------‐---------------- Main--------------------------------------------------------------------

def main():
    # 1. Initialize the Accelerator
    accelerator = Accelerator(mixed_precision="fp16")
    device = accelerator.device # Use accelerator's device

    train_sents = read_split(Path(cfg["data_root"]) / "train.txt")
    dev_sents   = read_split(Path(cfg["data_root"]) / "dev.txt")

    vocab = Vocab(train_sents, cfg["vocab_min_freq"])
    V = len(vocab.itos)
    print("Vocab size:", V)

    train_loader = DataLoader(MLMDataset(vocab, train_sents),
                              batch_size=cfg["batch_size"],
                              shuffle=True,
                              collate_fn=lambda b: collate_mlm(b, vocab))
    dev_loader   = DataLoader(MLMDataset(vocab, dev_sents),
                              batch_size=cfg["batch_size"],
                              shuffle=False,
                              collate_fn=lambda b: collate_mlm(b, vocab))

    model = make_m(V, V, N=2, d_model=cfg["embed_dim"], d_ff=cfg["hidden_dim"], dropout=0.1) 
    model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    # 3. Wrap model, optimizer, and data loaders with accelerator.prepare()
    model, model_opt, train_loader, dev_loader = accelerator.prepare(
        model, model_opt, train_loader, dev_loader
    )

    loss_compute = LossCompute(model.generator, opt=model_opt)

    best_val_loss = float("inf")
    print("Running Training...")
    for epoch in range(1, cfg["epochs"] + 1):
        start = time.time()
        
        # Pass device from accelerator to run_mlm
        train_avg_loss, train_perplexity = run_mlm(train_loader, model, loss_compute, vocab, training=True)
        val_avg_loss, val_perplexity = run_mlm(dev_loader, model, LossCompute(model.generator, opt=None), vocab, training=False)


        elapsed = time.time() - start
        print(f"Epoch {epoch:02d} | train_loss_per_masked={math.exp(train_avg_loss):6.2f} | train_perplexity={train_perplexity*100:5.1f} | val_loss_per_masked={math.exp(val_avg_loss):6.2f} | val_perplexity={val_perplexity*100:5.1f} | time={elapsed:.1f}s")


        if val_avg_loss < best_val_loss:
            best_val_loss = val_avg_loss
            os.makedirs(cfg["checkpoint_dir"], exist_ok=True)
            
            accelerator.save_state(Path(cfg["checkpoint_dir"]) / f"best_mlm_epoch{epoch}")

if __name__ == "__main__":
    main()



Device: cuda
Vocab size: 24937
Running Training...


Training:  12%|█▏        | 521/4289 [01:43<12:42,  4.94it/s]