# RNN, LSTM, GRU (+ Bi-variants) with Attention & FC -A PyTorch notebook
  
This notebook collects the material we discussed: what recurrence is, `batch/seq_len/input_size`, cell state in LSTM, why you need an `fc` head, attention (additive/Bahdanau style), and runnable PyTorch code for GRU/LSTM/RNN with attention as well as BiLSTM/BiGRU versions.  
The toy dataset is intentionally tiny to run fast and demonstrate shapes, training loop, and attention visualization.


## What you'll find here (quick tour)

1. Quick conceptual recap (recurrence, embeddings, RNN outputs vs predictions).  
2. A tiny synthetic dataset (token sequences) so we can show end-to-end code without downloading anything.  
3. An **Attention** module (additive / Bahdanau-style).  
4. Model classes: `SimpleRNN`, `LSTMModel`, `GRUModel`, `BiLSTMModel`, `BiGRUModel` — each with an `fc` head and optional attention.  
5. Training & evaluation helpers (includes F1 score).  
6. A small training run on the toy data and attention visualization on an example.  
7. Notes: when to use each architecture, gotchas, and practical tips.


---
## Recap

- **Recurrence**: the network applies the same cell across time steps and feeds the previous hidden state into the next step (the loop). That's what RNN/LSTM/GRU do.  
- **Embeddings**: `nn.Embedding` converts token ids → dense vectors. These are *representations*, not final answers.  
- **RNN outputs**: RNN layers produce **hidden states** (representations). They **do not** directly provide class logits — that's the job of an output head (typically a fully connected `nn.Linear`, commonly named `fc`).  
- **LSTM cell state**: `C_t` is the long-term memory conveyor belt; `h_t` is the short-term/output. Gates (forget/input/output) control updates.  
- **Attention**: a learned weighting over time steps (or elements) telling the model *which parts of the sequence to focus on* when producing a final representation. Attention + RNN often improves interpretability and results on noisy/long sequences.
---


In [None]:
# Minimal imports used by the notebook
import math, random, os, sys, itertools, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# Try sklearn for metrics; fall back to simple implementation if absent
try:
    from sklearn.metrics import f1_score
    _HAS_SK = True
except Exception:
    _HAS_SK = False
    def f1_score(y_true, y_pred, average='macro'):
        # simple implementation for binary/multiclass macro f1
        from collections import Counter
        labels = sorted(set(y_true) | set(y_pred))
        f1s = []
        for lab in labels:
            tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt==lab and yp==lab)
            fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt!=lab and yp==lab)
            fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt==lab and yp!=lab)
            prec = tp / (tp+fp) if (tp+fp)>0 else 0.0
            rec = tp / (tp+fn) if (tp+fn)>0 else 0.0
            f1s.append(2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0)
        if average=='macro':
            return sum(f1s)/len(f1s)
        elif average=='micro':
            # micro: compute global TP/FP/FN
            tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt==yp)
            return tp/len(y_true)
        else:
            return sum(f1s)/len(f1s)
print('torch:', torch.__version__, 'sklearn available?', _HAS_SK)


## Tiny synthetic dataset (text-like sequences)

We'll make sequences of token ids from a small vocab. The task is intentionally simple so the focus is on the models and flows:
- Each sequence length = `seq_len`
- Vocabulary size small (e.g. 20)
- Label is 1 if the special token `magic_token` appears, else 0 (binary classification)
This produces a problem where attention is useful (it can highlight the time steps containing the magic token).


In [None]:
class TinySeqDataset(Dataset):
    def __init__(self, n_samples=2000, seq_len=20, vocab=20, magic_token=7, p_magic=0.2):
        self.samples = []
        rng = random.Random(42)
        for _ in range(n_samples):
            seq = [rng.randrange(1, vocab) for _ in range(seq_len)]
            if rng.random() < p_magic:
                pos = rng.randrange(0, seq_len)
                seq[pos] = magic_token
                label = 1
            else:
                label = 0
            self.samples.append((torch.tensor(seq, dtype=torch.long), label))
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

