In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

In [None]:
pip install ftfy

In [None]:
from datasets import load_dataset
import re
from ftfy import fix_text

def clean_and_normalize_text(text: str) -> str:
    # Step 1: Use ftfy for general mojibake fixing
    text = fix_text(text)

    # Step 2: Replace known leftover mojibake manually
    replacements = {
        # Common double quotes (mojibake)
        'â€œ': '"',  # “
        'â€': '"',  # ” (some datasets use this variant)
        'â€': '"',   # leftover quote fragments
        'Â«': '"',   # «
        'Â»': '"',   # »
    
        # Common single quotes / apostrophes
        'â€˜': "'",  # ‘
        'â€™': "'",  # ’
        'â€²': "'",  # ′ (prime used instead of apostrophe sometimes)
        'â€³': '"',  # ″ double prime → double quote
    
        # En dash / em dash
        'â€“': '-',  # –
        'â€”': '-',  # —
        'âˆ’': '-',  # − (minus symbol)
    
        # Ellipsis
        'â€¦': '...',  # …
    
        # Currency symbols (guess replacements)
        'â‚¬': '€',  # euro
        'â‚£': '£',  # pound
        'â‚¥': '¥',  # yen
    
        # Bullets / middots
        'â€¢': '•',  # bullet
        'Â·': '·',   # middle dot
    
        # Accented letter fixes (when partial ftfy fails)
        'Ã©': 'é',
        'Ã¨': 'è',
        'Ã¢': 'â',
        'Ã´': 'ô',
        'Ãà': 'à',
        'Ãª': 'ê',
        'Ã«': 'ë',
        'Ã¹': 'ù',
        'Ã¼': 'ü',
        'Ã¶': 'ö',
        'Ã„': 'Ä',
        'Ãœ': 'Ü',
        'Ã–': 'Ö',
        'ÃŸ': 'ß',
        'Ã±': 'ñ',
    
        # Occasionally seen garbage characters
        'Â': '',     # stray non-breaking space marker
        'âˆ†': '',   # delta artifact
        'âˆž': '∞',  # infinity sign
        'â„¢': '™',  # trademark
        'âš«': '',   # stray symbol
        'âœ”': '✔',  # checkmark
    
        # Daggers and unknown artifacts
        '†': ' ',
        '‡': ' ',
        '�': ''      # replacement char for unknown glyph
    }

    for bad, good in replacements.items():
        text = text.replace(bad, good)

    # # Step 3: Normalize repeated/mismatched quotes (optional formatting cleanup)
    # text = re.sub(r'\s+"', '"', text)          # Remove extra leading spaces before quotes
    # text = re.sub(r'"\s+', '" ', text)         # Ensure a space after closing quotes when needed

    # # Step 4: Compact multiple spaces
    # text = re.sub(r'\s{2,}', ' ', text)

    # # Step 5: Strip leading/trailing spaces
    # text = text.strip()

    return text
for i, item in enumerate(ds["train"]):
    if i == 529932:
        print("j")
        text = clean_and_normalize_text(item['text'])  # Apply ftfy here
        
        if "Once upon a time there was a small, humble dog named Mittens. She loved to go for walks in the woods," in text:
            print("Found at index:", i)
            print(text)
            break


In [None]:
import os
import random
import math
import gzip
from pathlib import Path
from collections import Counter


import numpy as np
import pandas as pd


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


from datasets import load_dataset
import spacy
# from gensim.models import KeyedVectors


import matplotlib.pyplot as plt
from tqdm.auto import tqdm



SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)




In [None]:
import os
import pickle
import spacy
from datasets import load_dataset, Dataset
from tqdm import tqdm
import multiprocessing as mp
import re
from ftfy import fix_text
import json



print("Loading spaCy model 'en_core_web_sm'...")
nlp = spacy.load("en_core_web_sm", disable=["ner", "parser"])
tokenizer = nlp.tokenizer
print("spaCy tokenizer ready.")




def tokenize_text(text):
    """Tokenizes a single text string."""
    tokens = [t.text for t in tokenizer(str(text)) if not t.is_space]
    return tokens


tokenized_ds_path = "./tokenized_hf_dataset"
final_combined_path = "./combined_tokenized_texts_300dim_v2.pkl"

print("\n Tokenization Begins ...")

def process_batch(batch):
    # THE FIX: Apply the cleaning function before tokenizing
    cleaned_texts = [clean_and_normalize_text(t) for t in batch["text"]]
    return {"tokens": [tokenize_text(t) for t in cleaned_texts]}

