In [13]:
import math
from typing import Optional, Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)

class EmbeddingsWithPositionalEncoding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len=max_len, dropout=dropout)
        self.d_model = d_model

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        # token_ids: (B, T)
        x = self.token_embedding(token_ids) * math.sqrt(self.d_model)
        x = self.pos_enc(x)
        return x  # (B, T, d_model)

In [16]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, attn_dropout: float = 0.0):
        super().__init__()
        self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity()

    def forward(self,
                Q: torch.Tensor,
                K: torch.Tensor,
                V: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                return_attn: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Q,K,V: (B, H, T_q, d_k), (B, H, T_k, d_k), (B, H, T_k, d_k)
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # (B, H, T_q, T_k)

        if mask is not None:
            # mask: broadcastable to (B, 1, T_q, T_k) or (B, H, T_q, T_k)
            if mask.dtype != torch.bool:
                mask = mask.to(torch.bool)
            scores = scores.masked_fill(~mask, -1e9)

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, V)  # (B, H, T_q, d_k)
        return (out, attn) if return_attn else (out, None)

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, attn_dropout: float = 0.0, proj_dropout: float = 0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.in_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.attn_fn = ScaledDotProductAttention(attn_dropout)
        self.proj_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, d_model) -> (B, H, T, d_k)
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

    def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, H, T, d_k) -> (B, T, d_model)
        B, H, T, d_k = x.shape
        return x.transpose(1, 2).contiguous().view(B, T, H * d_k)

    def forward(self, x_q: torch.Tensor, x_kv: Optional[torch.Tensor] = None,
                mask: Optional[torch.Tensor] = None, return_attn: bool = False):
        # x_q: (B, T_q, d_model); x_kv: (B, T_k, d_model) or None
        if x_kv is None:
            x_kv = x_q

        if x_q is x_kv:
            qkv = self.in_proj(x_q)  # (B, T, 3*d_model)
            q, k, v = qkv.split(self.d_model, dim=-1)
        else:
            q = self.in_proj(x_q)[:, :, :self.d_model]
            kv = self.in_proj(x_kv)[:, :, self.d_model:]
            k, v = kv.split(self.d_model, dim=-1)

        Q = self._split_heads(q)  # (B, H, T_q, d_k)
        K = self._split_heads(k)  # (B, H, T_k, d_k)
        V = self._split_heads(v)  # (B, H, T_k, d_k)

        if mask is not None:
            if mask.dtype != torch.bool:
                mask = mask.to(torch.bool)
            mask = mask.to(Q.device)

        attn_out, attn_weights = self.attn_fn(Q, K, V, mask=mask, return_attn=return_attn)  # (B, H, T_q, d_k)
        combined = self._combine_heads(attn_out)  # (B, T_q, d_model)
        out = self.out_proj(combined)
        out = self.proj_dropout(out)
        return (out, attn_weights) if return_attn else (out, None)


In [18]:
def make_padding_mask(pad_mask: torch.Tensor) -> torch.Tensor:
    if pad_mask.dtype != torch.bool:
        pad_mask = pad_mask != 0
    return pad_mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,T)

def make_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device))
    return causal.unsqueeze(0).unsqueeze(0)  # (1,1,T,T)

def combine_masks(*masks: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    out = masks[0]
    for m in masks[1:]:
        out = out & m
    return out

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, attn_dropout=dropout, proj_dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None):
        attn_out, _ = self.self_attn(x, x, mask=src_mask)
        x = x + self.drop1(attn_out)
        x = self.norm1(x)
        ffn_out = self.ffn(x)
        x = x + self.drop2(ffn_out)
        x = self.norm2(x)
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_heads: int, d_ff: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.embed = EmbeddingsWithPositionalEncoding(vocab_size, d_model, max_len=max_len, dropout=dropout)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, src_ids: torch.Tensor, src_mask: Optional[torch.Tensor] = None):
        x = self.embed(src_ids)  # (B, T, d_model)
        for layer in self.layers:
            x = layer(x, src_mask)
        x = self.final_norm(x)
        return x  # (B, T, d_model)