def collate(batch):
    xs, ys = zip(*batch)
    xs = torch.stack(xs, dim=0)
    ys = torch.tensor(ys, dtype=torch.long)
    return xs, ys

# quick dataset
train_ds = TinySeqDataset(n_samples=800, seq_len=20, vocab=20, magic_token=7, p_magic=0.35)
val_ds   = TinySeqDataset(n_samples=200, seq_len=20, vocab=20, magic_token=7, p_magic=0.35)
test_ds  = TinySeqDataset(n_samples=200, seq_len=20, vocab=20, magic_token=7, p_magic=0.35)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate)
val_dl   = DataLoader(val_ds,   batch_size=128, shuffle=False, collate_fn=collate)
test_dl  = DataLoader(test_ds,  batch_size=128, shuffle=False, collate_fn=collate)

print('dataset sizes:', len(train_ds), len(val_ds), len(test_ds))

## Additive (Bahdanau-style) Attention module

This will take RNN outputs across time `(batch, seq_len, hidden)` and produce:
- a context vector `(batch, hidden)` — weighted sum of time-step hidden states
- attention weights `(batch, seq_len)` for interpretation


In [None]:
class AdditiveAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.project = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
    def forward(self, h):  # h: (batch, seq_len, hidden)
        # energy: (batch, seq_len, hidden)
        energy = torch.tanh(self.project(h))
        # scores: (batch, seq_len, 1)
        scores = self.v(energy)
        # weights: (batch, seq_len)
        weights = F.softmax(scores.squeeze(-1), dim=1)
        # context: (batch, hidden)
        context = torch.sum(h * weights.unsqueeze(-1), dim=1)
        return context, weights


## Models: RNN / LSTM / GRU (+ bi- variants) with attention + FC head

Each model:
- Embedding layer (`nn.Embedding`): tokens → vectors
- Recurrent layer: `nn.RNN`, `nn.LSTM`, or `nn.GRU` (batch_first=True)
- Optional attention (we'll use attention for all RNN variants here to show the flow)
- Fully-connected `fc` layer producing logits for binary classification


In [None]:
class BaseRNNClassifier(nn.Module):
    def __init__(self, rnn_type='gru', vocab=20, emb_dim=32, hidden=64, num_layers=1, bidir=False, use_attention=True, num_classes=2):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb_dim, padding_idx=0)
        self.use_attention = use_attention
        self.hidden = hidden
        self.bidir = bidir
        if rnn_type.lower() == 'gru':
            self.rnn = nn.GRU(emb_dim, hidden, num_layers=num_layers, batch_first=True, bidirectional=bidir)
        elif rnn_type.lower() == 'lstm':
            self.rnn = nn.LSTM(emb_dim, hidden, num_layers=num_layers, batch_first=True, bidirectional=bidir)
        elif rnn_type.lower() == 'rnn':
            self.rnn = nn.RNN(emb_dim, hidden, num_layers=num_layers, batch_first=True, bidirectional=bidir)
        else:
            raise ValueError('unknown rnn type')
        rnn_output_dim = hidden * (2 if bidir else 1)
        if use_attention:
            self.attn = AdditiveAttention(rnn_output_dim)
            self.fc = nn.Linear(rnn_output_dim, num_classes)
        else:
            # if no attention, we'll use last timestep hidden state(s) -> fc
            self.fc = nn.Linear(rnn_output_dim, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        emb = self.emb(x)  # (batch, seq_len, emb_dim)
        rnn_out, hidden = self.rnn(emb)  # rnn_out: (batch, seq_len, hidden * directions)
        if self.use_attention:
            context, weights = self.attn(rnn_out)  # (batch, hidden*dir), (batch, seq_len)
            logits = self.fc(context)
            return logits, weights
        else:
            # use last valid time step
            # for bidirectional and multiple layers we take last layer outputs
            if isinstance(hidden, tuple):  # LSTM -> hidden is (h_n, c_n)
                h_n = hidden[0]
            else:
                h_n = hidden
            # h_n: (num_layers * directions, batch, hidden)
            # take last layer's forward & backward
            if self.bidir:
                # concat forward and backward of last layer
                last = torch.cat([h_n[-2], h_n[-1]], dim=1)
            else:
                last = h_n[-1]
            logits = self.fc(last)
            # produce dummy weights
            return logits, None


## Training & eval helpers (small, clear functions)

We'll use a simple training loop and compute accuracy + macro F1 on validation data.


In [None]:
def train_epoch(model, dl, optim, criterion, device):
    model.train()
    total_loss = 0.0
    for x, y in dl:
        x = x.to(device); y = y.to(device)
        optim.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optim.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dl.dataset)