for split in ["train", "validation"]:
    if split not in ds:
        print(f"⚠️ Split '{split}' not found in dataset — skipping.")
        continue

    print(f"\nProcessing '{split}' split...")
    num_proc = min(4, os.cpu_count() or 1)

    tokenized_split = ds[split].map(
        process_batch,
        batched=True,
        batch_size=1000,
        num_proc=num_proc,
        desc=f"Cleaning & Tokenizing {split}"
    )
    split_save_path = os.path.join(tokenized_ds_path, split)
    tokenized_split.save_to_disk(split_save_path)
    print(f"Finished and saved '{split}' to '{split_save_path}'.")

# Commented section was used to make vocabulary from train and val both

# if os.path.exists(final_combined_path):
#     print(f"\n Found final combined file at '{final_combined_path}'. Nothing to do.")
# else:



#     print(f"\n Streaming tokenized splits to a single .pkl file: '{final_combined_path}'...")
    
#     total_lines = 0
#     with open(final_combined_path, "wb") as f_out:
#         for split in ["train", "validation"]:
#             split_save_path = os.path.join(tokenized_ds_path, split)
#             if not os.path.exists(split_save_path):
#                 continue
            
#             print(f"  Streaming from '{split}' split...")
#             tokenized_dataset = Dataset.load_from_disk(split_save_path)

#             for example in tqdm(tokenized_dataset, desc=f"  Pickling {split}"):
#                 tokens = example['tokens']
#                 if tokens:
#                     pickle.dump(tokens, f_out)
#                     total_lines += 1

#     # Save metadata for the vocab builder's progress bar
#     # with open(metadata_path, "w") as f_meta:
#     #     json.dump({"total_lines": total_lines}, f_meta)

#     print("\n Streaming complete.")
#     print(f"Final combined file '{final_combined_path}' created with {total_lines:,} tokenized sentences.")



In [None]:
## Training Begins

In [None]:
%%writefile train_ddp_v8.py
import os
import math
import pickle
from pathlib import Path
from functools import partial
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler

from torch.nn.utils.rnn import pad_sequence

import wandb

# Config

CONTEXT_LEN = 64
BATCH_SIZE = 128
SOS_ID = 2
EOS_ID = 3
UNK_ID = 1
PAD_ID = 0
HIDDEN_DIM = 300
NUM_LAYERS = 3
NUM_HEADS = 5
DROPOUT = 0.1
LR = 3e-4
EPOCHS = 10

# Dataset 
class ChunkedSequenceDataset(Dataset):


    def __init__(self, hf_dataset, word2idx, context_len, sos_id, eos_id, unk_id):
        self.word2idx = word2idx
        self.context_len = context_len
        self.sos_id = sos_id
        self.eos_id = eos_id
        self.unk_id = unk_id

        print(f"Preparing sequences (context_len={context_len})...")
        self.chunks = self._prepare_chunks(hf_dataset)
        print(f"Created {len(self.chunks):,} chunks total.")

    def _tokens_to_ids(self, tokens):
        return [self.sos_id] + [self.word2idx.get(tok, self.unk_id) for tok in tokens] + [self.eos_id]

    def _prepare_chunks(self, dataset):
        all_chunks = []
        for item in tqdm(dataset, desc="Converting to chunks"):
            tokens = item.get("tokens", None)
            if not tokens:
                continue

            ids = self._tokens_to_ids(tokens)

            # Chunk into (context_len + 1) for input/target shifting 
            for i in range(0, max(1, len(ids) - 1), self.context_len + 1):
                chunk = ids[i : i + self.context_len + 1]  
                if len(chunk) > 1:  
                    all_chunks.append(chunk)
        return all_chunks

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        # Each chunk is already a list[int]
        return torch.tensor(self.chunks[idx], dtype=torch.long)


# Collate Function — Handles padding + input/target split
def collate_batch(batch, pad_id, context_len):
    
    input_seqs, target_seqs = [], []

    for seq in batch:
        input_seq = seq[:-1][:context_len]   # up to context_len (leftmost tokens)
        target_seq = seq[1:][:context_len]   # shifted by one done, will not do in training now
        input_seqs.append(input_seq)
        target_seqs.append(target_seq)

    padded_inputs = pad_sequence(input_seqs, batch_first=True, padding_value=pad_id)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_id)

    if padded_inputs.size(1) < context_len:
        pad_width = context_len - padded_inputs.size(1)
        pad_tensor = torch.full((padded_inputs.size(0), pad_width), pad_id, dtype=torch.long)
        padded_inputs = torch.cat([padded_inputs, pad_tensor], dim=1)
        padded_targets = torch.cat([padded_targets, pad_tensor], dim=1)
    elif padded_inputs.size(1) > context_len:
        padded_inputs = padded_inputs[:, :context_len]
        padded_targets = padded_targets[:, :context_len]

    # Attention mask: 1 = real token, 0 = pad
    attention_mask = (padded_inputs != pad_id).long()

    return padded_inputs, padded_targets, attention_mask


