In [8]:
"""NER.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1aTjGzC0bWjYVrB16D-EgdJTH7H6D0fYk

#Adapted by Dean Shumbanhete
Much of the code used in this initial experiment was adapted from the code offered by Crash Course Computer Science, an online learning community. Original code can be found in the caption to the video found here:
https://www.youtube.com/watch?v=oi0JXuL19TA

Purpose of this is to try out some NLP models with Shona to see if dominant tools work well when transferring the same capabilities to Shona plain text. Original training and test data was taken from MasakhaneNER 2.0 (https://github.com/masakhane-io/masakhane-ner/tree/main/MasakhaNER2.0). This data format was already tagged and reduced to entity level entries and their associate tags which were used to train the model.

At first I had to flatten the data back into continuous text.
"""

# ---------- 0.  House-keeping ----------
!pip -q install torchmetrics  # only extra dependency
import os, json, random, math, time, datetime
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions.normal import Normal
from torch.distributions.kl import kl_divergence
from tqdm import tqdm
import matplotlib.pyplot as plt

#--------------0. Globals-----------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---------- 1.  Device -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- 2.  Data loading (from checkpoints)----------
def read_split(path: str) -> Tuple[List[List[str]], List[List[str]]]:
    """Return (sentences, ner_tags) lists."""
    lines = Path(path).read_text(encoding="utf-8").splitlines()
    sents, tags, cur_sent, cur_tags = [], [], [], []
    for ln in lines:
        ln = ln.strip()
        if not ln:
            if cur_sent:
                sents.append(cur_sent); tags.append(cur_tags)
                cur_sent, cur_tags = [], []
        else:
            tok, tag = ln.split()
            cur_sent.append(tok); cur_tags.append(tag)
    if cur_sent:                        # last example
        sents.append(cur_sent); tags.append(cur_tags)
    return sents, tags
# Load the checkpoint
def run_from_saved():
    try:
        checkpoint_path = "/kaggle/working/checkpoints/best.pt"
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load the vocabulary and other metadata
        meta_path = "/kaggle/working/checkpoints/meta.json"
        with open(meta_path, "r") as f:
            meta = json.load(f)
        vocab_size = len(meta["itos"])
        vocab = meta["itos"]
        stoi = meta["stoi"]
        
        # Create an instance of the model
        model = BayesianLSTMMLM(vocab_size, cfg["embed_dim"], cfg["hidden_dim"]).to(device)
        
        # Load the state dictionary into the model
        model.load_state_dict(checkpoint)
        model.eval()  # Set the model to evaluation mode
        
        print("Model loaded successfully from checkpoint.")
        return True
    except FileNotFoundError:
        return False

# ---------- 3.  Build vocabulary ----------
class Vocab:
    def __init__(self, sentences: List[List[str]], min_freq: int):
        counter = {}
        for s in sentences:
            for w in s:
                counter[w] = counter.get(w, 0) + 1
        self.itos = ["<PAD>", "<UNK>", "<MASK>"] + \
                    [w for w, c in counter.items() if c >= min_freq]
        self.stoi = {w:i for i,w in enumerate(self.itos)}
    def encode(self, sent): return [self.stoi.get(w, self.stoi["<UNK>"]) for w in sent]


# ---------- 4.  Configure some global variables into a dictionary for ease of use and readability----------
cfg = dict(
    data_root      = "/kaggle/input/maskhaner/pytorch/default/4/",   # change if local
    vocab_min_freq = 2,
    max_len        = 64,
    batch_size     = 32,
    embed_dim      = 128,
    hidden_dim     = 256,
    lr             = 1e-3,
    mlm_prob       = 0.15,
    epochs         = 150,
    ensemble_masks = 5,
    patience       = 5,
    checkpoint_dir = "/kaggle/working/checkpoints",
    seed           = SEED,
)
Path(cfg["checkpoint_dir"]).mkdir(exist_ok=True, parents=True)
if not (run_from_saved()):
        train_loader = DataLoader(MLMDataset(train_sents),
                                      batch_size=cfg["batch_size"],
                                      shuffle=True,
                                      collate_fn=collate_mlm)
        dev_loader   = DataLoader(MLMDataset(dev_sents),
                                      batch_size=cfg["batch_size"],
                                      shuffle=False,
                                      collate_fn=collate_mlm)
        model = BayesianLSTMMLM(len(vocab.itos), cfg["embed_dim"], cfg["hidden_dim"]).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
        best_val, patience_left = 1e9, cfg["patience"]
        train_sents, train_tags = read_split(Path(cfg["data_root"]) / "train.txt")
        test_sents , test_tags  = read_split(Path(cfg["data_root"]) / "test.txt")
        dev_sents  , dev_tags   = read_split(Path(cfg["data_root"]) / "dev.txt")
        vocab = Vocab(train_sents, cfg["vocab_min_freq"])
        PAD, UNK, MASK = vocab.stoi["<PAD>"], vocab.stoi["<UNK>"], vocab.stoi["<MASK>"]