In [20]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, attn_dropout=dropout, proj_dropout=dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, attn_dropout=dropout, proj_dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)
        self.drop3 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, enc_output: torch.Tensor,
                tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None):
        self_attn_out, _ = self.self_attn(x, x, mask=tgt_mask)
        x = x + self.drop1(self_attn_out)
        x = self.norm1(x)

        cross_attn_out, _ = self.cross_attn(x, enc_output, mask=memory_mask)
        x = x + self.drop2(cross_attn_out)
        x = self.norm2(x)

        ffn_out = self.ffn(x)
        x = x + self.drop3(ffn_out)
        x = self.norm3(x)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_heads: int, d_ff: int, max_len: int = 5000, dropout: float = 0.3, tie_embeddings: bool = False):
        super().__init__()
        self.embed = EmbeddingsWithPositionalEncoding(vocab_size, d_model, max_len=max_len, dropout=dropout)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.final_norm = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size)
        if tie_embeddings:
            self.output_proj.weight = self.embed.token_embedding.weight
        self.d_model = d_model

    def forward(self, tgt_ids: torch.Tensor, enc_output: torch.Tensor,
                tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None):
        x = self.embed(tgt_ids)  # (B,T,d_model)
        for layer in self.layers:
            x = layer(x, enc_output, tgt_mask=tgt_mask, memory_mask=memory_mask)
        x = self.final_norm(x)
        logits = self.output_proj(x)  # (B,T,vocab)
        return logits

In [21]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, src_vocab: int, tgt_vocab: int, src_pad_idx: int, tgt_pad_idx: int,
                 d_model: int = 512, n_layers: int = 6, n_heads: int = 8, d_ff: int = 2048, max_len: int = 5000, dropout: float = 0.1, tie_embeddings: bool = False):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, n_layers, n_heads, d_ff, max_len=max_len, dropout=dropout)
        self.decoder = Decoder(tgt_vocab, d_model, n_layers, n_heads, d_ff, max_len=max_len, dropout=dropout, tie_embeddings=tie_embeddings)
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx

    def make_src_mask(self, src_ids: torch.Tensor) -> torch.Tensor:
        # src_ids: (B, S)
        return (src_ids != self.src_pad_idx).unsqueeze(1).unsqueeze(2)  # (B,1,1,S)

    def make_tgt_mask(self, tgt_ids: torch.Tensor) -> torch.Tensor:
        batch_size, tgt_len = tgt_ids.size()
        pad_mask = (tgt_ids != self.tgt_pad_idx).unsqueeze(1).unsqueeze(3)  # (B,1,T,1)
        causal = make_causal_mask(tgt_len, device=tgt_ids.device)  # (1,1,T,T)
        return pad_mask & causal  # (B,1,T,T)

    def forward(self, src_ids: torch.Tensor, tgt_ids: torch.Tensor):
        src_mask = self.make_src_mask(src_ids)  # (B,1,1,S)
        tgt_mask = self.make_tgt_mask(tgt_ids)  # (B,1,T,T)
        memory_mask = src_mask

        enc_output = self.encoder(src_ids, src_mask)
        logits = self.decoder(tgt_ids, enc_output, tgt_mask=tgt_mask, memory_mask=memory_mask)  # (B,T,V)
        return logits

In [22]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})

PAD_ID = tokenizer.pad_token_id
VOCAB_SIZE = len(tokenizer)

dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

MAX_LEN = 128
BATCH_SIZE = 32

def tokenize_batch(examples):
    out = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=MAX_LEN)
    return {"input_ids": out["input_ids"]}

tokenized = dataset.map(tokenize_batch, batched=True, remove_columns=["text"])

tokenized.set_format(type="torch", columns=["input_ids"])

train_loader = DataLoader(tokenized["train"], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(tokenized["validation"], batch_size=BATCH_SIZE, shuffle=False)

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

In [23]:
def evaluate(model: nn.Module, data_loader: DataLoader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for batch in data_loader:
            # batch is a dict-like row from datasets with "input_ids" tensor shaped (B, L)
            input_ids = batch["input_ids"].to(device)
            tgt_input = input_ids[:, :-1]
            tgt_output = input_ids[:, 1:].contiguous().view(-1)

            logits = model(input_ids, tgt_input)  # (B, T-1, V)
            logits = logits.view(-1, logits.size(-1))
            loss = criterion(logits, tgt_output)  # summed loss over non-ignored tokens

            non_pad = (tgt_output != PAD_ID).sum().item()
            if non_pad == 0:
                continue

            total_loss += loss.item()
            total_tokens += non_pad

    avg_loss = total_loss / total_tokens
    return avg_loss, math.exp(avg_loss)


def train_loop(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
               epochs: int = 6, lr: float = 3e-4, device: torch.device = device, prev_val: float = float('inf'), count: int = 0):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID, label_smoothing=0.1, reduction='sum')  # sum so we can divide by token count

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        running_tokens = 0
        for step, batch in enumerate(train_loader, start=1):
            input_ids = batch["input_ids"].to(device)
            tgt_input = input_ids[:, :-1]
            tgt_output = input_ids[:, 1:].contiguous().view(-1)

            optimizer.zero_grad()
            logits = model(input_ids, tgt_input)  # (B, T-1, V)
            logits_flat = logits.view(-1, logits.size(-1))
            loss = criterion(logits_flat, tgt_output)  # summed loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            non_pad = (tgt_output != PAD_ID).sum().item()
            if non_pad > 0:
                running_loss += loss.item()
                running_tokens += non_pad

            if step % 200 == 0:
                cur_avg = (running_loss / running_tokens) if running_tokens > 0 else float("nan")
                print(f"Epoch {epoch} Step {step} | partial train_loss_tokavg={cur_avg:.6f}")

        train_loss = running_loss / running_tokens
        train_ppl = math.exp(train_loss)

        val_loss, val_ppl = evaluate(model, val_loader, criterion, device)
        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} PPL: {train_ppl:.2f} | Val Loss: {val_loss:.4f} PPL: {val_ppl:.2f}")
        if prev_val > val_loss: prev_val = val_loss
        elif prev_val == val_loss:
          count += 1
          if count > 2: print("same val_loss repeated for 3 epochs!!!"); break
        else: print("Overfitting!!!"); break