# Layers 
class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        # x: (..., dim)
        mean = x.mean(dim=-1, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * x_norm + self.bias


def sinusoidal_positional_encoding(max_len, d_model):
    pe = np.zeros((max_len, d_model), dtype=np.float32)
    position = np.arange(0, max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return torch.from_numpy(pe)  # (max_len, d_model)


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = math.sqrt(self.head_dim)

        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, padding_mask=None, output_attentions: bool = False):
        # x: (b, t, d_model)
        b, t, _ = x.size()

        q = self.wq(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.wk(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.wv(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)

        # scores: (b, heads, t, t)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # masking + softmax in float32 to avoid fp16 overflow
        orig_dtype = scores.dtype
        scores = scores.float()

        causal = torch.tril(torch.ones((t, t), dtype=torch.bool, device=x.device))
        allowed = causal.unsqueeze(0).unsqueeze(0)  # (1,1,t,t)

        if padding_mask is not None:
            pad_bool = padding_mask if padding_mask.dtype == torch.bool else (padding_mask == 0)
            key_is_real = (~pad_bool).unsqueeze(1).unsqueeze(2)  # (b,1,1,t)
            allowed = allowed & key_is_real
            allowed = allowed.expand(b, self.num_heads, t, t)
        else:
            allowed = allowed.expand(b, self.num_heads, t, t)

        scores = scores.masked_fill(~allowed, torch.finfo(scores.dtype).min)
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        attn = attn.to(v.dtype)
        context = torch.matmul(attn, v)  # (b, heads, t, head_dim)
        context = context.transpose(1, 2).contiguous().view(b, t, self.d_model)
        out = self.wo(context)


        return out, attn



class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim=None, dropout=0.1):
        super().__init__()
        hidden_dim = hidden_dim or (4 * d_model)
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.dropout(self.activation(self.fc1(x))))


class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads, dropout)
        self.ln1 = LayerNorm(d_model)
        self.ff = FeedForward(d_model, hidden_dim=4 * d_model, dropout=dropout)
        self.ln2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None,output_attentions=False):
        # Self-attention (with pre-LN)
        x_norm = self.ln1(x)
        attn_out, attn_weights = self.self_attn(x_norm, padding_mask=key_padding_mask)
        x = x + self.dropout(attn_out)

        # Feed-forward
        x_norm2 = self.ln2(x)
        ff_out = self.ff(x_norm2)
        x = x + self.dropout(ff_out)

        return x,attn_weights


class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, max_len, dropout=0.1, embedding_weights=None, freeze_embeddings=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        if embedding_weights is not None:
            self.token_embedding.weight.data.copy_(torch.from_numpy(embedding_weights))
            if freeze_embeddings:
                self.token_embedding.weight.requires_grad = False

        pe = sinusoidal_positional_encoding(max_len, d_model)  # (max_len, d_model)
        self.register_buffer("position_encoding", pe)  # persistent buffer

        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, dropout) for _ in range(num_layers)])
        self.final_ln = LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None,output_attentions=False):
        # input_ids: (batch, seq_len)
        b, t = input_ids.size()
        tok_emb = self.token_embedding(input_ids)  # (b, t, d_model)
        pos_emb = self.position_encoding[:t, :].unsqueeze(0).expand(b, -1, -1)  # (b, t, d_model)
        x = tok_emb + pos_emb  # (b, t, d_model)

        # attention_mask: (b, t) with 1 for real tokens, 0 for pad
        if attention_mask is None:
            key_padding_mask = None
        else:
            
            key_padding_mask = (attention_mask == 0)  # shape (b, t), dtype=bool


        all_attentions = []
  
        for layer in self.layers:
            if output_attentions:
                x, layer_attn = layer(x, key_padding_mask=key_padding_mask, output_attentions=True)
                all_attentions.append(layer_attn)  # list of tensors (b, heads, t, t)
            else:
                x,_ = layer(x, key_padding_mask=key_padding_mask)
        x = self.final_ln(x)
        logits = self.output_linear(x)
        return logits, all_attentions