# ---------- 5.  Masked-LM ensemble wrapper ----------
class MLMDataset(Dataset):
    def __init__(self, sentences):
        self.data = [vocab.encode(s) for s in sentences]

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx][:cfg["max_len"]]
        padded = seq + [PAD]*(cfg["max_len"] - len(seq))
        return torch.tensor(padded, dtype=torch.long)
# Create ensemble dataset
def collate_mlm(batch):
    x_raw = torch.stack(batch)                       # [B, T]
    B, T = x_raw.shape
    K = cfg["ensemble_masks"]

    # one shared target
    labels = x_raw.clone()
    prob = torch.rand_like(labels, dtype=torch.float)
    mask = prob < cfg["mlm_prob"]
    labels[~mask] = -100

    # K different masked versions
    x_ens = []
    for _ in range(K):
        x = x_raw.clone()
        # 80% MASK, 10% random, 10% keep (per mask)
        rand = prob < 0.8*cfg["mlm_prob"]
        x[rand] = MASK
        rand = (prob >= 0.8*cfg["mlm_prob"]) & (prob < 0.9*cfg["mlm_prob"])
        x[rand] = torch.randint_like(x[rand], low=3, high=len(vocab.itos))
        # remaining 10% leave unchanged
        x_ens.append(x)

    x_ens = torch.stack(x_ens, dim=1)   # [B, K, T]
    return x_ens.to(device), labels.to(device)



# ---------- 6.  Bayesian Linear layer is a variational layer as in Blundell, C., Cornebise, J., Kavukcuoglu, K. and Wierstra, D., 2015, June. Weight uncertainty in neural network. In International conference on machine learning (pp. 1613-1622). PMLR.----------
    
    
class BayesianLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_sigma=1.0):
        super().__init__()
        self.in_f, self.out_f, self.prior_sigma = in_f, out_f, prior_sigma
        # this is meant to simulate the rho(ρ) = log(1+exp(σ)), trick
        self.w_mu  = nn.Parameter(torch.zeros(out_f, in_f))
        self.w_rho = nn.Parameter(torch.ones(out_f, in_f)*-3)
        self.b_mu  = nn.Parameter(torch.zeros(out_f))
        self.b_rho = nn.Parameter(torch.ones(out_f)*-3)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1/math.sqrt(self.in_f)
        nn.init.uniform_(self.w_mu, -std, std)
        nn.init.uniform_(self.b_mu, -std, std)

    def forward(self, x):
        w_sigma = F.softplus(self.w_rho)
        b_sigma = F.softplus(self.b_rho)
        eps_w = torch.randn_like(self.w_mu)
        eps_b = torch.randn_like(self.b_mu)
        w = self.w_mu + w_sigma * eps_w
        b = self.b_mu + b_sigma * eps_b
        return F.linear(x, w, b)

    def kl_loss(self):
        prior = Normal(0, self.prior_sigma)
        w_post = Normal(self.w_mu, F.softplus(self.w_rho))
        b_post = Normal(self.b_mu, F.softplus(self.b_rho))
        return kl_divergence(w_post, prior).sum() + kl_divergence(b_post, prior).sum() #bring in noise to help with the use of the prior to determin the posterior

# ---------- 7.  Bayesian LSTM model implementation where much is borrowed from Zhu, L. and Laptev, N., 2017, November. Deep and confident prediction for time series at uber. In 2017 IEEE international conference on data mining workshops (ICDMW) (pp. 103-110). IEEE.----------
class BayesianLSTMMLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD)
        self.lstm  = EnsembleLSTM(embed_dim, hidden_dim)
        self.fc    = BayesianLinear(hidden_dim, vocab_size)

    def forward(self, x):                # x : [B, K, T]
        emb = self.embed(x)              # [B, K, T, E]
        out = self.lstm(emb)             # [B, T, H]
        return self.fc(out)              # [B, T, V]

    def kl_loss(self):
        return self.fc.kl_loss()
   
    # LSTM that receives [B, K, T, E] and returns [B, T, H]
    # by concatenating outputs along the hidden dim and then
    # projecting back to hidden_dim.
    