In [24]:
D_MODEL = 512
N_LAYERS = 4
N_HEADS = 8
D_FF = 2048

model = Seq2SeqTransformer(
    src_vocab=VOCAB_SIZE,
    tgt_vocab=VOCAB_SIZE,
    src_pad_idx=PAD_ID,
    tgt_pad_idx=PAD_ID,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_ff=D_FF,
    max_len=MAX_LEN,
    dropout=0.1,
    tie_embeddings=False
).to(device)

batch = next(iter(train_loader))
input_ids = batch["input_ids"]
print("Smoke batch shapes:", input_ids.shape)

train_loop(model, train_loader, val_loader, epochs=10, lr=3e-4, device=device)

Smoke batch shapes: torch.Size([32, 128])
Epoch 1 Step 200 | partial train_loss_tokavg=7.688379
Epoch 1 Step 400 | partial train_loss_tokavg=7.395458
Epoch 1 Step 600 | partial train_loss_tokavg=7.205550
Epoch 1 Step 800 | partial train_loss_tokavg=7.064069
Epoch 1 Step 1000 | partial train_loss_tokavg=6.946462
Epoch 01 | Train Loss: 6.8726 PPL: 965.44 | Val Loss: 6.4795 PPL: 651.67
Epoch 2 Step 200 | partial train_loss_tokavg=6.079474
Epoch 2 Step 400 | partial train_loss_tokavg=6.047121
Epoch 2 Step 600 | partial train_loss_tokavg=6.005047
Epoch 2 Step 800 | partial train_loss_tokavg=5.970900
Epoch 2 Step 1000 | partial train_loss_tokavg=5.933210
Epoch 02 | Train Loss: 5.9103 PPL: 368.83 | Val Loss: 6.0728 PPL: 433.91
Epoch 3 Step 200 | partial train_loss_tokavg=5.402452
Epoch 3 Step 400 | partial train_loss_tokavg=5.391234
Epoch 3 Step 600 | partial train_loss_tokavg=5.389658
Epoch 3 Step 800 | partial train_loss_tokavg=5.380993
Epoch 3 Step 1000 | partial train_loss_tokavg=5.364440

In [25]:
@torch.no_grad()
def generate_text(model: nn.Module, tokenizer, prompt: str, max_len: int = 50, temperature: float = 2.0, top_k: int = 50, device: torch.device = device):
    model.eval()

    src_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    enc_output = model.encoder(src_ids)
    generated = torch.tensor([[tokenizer.pad_token_id]], device=device)

    for _ in range(max_len):
        tgt_mask = model.make_tgt_mask(generated)
        memory_mask = model.make_src_mask(src_ids)

        logits = model.decoder(generated, enc_output, tgt_mask=tgt_mask, memory_mask=memory_mask)
        next_token_logits = logits[0, -1, :] / temperature

        last_token = generated[0, -1]
        next_token_logits[last_token] -= 1.0

        top_vals, top_idx = torch.topk(next_token_logits, top_k)
        probs = F.softmax(top_vals, dim=-1)
        next_token = top_idx[torch.multinomial(probs, num_samples=1)]

        generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0, 1:], skip_special_tokens=True)
    return output_text

In [35]:
prompt = "In the middle of the desert, a lone traveler found"

generated_text = generate_text(model, tokenizer, prompt, max_len=100, temperature=1.5, top_k=50, device=device)

print("Generated text:\n", generated_text)

Generated text:
  Officer of a foundation of 150 million copies built a, representing determining a project in the case of the present the cause of bank of 150 tons of sustributaries and the allegedly the water a combined effort which witnessed by the border of a the basin found the,000 iron pl Such cause for the sector of the see the most accurate grant the famine of basin the use of the Crown Point Edward Coke Ridge basin below made bank of 500 . Like Cats the search engine the famine the most disassembly of the