# Training / Main
def main():
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device:', DEVICE)
    dist.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    tokenized_ds_path = "tokenized_hf_dataset"
    VOCAB_SAVE_PATH = "/kaggle/input/300dim-utils/vocab_300dim.pkl"
    EMBEDDING_MATRIX_SAVE_PATH = "/kaggle/input/300dim-utils/embedding_matrix_300dim.pkl"

    # Load vocab
    print(f"\n Found saved vocabulary file at '{VOCAB_SAVE_PATH}'. Loading...")
    with open(VOCAB_SAVE_PATH, "rb") as f:
        vocab_data = pickle.load(f)
        word2idx = vocab_data['word2idx']
        idx2word = vocab_data['idx2word']
    print(f"Loaded vocabulary with {len(word2idx):,} tokens.")
    VOCAB_SIZE = len(word2idx)

    # Load embedding matrix
    EMBEDDING_DIM = 100
    print(f"\n Found saved embedding matrix file at '{EMBEDDING_MATRIX_SAVE_PATH}'. Loading...")
    with open(EMBEDDING_MATRIX_SAVE_PATH, "rb") as f:
        embedding_matrix = pickle.load(f)
    print(f"Loaded embedding matrix with shape {embedding_matrix.shape}.")

    print(" Loading tokenized dataset...")
    from datasets import load_from_disk
    train_ds = load_from_disk(os.path.join(tokenized_ds_path, "train"))
    val_ds = load_from_disk(os.path.join(tokenized_ds_path, "validation"))

    # Build PyTorch Datasets
    print("Building Dataset")
    train_dataset = ChunkedSequenceDataset(train_ds, word2idx, CONTEXT_LEN, SOS_ID, EOS_ID, UNK_ID)
    val_dataset = ChunkedSequenceDataset(val_ds, word2idx, CONTEXT_LEN, SOS_ID, EOS_ID, UNK_ID)

    # DataLoaders
    collate_fn = partial(collate_batch, pad_id=PAD_ID, context_len=CONTEXT_LEN)

    print("Building Sampler")
    train_sampler = DistributedSampler(train_dataset, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, shuffle=False)

    print("Building Loader")
    SAVE_DIR = "./checkpoints"
    os.makedirs(SAVE_DIR, exist_ok=True)
    USE_WANDB = True

    # Init wandb only on rank 0
    rank = dist.get_rank() if dist.is_initialized() else 0
    world_size = dist.get_world_size() if dist.is_initialized() else 1

    os.environ["WANDB_MODE"] = "offline"  s
    
    
    if USE_WANDB:
        wandb.require("service")
        wandb.init(
            project="transformer-training",
            mode="offline",   
            config={
                "epochs": EPOCHS,
                "lr": LR,
                "batch_size": BATCH_SIZE,
                "hidden_dim": HIDDEN_DIM,
                "num_layers": NUM_LAYERS,
                "num_heads": NUM_HEADS
            }
        )
   

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              sampler=train_sampler, collate_fn=collate_fn,
                              num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                            sampler=val_sampler, collate_fn=collate_fn,
                            num_workers=4, pin_memory=True)

    # Model
    print("Starting Model")
    model = DecoderOnlyTransformer(
        vocab_size=VOCAB_SIZE,
        d_model=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        max_len=CONTEXT_LEN,
        dropout=DROPOUT,
        embedding_weights=embedding_matrix,
        freeze_embeddings=True
    ).to(device)

    model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    pad_idx = PAD_ID
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
    scaler = torch.cuda.amp.GradScaler()
    
    train_losses, val_losses, perplexities = [], [], []
    
    # =============== TRAINING LOOP ===============
    print("Training Begins....")
    # Path to checkpoint
    CKPT_PATH = "./epoch_10.pt"  # example: resume from epoch 5
    START_EPOCH = 1  # default if no checkpoint
    
    if os.path.exists(CKPT_PATH):
        print(f"Loading checkpoint from {CKPT_PATH} ...")
        checkpoint = torch.load(CKPT_PATH, map_location=DEVICE)
    
        # Restore model parameters
        model.module.load_state_dict(checkpoint["model_state_dict"])
    
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])
        train_losses = checkpoint.get("train_losses", [])
        val_losses = checkpoint.get("val_losses", [])
        perplexities = checkpoint.get("perplexities", [])
    
        # Continue from  epoch
        START_EPOCH = checkpoint["epoch"] + 1
    
        print(f" Resumed from epoch {checkpoint['epoch']} — continuing from epoch {START_EPOCH}")
    else:
        print(" No checkpoint found — starting training from scratch")

    
    for epoch in range(START_EPOCH, EPOCHS + 1):
        torch.cuda.reset_peak_memory_stats(DEVICE)
        model.train()
        running_loss = 0.0
        total_steps = 0
    
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} train")
    
        for step, (input_ids, target_ids, _) in pbar:
            input_ids = input_ids.to(DEVICE, non_blocking=True)
            target_ids = target_ids.to(DEVICE, non_blocking=True)
            attention_mask = (input_ids != pad_idx).long()
    
            with torch.cuda.amp.autocast():
                logits,_ = model(input_ids, attention_mask=attention_mask)
                logits_flat = logits.view(-1, VOCAB_SIZE)
                targets_flat = target_ids.view(-1)
                loss = criterion(logits_flat, targets_flat)
    
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
    
            running_loss += loss.item()
            total_steps += 1
            avg_loss = running_loss / total_steps
            pbar.set_postfix({'train_loss': avg_loss})
    
        train_loss = running_loss / total_steps
        train_losses.append(train_loss)
    
        # =============== VALIDATION ===============
        model.eval()
        val_loss = 0.0
        val_steps = 0
        with torch.no_grad():
            for input_ids, target_ids, _ in tqdm(val_loader, desc=f"Epoch {epoch} val"):
                input_ids = input_ids.to(DEVICE, non_blocking=True)
                target_ids = target_ids.to(DEVICE, non_blocking=True)
                attention_mask = (input_ids != pad_idx).long()
    
                with torch.cuda.amp.autocast():
                    logits,all_attns = model(input_ids, attention_mask=attention_mask, output_attentions = True)
                    logits_flat = logits.view(-1, VOCAB_SIZE)
                    targets_flat = target_ids.view(-1)
                    loss = criterion(logits_flat, targets_flat)
    
                val_loss += loss.item()
                val_steps += 1
    
        val_loss /= val_steps
        val_losses.append(val_loss)
        perplexity = math.exp(val_loss)
        perplexities.append(perplexity)
    
        peak_mem = torch.cuda.max_memory_allocated(DEVICE) / (1024 ** 3)
    
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, PPL={perplexity:.2f}, PeakMem={peak_mem:.2f} GB")
    
        if USE_WANDB:
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "perplexity": perplexity,
                "peak_gpu_memory_gb": peak_mem
            })
    
        # =============== CHECKPOINTING ===============
        ckpt_path = os.path.join(SAVE_DIR, f"epoch_{epoch}.pt")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.module.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "train_losses": train_losses,
            "val_losses": val_losses,
            "perplexities": perplexities
        }, ckpt_path)
    
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, EPOCHS + 1), train_losses, label="Train Loss", marker='o')
    plt.plot(range(1, EPOCHS + 1), val_losses, label="Validation Loss", marker='o')
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
    plt.show()
    
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, EPOCHS + 1), perplexities, label="Perplexity", color="purple", marker='o')
    plt.title("Validation Perplexity Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Perplexity")
    plt.grid(True)
    plt.savefig(os.path.join(SAVE_DIR, "perplexity_curve.png"))
    plt.show()
    
    if USE_WANDB:
        wandb.finish()
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