def evaluate(model, dl, device):
    model.eval()
    ys, ypreds = [], []
    with torch.no_grad():
        for x, y in dl:
            x = x.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1).cpu().tolist()
            ys.extend(y.tolist())
            ypreds.extend(preds)
    acc = sum(1 for a,b in zip(ys, ypreds) if a==b) / len(ys)
    f1 = f1_score(ys, ypreds, average='macro')
    return acc, f1


## Quick training run (GRU + Attention) — tiny run to demonstrate flow

We train for a few epochs on the tiny synthetic dataset so you can see loss dropping and get attention visualizations. This is just a demonstration; real-world runs need more data and tuning.


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BaseRNNClassifier(rnn_type='gru', vocab=20, emb_dim=32, hidden=64, bidir=False, use_attention=True, num_classes=2).to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 6):
    t0 = time.time()
    train_loss = train_epoch(model, train_dl, optim, criterion, device)
    val_acc, val_f1 = evaluate(model, val_dl, device)
    print(f'Epoch {epoch} train_loss={train_loss:.4f} val_acc={val_acc:.4f} val_f1={val_f1:.4f} time={(time.time()-t0):.2f}s')


## Visualizing attention weights for a sample

We pick one sample from the test set, run the model, and plot attention weights across the sequence. This shows what the attention mechanism focused on.


In [None]:
# get a batch and visualize the first sample
model.eval()
x_batch, y_batch = next(iter(test_dl))
x0 = x_batch[0:1].to(device)
logits, weights = model(x0)
pred = logits.argmax(dim=1).item()
weights = weights[0].cpu().numpy()  # (seq_len,)
seq = x0[0].cpu().numpy()

print('true label:', y_batch[0].item(), 'pred:', pred)
plt.figure(figsize=(8,2))
plt.plot(weights, marker='o')
plt.title('Attention weights across time (higher = more attended)')
plt.xlabel('time step')
plt.ylabel('attention weight')
plt.grid(True)
plt.show()

# print sequence with weights for inspection
print('tokens:', seq.tolist())
print('attn  :', [float(f'{w:.3f}') for w in weights.tolist()])

---
## Notes, tips, and practical takeaways

- **Why `fc`?** RNN/LSTM/GRU output vector representations. `fc` maps representation → task space (classes, regression values, vocab logits). Without it there's no supervision signal that matches labels.  
- **Attention is cheap and interpretable** for many sequence classification tasks. Use additive (Bahdanau) attention for simplicity. Dot-product attention is slightly faster when dims align.  
- **Bi-directional RNNs (BiLSTM / BiGRU)** read sequence forward and backward and concatenate outputs. They give richer per-step features but require access to full sequence (not streaming).  
- **GRU vs LSTM vs RNN**: GRU is a lighter LSTM (fewer gates), often a good default. LSTM has separate cell state for longer memory. Vanilla RNNs rarely match LSTM/GRU on long sequences due to vanishing gradients.  
- **MultiHeadAttention**: usually better used when you remove recurrence (Transformers) or when you want dense pairwise interactions; it's quadratic in time and often redundant when used with recurrent layers.  
- **Metrics**: prefer macro-F1 for imbalanced classes. Accuracy is misleading for skewed data.  
- **Next steps**: add masking for variable-length sequences, compare BiGRU/BiLSTM, add dropout, tune hidden sizes, train on a real dataset (UCI HAR or an NLP dataset) and visualize attention across classes.
---