class EnsembleLSTM(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        # projection from K*hidden_dim back to hidden_dim
        self.merge = nn.Linear(hidden_dim * cfg["ensemble_masks"], hidden_dim)

    def forward(self, x_emb):
        # x_emb : [B, K, T, E]
        B, K, T, E = x_emb.shape
        x_emb = x_emb.view(B*K, T, E)          # treat as larger batch
        out, _ = self.lstm(x_emb)              # [B*K, T, H]
        out = out.view(B, K, T, -1)            # [B, K, T, H]
        out = out.permute(0, 2, 1, 3).contiguous()  # [B, T, K, H]
        out = out.view(B, T, -1)               # [B, T, K*H]
        return torch.tanh(self.merge(out))     # [B, T, H]

# ---------- 8.  Training ----------
    
def run_epoch(loader, training=False):
    total_loss, total_acc, n_tok = 0., 0., 0
    if training: model.train()
    else:        model.eval()
    torch.set_grad_enabled(training)
    for x, y in loader:
        if training: optimizer.zero_grad()
        logits = model(x)
        loss_ce = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-100)
        if training:
            loss_kl = model.kl_loss() / len(train_loader.dataset)  # scale KL
            loss = loss_ce + loss_kl
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        else:
            loss = loss_ce
        total_loss += loss_ce.item() * y.numel()
        preds = logits.argmax(-1)
        mask = y != -100
        total_acc += (preds[mask] == y[mask]).sum().item()
        n_tok  += mask.sum().item()
    return total_loss/n_tok, total_acc/n_tok  # avg CE, accuracy
    
#--------------9. Evaluate and generate sample generations----------------------------------
def generate(prompt_tokens, max_new=30, temperature=0.5):
    ids = vocab.encode(prompt_tokens)
    K = cfg["ensemble_masks"]

    for _ in range(max_new):
        # build 4-D tensor [1, K, T]
        x = torch.tensor(ids[-cfg["max_len"]:], dtype=torch.long, device=device)
        x = x.unsqueeze(0).expand(K, -1).unsqueeze(0)  # [1, K, T]

        with torch.no_grad():
            logits = model(x)              # [1, T, V]
            # average over the ensemble dimension (K) after softmax for stability
            probs = F.softmax(logits / temperature, dim=-1)  # [1, T, V]
            ens_prob = probs.mean(1)                           # [1, V]
            tok = torch.multinomial(ens_prob, 1).item()

        ids.append(tok)

    return " ".join(vocab.itos[i] for i in ids)


def main():        
    print("Vocab size:", len(vocab.itos))
   
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
    best_val, patience_left = 1e9, cfg["patience"]
    for epoch in range(1, cfg["epochs"]+1):
        train_loss, train_acc = run_epoch(train_loader, training=True)
        val_loss, val_acc     = run_epoch(dev_loader, training=False)
        print(f"Ep {epoch:02d}  "
              f"train_ppl={math.exp(train_loss):6.2f}  "
              f"train_acc={train_acc*100:5.1f}%  |  "
              f"val_ppl={math.exp(val_loss):6.2f}  "
              f"val_acc={val_acc*100:5.1f}%")
        # Early stopping
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), Path(cfg["checkpoint_dir"])/"best.pt")
            patience_left = cfg["patience"]
        else:
            patience_left -= 1
            if patience_left == 0:
                print("Early stopping.")
                break
    
    # ---------- 10.  Generation demo ----------
    model.load_state_dict(torch.load(Path(cfg["checkpoint_dir"])/"best.pt"))
    model.eval()
        
    print(f"Train: {len(train_sents)}  Dev: {len(dev_sents)}  Test: {len(test_sents)}")
    example = generate(["Moro", "Vakuru"], max_new=25)
    print("Generation:", example)
    
    
    # ---------- 11.  Save artefacts ----------
    meta = {
        "cfg": cfg,
        "itos": vocab.itos,
        "stoi": vocab.stoi,
    }
    json.dump(meta, open(Path(cfg["checkpoint_dir"])/"meta.json", "w"))
    print("Saved artefacts to", cfg["checkpoint_dir"])
if __name__ == "__main__":
    main()

Device: cuda
Model loaded successfully from checkpoint.
Vocab size: 9616
Ep 01  train_ppl=1603590.61  train_acc= 71.7%  |  val_ppl=4502089.63  val_acc= 72.2%
Ep 02  train_ppl=6929381.04  train_acc= 70.9%  |  val_ppl=4236470.76  val_acc= 71.8%
Ep 03  train_ppl=7595246.39  train_acc= 70.9%  |  val_ppl=9247161.54  val_acc= 71.5%
Ep 04  train_ppl=7320650.78  train_acc= 71.2%  |  val_ppl=6382112.91  val_acc= 72.2%
Ep 05  train_ppl=8533019.58  train_acc= 70.9%  |  val_ppl=15890548.99  val_acc= 71.2%
Ep 06  train_ppl=9199290.80  train_acc= 71.1%  |  val_ppl=8998381.31  val_acc= 72.3%
Ep 07  train_ppl=13163552.51  train_acc= 70.7%  |  val_ppl=7422958.54  val_acc= 72.9%
Early stopping.
Train: 6207  Dev: 887  Test: 1773
Generation: <UNK> Vakuru <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> vanosvika <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>
Saved artefacts to /kaggle/working/checkpoints