In [None]:
!torchrun --standalone --nproc_per_node=2 train_ddp_v8.py

In [None]:
## Inference

In [None]:
## Same Code but since we used DDP earlier, we have to initialise everything Again

import os
import math
import pickle
from pathlib import Path
from functools import partial
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler

from torch.nn.utils.rnn import pad_sequence

import wandb

# Config
CONTEXT_LEN = 64
BATCH_SIZE = 128
SOS_ID = 2
EOS_ID = 3
UNK_ID = 1
PAD_ID = 0
HIDDEN_DIM = 300
NUM_LAYERS = 3
NUM_HEADS = 5
DROPOUT = 0.1
LR = 3e-4
EPOCHS = 10

class ChunkedSequenceDataset(Dataset):


    def __init__(self, hf_dataset, word2idx, context_len, sos_id, eos_id, unk_id):
        self.word2idx = word2idx
        self.context_len = context_len
        self.sos_id = sos_id
        self.eos_id = eos_id
        self.unk_id = unk_id

        print(f"Preparing sequences (context_len={context_len})...")
        self.chunks = self._prepare_chunks(hf_dataset)
        print(f"Created {len(self.chunks):,} chunks total.")

    def _tokens_to_ids(self, tokens):
        return [self.sos_id] + [self.word2idx.get(tok, self.unk_id) for tok in tokens] + [self.eos_id]

    def _prepare_chunks(self, dataset):
        all_chunks = []
        for item in tqdm(dataset, desc="Converting to chunks"):
            tokens = item.get("tokens", None)
            if not tokens:
                continue

            ids = self._tokens_to_ids(tokens)

            # Chunk into (context_len + 1) for input/target shifting 
            for i in range(0, max(1, len(ids) - 1), self.context_len + 1):
                chunk = ids[i : i + self.context_len + 1]  
                if len(chunk) > 1:  
                    all_chunks.append(chunk)
        return all_chunks

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        # Each chunk is already a list[int]
        return torch.tensor(self.chunks[idx], dtype=torch.long)


