<a href="https://colab.research.google.com/github/kevalshah90/llms/blob/main/transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math                    # needed for positional encoding math functions
import torch
import torch.nn as nn         # high-level neural network modules
import torch.optim as optim   # for the Adam optimizer

# Run on GPU if available, otherwise CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ⚙️ Hyperparameters
SEQ_LEN = 10     # length of input/output sequences
VOCAB_SIZE = 50  # number of unique tokens (IDs 0..VOCAB_SIZE-1)
D_MODEL = 32     # embedding/hidden dimension size
NHEAD = 4        # number of attention heads
NUM_LAYERS = 2   # number of transformer blocks
BATCH_SIZE = 16  # number of sequences per training batch
EPOCHS = 5       # how many training passes to run

In [5]:
class PositionalEncoding(nn.Module):
    """
    Adds fixed sinusoidal positional vectors per token position to embeddings,
    so the Transformer can understand sequence order (sin(pos/10000^(i/d_model))).
    """
    def __init__(self, d_model, max_len=SEQ_LEN):
        super().__init__()
        pe = torch.zeros(max_len, d_model, device=DEVICE)
        position = torch.arange(0, max_len, dtype=torch.float, device=DEVICE).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=DEVICE).float()
                             * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(1)

    def forward(self, x):
        # x: (seq_len, batch_size, d_model)
        return x + self.pe[:x.size(0)]

In [11]:
class DemoTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, seq_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc   = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=4*d_model,
            dropout=0.1,
            batch_first=False
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.seq_len   = seq_len       # store it internally
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt):
        # src: (BATCH_SIZE, T_src), tgt: (BATCH_SIZE, T_tgt)
        src = self.token_emb(src).transpose(0,1)  # → (T_src, B, d_model)
        src = self.pos_enc(src)
        tgt = self.token_emb(tgt).transpose(0,1)  # → (T_tgt, B, d_model)
        tgt = self.pos_enc(tgt)

        T_tgt = tgt.size(0)  # actual target length
        # mask prevents each position seeing future positions
        tgt_mask = self.transformer.generate_square_subsequent_mask(T_tgt).to(tgt.device)

        out = self.transformer(src, tgt, tgt_mask=tgt_mask)
        out = self.fc_out(out)        # (T_tgt, B, vocab_size)
        return out.transpose(0,1)     # → (B, T_tgt, vocab_size)

In [12]:
def train_demo():
    model = DemoTransformer(VOCAB_SIZE, D_MODEL, NHEAD, NUM_LAYERS, SEQ_LEN).to(DEVICE)
    # CrossEntropyLoss combines log-softmax + negative log likelihood in one efficient op   [oai_citation:7‡Reddit](https://www.reddit.com/r/MLQuestions/comments/ledtzw/torchnncrossentropyloss/?utm_source=chatgpt.com)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    print("Training on synthetic data…")
    model.train()
    for epoch in range(1, EPOCHS + 1):
        # Generate random token sequences in [0, VOCAB_SIZE)
        src = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=DEVICE)
        tgt_input = src[:, :-1]  # decoder input (everything except last token)
        tgt_output = src[:, 1:]  # target next-token labels

        logits = model(src, tgt_input)   # (BATCH, seq_len-1, vocab_size)

        # Prepare for CrossEntropyLoss:
        # - logits must be shape (N, C) where N = BATCH*(seq_len-1), C = vocab_size
        # - targets must be shape (N,) with integer class indices
        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),
            tgt_output.reshape(-1)
        )
        # Internally, CrossEntropyLoss does:
        #   1. LogSoftmax on logits → log-probabilities
        #   2. Picks log(p_true) for each position
        #   3. Takes negative and averages over samples 🡪 hence lower loss = better performance  [oai_citation:8‡Zero To Mastery](https://zerotomastery.io/blog/pytorch-loss-functions/?utm_source=chatgpt.com) [oai_citation:9‡PyTorch Forums](https://discuss.pytorch.org/t/how-does-nn-crossentropyloss-work-under-the-hood/126431?utm_source=chatgpt.com)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"  epoch {epoch}/{EPOCHS}, loss = {loss.item():.4f}")

    return model

def greedy_decode(model):
    """
    Demonstrates step-by-step generation by greedy decoding:
    at each timestep, we feed the previously generated tokens into the decoder.
    """
    model.eval()
    with torch.no_grad():
        # Seed sequence: random
        input_seq = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN), device=DEVICE)
        output_seq = [input_seq[0, 0].item()]  # start with the first token from seed

        for t in range(1, SEQ_LEN):
            src = input_seq[:, :t]  # feed all previous tokens as src
            tgt = torch.tensor([output_seq + [0] * (SEQ_LEN - len(output_seq))],
                               dtype=torch.long, device=DEVICE)
            logits = model(src, tgt)  # shape: (1, seq_len, vocab_size)
            next_tok = logits[0, t - 1].argmax().item()
            output_seq.append(next_tok)

        print("Input (first seed):", input_seq[0].tolist())
        print("Greedy output:", output_seq)

if __name__ == "__main__":
    m = train_demo()
    greedy_decode(m)

Training on synthetic data…
  epoch 1/5, loss = 4.2954
  epoch 2/5, loss = 4.2355
  epoch 3/5, loss = 4.2458
  epoch 4/5, loss = 4.1537
  epoch 5/5, loss = 4.0078
Input (first seed): [42, 36, 29, 26, 46, 29, 24, 18, 34, 30]
Greedy output: [42, 23, 23, 23, 23, 25, 25, 25, 25, 39]