# Collate Function — Handles padding + input/target split
def collate_batch(batch, pad_id, context_len):
    
    input_seqs, target_seqs = [], []

    for seq in batch:
        input_seq = seq[:-1][:context_len]   # up to context_len (leftmost tokens)
        target_seq = seq[1:][:context_len]   # shifted by one done, will not do in training now
        input_seqs.append(input_seq)
        target_seqs.append(target_seq)

    padded_inputs = pad_sequence(input_seqs, batch_first=True, padding_value=pad_id)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_id)

    if padded_inputs.size(1) < context_len:
        pad_width = context_len - padded_inputs.size(1)
        pad_tensor = torch.full((padded_inputs.size(0), pad_width), pad_id, dtype=torch.long)
        padded_inputs = torch.cat([padded_inputs, pad_tensor], dim=1)
        padded_targets = torch.cat([padded_targets, pad_tensor], dim=1)
    elif padded_inputs.size(1) > context_len:
        padded_inputs = padded_inputs[:, :context_len]
        padded_targets = padded_targets[:, :context_len]

    # Attention mask: 1 = real token, 0 = pad
    attention_mask = (padded_inputs != pad_id).long()

    return padded_inputs, padded_targets, attention_mask


# Layers 
class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        # x: (..., dim)
        mean = x.mean(dim=-1, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * x_norm + self.bias


def sinusoidal_positional_encoding(max_len, d_model):
    pe = np.zeros((max_len, d_model), dtype=np.float32)
    position = np.arange(0, max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return torch.from_numpy(pe)  # (max_len, d_model)


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = math.sqrt(self.head_dim)

        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, padding_mask=None, output_attentions: bool = False):
        # x: (b, t, d_model)
        b, t, _ = x.size()

        q = self.wq(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.wk(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.wv(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)

        # scores: (b, heads, t, t)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # masking + softmax in float32 to avoid fp16 overflow
        orig_dtype = scores.dtype
        scores = scores.float()

        causal = torch.tril(torch.ones((t, t), dtype=torch.bool, device=x.device))
        allowed = causal.unsqueeze(0).unsqueeze(0)  # (1,1,t,t)

        if padding_mask is not None:
            pad_bool = padding_mask if padding_mask.dtype == torch.bool else (padding_mask == 0)
            key_is_real = (~pad_bool).unsqueeze(1).unsqueeze(2)  # (b,1,1,t)
            allowed = allowed & key_is_real
            allowed = allowed.expand(b, self.num_heads, t, t)
        else:
            allowed = allowed.expand(b, self.num_heads, t, t)

        scores = scores.masked_fill(~allowed, torch.finfo(scores.dtype).min)
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        attn = attn.to(v.dtype)
        context = torch.matmul(attn, v)  # (b, heads, t, head_dim)
        context = context.transpose(1, 2).contiguous().view(b, t, self.d_model)
        out = self.wo(context)


        return out, attn



class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim=None, dropout=0.1):
        super().__init__()
        hidden_dim = hidden_dim or (4 * d_model)
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.dropout(self.activation(self.fc1(x))))


class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads, dropout)
        self.ln1 = LayerNorm(d_model)
        self.ff = FeedForward(d_model, hidden_dim=4 * d_model, dropout=dropout)
        self.ln2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None,output_attentions=False):
        # Self-attention (with pre-LN)
        x_norm = self.ln1(x)
        attn_out, attn_weights = self.self_attn(x_norm, padding_mask=key_padding_mask)
        x = x + self.dropout(attn_out)

        # Feed-forward
        x_norm2 = self.ln2(x)
        ff_out = self.ff(x_norm2)
        x = x + self.dropout(ff_out)

        return x,attn_weights


class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, max_len, dropout=0.1, embedding_weights=None, freeze_embeddings=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        if embedding_weights is not None:
            self.token_embedding.weight.data.copy_(torch.from_numpy(embedding_weights))
            if freeze_embeddings:
                self.token_embedding.weight.requires_grad = False

        pe = sinusoidal_positional_encoding(max_len, d_model)  # (max_len, d_model)
        self.register_buffer("position_encoding", pe)  # persistent buffer

        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, dropout) for _ in range(num_layers)])
        self.final_ln = LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None,output_attentions=False):
        # input_ids: (batch, seq_len)
        b, t = input_ids.size()
        tok_emb = self.token_embedding(input_ids)  # (b, t, d_model)
        pos_emb = self.position_encoding[:t, :].unsqueeze(0).expand(b, -1, -1)  # (b, t, d_model)
        x = tok_emb + pos_emb  # (b, t, d_model)

        # attention_mask: (b, t) with 1 for real tokens, 0 for pad
        if attention_mask is None:
            key_padding_mask = None
        else:
            
            key_padding_mask = (attention_mask == 0)  # shape (b, t), dtype=bool


        all_attentions = []
  
        for layer in self.layers:
            if output_attentions:
                x, layer_attn = layer(x, key_padding_mask=key_padding_mask, output_attentions=True)
                all_attentions.append(layer_attn)  # list of tensors (b, heads, t, t)
            else:
                x,_ = layer(x, key_padding_mask=key_padding_mask)
        x = self.final_ln(x)
        logits = self.output_linear(x)
        return logits, all_attentions




In [None]:
import torch, pickle, os
from datasets import load_from_disk

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# === Load vocabulary ===
with open("/kaggle/input/300dim-utils/vocab_300dim.pkl", "rb") as f:
    vocab_data = pickle.load(f)
word2idx = vocab_data['word2idx']
idx2word = vocab_data['idx2word']

VOCAB_SIZE = len(word2idx)
print(f"Loaded vocab of size {VOCAB_SIZE}")

# === Load embedding matrix ===
with open("/kaggle/input/300dim-utils/embedding_matrix_300dim.pkl", "rb") as f:
    embedding_matrix = pickle.load(f)
print(f"Loaded embedding matrix shape: {embedding_matrix.shape}")

# === Initialize model ===
model = DecoderOnlyTransformer(
    vocab_size=VOCAB_SIZE,
    d_model=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    max_len=CONTEXT_LEN,
    dropout=DROPOUT,
    embedding_weights=embedding_matrix,
    freeze_embeddings=True
).to(DEVICE)

# === Load checkpoint (non-DDP) ===
ckpt = torch.load("/kaggle/working/checkpoints/epoch_10.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model.eval()

print(f" Loaded model from epoch {ckpt['epoch']}")


In [None]:
!pip install evaluate


In [None]:
import evaluate
print("Evaluate version:", evaluate.__version__)


In [None]:



import random
import math
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import evaluate  # you already used evaluate.load("bleu")
bleu = evaluate.load("bleu")


@torch.no_grad()
def generate(model, prompt_text, tokenizer, word2idx, idx2word,
             max_new_tokens=50, temperature=1.0, top_k=50, eos_token="[EOS]",
             max_total_len=64):
    """
    Auto-regressive generation for DecoderOnlyTransformer.
    Returns: (all_ids, generated_ids)
      - all_ids: list of token ids including prompt and generated tokens
      - generated_ids: list of generated ids (NOT including prompt ids)
    NOTE: Works in id-space and does NOT re-tokenize generated output.
    """
    model.eval()
    device = next(model.parameters()).device

    # Tokenize prompt using spaCy (or other tokenizer you used during vocab creation)
    tokens = [t.text for t in tokenizer(prompt_text) if not t.is_space]
    prompt_ids = [word2idx.get(tok, word2idx.get("[UNK]", 0)) for tok in tokens]
    if len(prompt_ids) == 0:
        # avoid empty prompt
        prompt_ids = [word2idx.get("[SOS]", 0)] if "[SOS]" in word2idx else [0]

    all_ids = prompt_ids.copy()
    generated_ids = []

    vocab_size = len(idx2word)
    
    for step in range(max_new_tokens):
        if len(all_ids) >= max_total_len:
            
            print(f"Reached max_total_len={max_total_len}. Stopping generation.")
            break
        # optional: enforce maximum total length (position encoding limit)
        

        input_ids = torch.tensor([all_ids], dtype=torch.long, device=device)  # (1, seq_len)
        attention_mask = (input_ids != word2idx.get("[PAD]", 0)).long()  # 1: real, 0: pad

        logits, _ = model(input_ids, attention_mask=attention_mask)  # (1, seq_len, vocab)
        next_token_logits = logits[:, -1, :]  # (1, vocab)
        # apply temperature
        if temperature != 1.0:
            next_token_logits = next_token_logits / float(temperature)

        # top-k sampling
        k = min(int(top_k), next_token_logits.size(-1))
        topk_vals, topk_idx = torch.topk(next_token_logits, k=k, dim=-1)  # shapes (1,k)
        probs = F.softmax(topk_vals, dim=-1)  # (1, k)
        # sample (works for CPU/GPU)
        sampled_idx_in_topk = torch.multinomial(probs[0], num_samples=1).item()  # scalar 0..k-1
        next_token = int(topk_idx[0, sampled_idx_in_topk].item())  # the actual vocab id

        all_ids.append(next_token)
        generated_ids.append(next_token)

        # stop on EOS
        if idx2word.get(next_token, "") == eos_token:
            break
            
    all_tokens = [idx2word.get(i, "[UNK]") for i in all_ids]
    gen_tokens = [idx2word.get(i, "[UNK]") for i in generated_ids]
    return gen_tokens


In [None]:
import time
import math
import torch
import torch.nn.functional as F
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import evaluate

bleu = evaluate.load("bleu")

# ---------- Beam search generator ----------
@torch.no_grad()
def beam_search_generate(
        model, prompt_text, tokenizer, word2idx, idx2word,
        beam_size=5, max_new_tokens=50, eos_token="[EOS]",
        max_total_len=64, length_penalty=1.0):

    model.eval()
    device = next(model.parameters()).device

    
    tokens = [t.text for t in tokenizer(prompt_text) if not t.is_space]
    prompt_ids = [word2idx.get(tok, word2idx.get("[UNK]", 0)) for tok in tokens]
    if len(prompt_ids) == 0:
        prompt_ids = [word2idx.get("[SOS]", 0)]

    eos_id = word2idx.get(eos_token, None)
    pad_id = word2idx.get("[PAD]", 0)

    vocab_size = len(idx2word)

    # Beam = (token_ids, raw_logprob_sum, finished_flag)
    beams = [(prompt_ids.copy(), 0.0, False)]

    
    for step in range(max_new_tokens):

        # stop if all beams finished
        if all(b[2] for b in beams):
            break

        candidates = []

        for ids, score, finished in beams:

            # Keep finished beams unchanged
            if finished or len(ids) >= max_total_len:
                candidates.append((ids, score, True))
                continue

            input_ids = torch.tensor([ids], dtype=torch.long, device=device)
            attn_mask = (input_ids != pad_id).long()

            logits, _ = model(input_ids, attention_mask=attn_mask)
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1)[0]

            topk = min(beam_size, vocab_size)
            vals, idxs = torch.topk(log_probs, k=topk)

            for v, idx in zip(vals.tolist(), idxs.tolist()):
                new_ids = ids + [idx]
                new_score = score + v
                done = (idx == eos_id)
                candidates.append((new_ids, new_score, done))

        
        scored = []
        for ids, s, f in candidates:
            lp = (len(ids) ** length_penalty)
            lp = max(lp, 1e-6)
            norm = s / lp
            scored.append((norm, s, ids, f))

        scored.sort(key=lambda x: x[0], reverse=True)
        scored = scored[:beam_size]

        beams = [(ids, raw, flag) for (norm, raw, ids, flag) in scored]

    
    finished = [b for b in beams if b[2]]
    if len(finished) > 0:
        best_ids, best_score, _ = max(finished, key=lambda x: x[1])
    else:
        best_ids, best_score, _ = max(beams, key=lambda x: x[1])

    
    gen_ids = best_ids[len(prompt_ids):]
    gen_tokens = [idx2word.get(i, "[UNK]") for i in gen_ids]
    gen_tokens = [t for t in gen_tokens if t not in ("[SOS]", "[PAD]", "[EOS]")]

    return gen_tokens



In [None]:
prompt = "Once upon a time, there was an island"
out = beam_search_generate(
    model, prompt, tokenizer, word2idx, idx2word,
    beam_size=5, max_new_tokens=40
)
print("Generated:", " ".join(out))
generated_text = generate(
        model=model,
        prompt_text=prompt,
        tokenizer=tokenizer,
        word2idx=word2idx,
        idx2word=idx2word,
        max_new_tokens=40,
        temperature=0.90,
        top_k=50,
    )
print(f"Prompt: {prompt}")
print(f"Generated: {' '.join(generated_text)}")
