In [3]:
import re
import math
import random
import numpy as np
from collections import Counter
from typing import List, Tuple, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

In [4]:
device = torch.device("mps")
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x1133d0f90>

# 1. toy data (sentiment)

In [5]:
corpus = [
    "this movie is great and excellent",
    "fantastic film wonderful direction",
    "good plot amazing soundtrack",
    "touching story strong performances",
    "brilliant engaging narrative",
    "bad pacing awful movie",
    "boring film dull characters",
    "terrible editing horrible dialogue",
    "predictable script poor scenes",
    "unwatchable messy scenes weak plot",
]
labels = np.array([1,1,1,1,1, 0,0,0,0,0], dtype=np.int64)  # 1=POSITIVE, 0=NEGATIVE

# 2. Tokenization & Vocab

In [6]:
def tok(s: str) -> List[str]:
    return re.findall(r"[a-z]+", s.lower())

In [7]:
PAD, UNK = "<pad>", "<unk>"

In [8]:
def build_vocab(texts: List[str], min_freq=1) -> Tuple[Dict[str, int], Dict[int, str]]:
    count = Counter(t for s in texts for t in tok(s))
    itos = [PAD, UNK] + [w for w, f in count.items() if f >= min_freq]
    stoi = {s:i for i, s in enumerate(itos)}
    return stoi, {i:w for w, i in stoi.items()}

In [9]:
stoi, itos = build_vocab(corpus, min_freq=1)
pad_id, unk_id = stoi[PAD], stoi[UNK]

In [10]:
def encode(s: str) -> torch.Tensor:
    ids = [stoi.get(w, unk_id) for w in tok(s)]
    return torch.tensor(ids, dtype=torch.long)

# 3. Load Pretrained embeddings

if embedding exists in keyed_vector, pretrained one, then use pretrained, otherwise stay normalized initialization

In [11]:
def build_embedding_matrix(stoi: Dict[str,int], dim: int,
                           keyed_vectors=None, freeze: bool=True) -> Tuple[nn.Embedding, int]:
    """
    keyed_vectors: gensim KeyedVectors-like object (optional). Must expose:
      - .key_to_index dict (word->idx), .vector_size, and __getitem__(word)->np.ndarray
    """
    V = len(stoi)
    W = np.random.normal(scale=0.01, size=(V, dim)).astype(np.float32)
    W[pad_id] = 0.0

    if keyed_vectors is not None:
        if getattr(keyed_vectors, "vector_size", dim) != dim:
            print(f"[Emb] Dim mismatch: kv={keyed_vectors.vector_size}, requested={dim}. Using random init.")
        else:
            hit = 0
            for w, i in stoi.items():
                if w in getattr(keyed_vectors, "key_to_index", {}):
                    W[i] = keyed_vectors[w]
                    hit += 1

            print(f"[Emb] Loaded {hit} / {V} pretrained vectors")

    emb = nn.Embedding(V, dim, padding_idx=pad_id)
    emb.weight.data.copy_(torch.from_numpy(W))
    emb.weight.requires_grad = not freeze
    return emb, V


kv = None

# 4. Dataset / DataLoader with padding & lengths

In [12]:
class TextClsDataset(Dataset):
    def __init__(self, texts: List[str], lables: np.ndarray):
        self.texts = texts
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, i):
        ids = encode(self.texts[i])
        return ids, len(ids), self.labels[i]

In [20]:
def collate_fn(batch):
    #batch: list of (ids, length, label)
    ids, lens, ys = zip(*batch)
    lens = torch.tensor(lens, dtype=torch.long)
    padded = pad_sequence(ids, batch_first=True, padding_value=pad_id) # [B, T] batch_first is put batch size in first dimension
    ys = torch.stack(ys)
    #sort by length desc for pack_padded_sequence
    lens, sort_idx = lens.sort(descending=True)
    padded = padded.index_select(0, sort_idx)
    ys = ys.index_select(0, sort_idx)
    return padded, lens, ys

In [17]:
#split train/valid
idx = np.arange(len(corpus))
np.random.shuffle(idx)
split = int(0.8 * len(idx))
tr_idx, va_idx = idx[:split], idx[split:]
train_ds = TextClsDataset([corpus[i] for i in tr_idx], labels[tr_idx])
valid_ds = TextClsDataset([corpus[i] for i in va_idx], labels[va_idx])

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
valid_dl = DataLoader(valid_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)
#eg batch_size=4, DataLoader will pick batch_indices = [i1, i2, i3, i4]
#for every i, use train_ds.__getitem__(i) to get 4 [(ids, len(ids), labels[i])]
#then the collate_fn put those into a list as a batch
#each time during training loop using DataLoader, will get returned value from the collate_fn on a batch

# 5. RNN classifier (supports LSTM/GRU, uni/bi-directional)

In [22]:
class RNNClassifier(nn.Module):
    def __init__(self, emb: nn.Embedding, hidden: int = 64,
                 rnn_type: str = "lstm", bidirectional: bool = False,
                 num_layers: int = 1, dropout: float = 0.0, num_classes: int = 2):
        super().__init__()
        self.emb = emb
        self.embed_dim = emb.embedding_dim
        self.hidden = hidden
        self.bidirectional = bidirectional
        self.num_dirs = 2 if bidirectional else 1
        self.rnn_type = rnn_type.lower()
        rnn_cls = nn.LSTM if self.rnn_type == "lstm" else nn.GRU

        self.rnn = rnn_cls(
            input_size=self.embed_dim, #word vector size dimension
            hidden_size=hidden, #H RNN hidden state dimension H
            num_layers=num_layers, #number of RNN stacking layers
            dropout=dropout if num_layers > 1 else 0.0, #only if only num_layers>1 use dropout between layers
            bidirectional=bidirectional,
            batch_first=True, #shape will be [B, T, *]
        )
        self.fc = nn.Linear(hidden * self.num_dirs, num_classes)

    def forward(self, x, lengths):
        #x: [B, T] int otkens
        #lengths: [B]
        emb = self.emb(x) #[B, T, D]
        packed = pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=True)
        out_packed, h = self.rnn(packed)

        if self.rnn_type == "lstm":
            h_n = h[0]
        else:
            h_n = h

        last = torch.cat([h_n[-1]] if self.num_dirs == 1 else [h_n[-2], h_n[-1]], dim=-1)
        logits = self.fc(last) # [B,C]
        return logits

x: [B, T]

emb: [B, T, D]

packed: PackedSequence

h_n: [L * dirs, B, H]

last: [B, H]（单向）或 [B, 2H]（双向）

logits: [B, C]

In [23]:
EMB_DIM = 100
emb_layer, V = build_embedding_matrix(stoi, EMB_DIM, keyed_vectors=kv, freeze=True)
model = RNNClassifier(emb=emb_layer, hidden=64, rnn_type="lstm", bidirectional=False, num_classes=2)

In [31]:
model = model.to(device)

# 6. Train / validate with gradient clipping + early stopping

In [32]:
def run_epoch(dl, model, opt=None, clip_norm: float=1.0):
    is_train = opt is not None
    model.train(is_train)
    total, correct, n = 0.0, 0, 0
    for x, lengths, y in dl:
        x, lengths, y = x.to(device), lengths.to(device), y.to(device)
        logits = model(x, lengths)
        loss = F.cross_entropy(logits, y)
        if is_train:
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
            opt.step()

        total += float(loss) * y.size(0)
        pred = logits.argmax(1)
        correct += int((pred == y).sum())
        n += y.size(0)
    return total / n, correct / n

In [33]:
opt = torch.optim.Adam(model.parameters(), lr=3e-3)
best_val, best_state, patience, bad = float("inf"), None, 5, 0

In [34]:
for ep in range(1, 51):
    tr_loss, tr_acc = run_epoch(train_dl, model, opt)
    va_loss, va_acc = run_epoch(valid_dl, model, opt=None)
    print(f"[ep {ep:02d}] train loss={tr_loss:.3f} acc={tr_acc:.2f} | valid loss={va_loss:.3f} acc={va_acc:.2f}")
    if va_loss < best_val - 1e-4:
        best_val, best_state, bad = va_loss, {k: v.cpu().clone() for k,v in model.state_dict().items()}, 0
    else:
        bad += 1
        if bad >= patience:
            print(f"[EarlyStopping] no improve {patience} epochs, stop.")
            break


if best_state:
    model.load_state_dict(best_state)
    model.to(device)

[ep 01] train loss=0.683 acc=0.62 | valid loss=0.611 acc=1.00
[ep 02] train loss=0.671 acc=0.62 | valid loss=0.592 acc=1.00
[ep 03] train loss=0.665 acc=0.62 | valid loss=0.568 acc=1.00
[ep 04] train loss=0.659 acc=0.62 | valid loss=0.544 acc=1.00
[ep 05] train loss=0.657 acc=0.62 | valid loss=0.513 acc=1.00
[ep 06] train loss=0.650 acc=0.62 | valid loss=0.483 acc=1.00
[ep 07] train loss=0.644 acc=0.62 | valid loss=0.454 acc=1.00
[ep 08] train loss=0.639 acc=0.62 | valid loss=0.428 acc=1.00
[ep 09] train loss=0.629 acc=0.62 | valid loss=0.420 acc=1.00
[ep 10] train loss=0.623 acc=0.62 | valid loss=0.413 acc=1.00
[ep 11] train loss=0.616 acc=0.62 | valid loss=0.411 acc=1.00
[ep 12] train loss=0.607 acc=0.62 | valid loss=0.408 acc=1.00
[ep 13] train loss=0.596 acc=0.62 | valid loss=0.412 acc=1.00
[ep 14] train loss=0.583 acc=0.62 | valid loss=0.417 acc=1.00
[ep 15] train loss=0.569 acc=0.62 | valid loss=0.411 acc=1.00
[ep 16] train loss=0.564 acc=0.62 | valid loss=0.427 acc=1.00
[ep 17] 

In [35]:
# Quick inference
label_names = {0: "NEGATIVE", 1: "POSITIVE"}
def predict(texts: List[str]) -> List[Tuple[str, str, np.ndarray]]:
    model.eval()
    with torch.no_grad():
        batch = [(encode(s), len(tok(s))) for s in texts]
        ids, lens = zip(*batch)
        lens = torch.tensor(lens, dtype=torch.long)
        padded = pad_sequence(ids, batch_first=True, padding_value=pad_id)
        # sort
        lens, sort_idx = lens.sort(descending=True)
        padded = padded.index_select(0, sort_idx).to(device)
        logits = model(padded, lens.to(device))
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        # restore order
        inv = torch.argsort(sort_idx)
        probs = probs[inv]
        preds = probs.argmax(1)
        return [(t, label_names[int(p)], pr) for t,p,pr in zip(texts, preds, probs)]

print("\n[Predict demo]")
for t, lab, pr in predict([
    "this film is wonderful and touching",
    "awful boring movie with dull characters"
]):
    print(f"  {t!r} -> {lab}  probs={np.round(pr, 3)}")


[Predict demo]
  'this film is wonderful and touching' -> POSITIVE  probs=[0.171 0.829]
  'awful boring movie with dull characters' -> POSITIVE  probs=[0.176 0.824]


# 7. Word-level LM with LSTM

In [36]:
lm_text = (
    "the movie is good but the pacing is slow . "
    "the soundtrack is wonderful and the acting is great . "
    "however the script is weak and the plot is boring . "
) * 20

lm_vocab = sorted(set(tok(lm_text)))
lm_stoi = {w:i for i, w in enumerate(lm_vocab)}
lm_itos = {i:w for i, w in lm_stoi.items()}

In [38]:
def encode_words(words):
    return torch.tensor([lm_stoi[w] for w in words], dtype=torch.long)
lm_ids = encode_words(tok(lm_text))

In [39]:
lm_ids

tensor([14,  8,  7,  4,  3, 14,  9,  7, 12, 14, 13,  7, 16,  1, 14,  0,  7,  5,
         6, 14, 11,  7, 15,  1, 14, 10,  7,  2, 14,  8,  7,  4,  3, 14,  9,  7,
        12, 14, 13,  7, 16,  1, 14,  0,  7,  5,  6, 14, 11,  7, 15,  1, 14, 10,
         7,  2, 14,  8,  7,  4,  3, 14,  9,  7, 12, 14, 13,  7, 16,  1, 14,  0,
         7,  5,  6, 14, 11,  7, 15,  1, 14, 10,  7,  2, 14,  8,  7,  4,  3, 14,
         9,  7, 12, 14, 13,  7, 16,  1, 14,  0,  7,  5,  6, 14, 11,  7, 15,  1,
        14, 10,  7,  2, 14,  8,  7,  4,  3, 14,  9,  7, 12, 14, 13,  7, 16,  1,
        14,  0,  7,  5,  6, 14, 11,  7, 15,  1, 14, 10,  7,  2, 14,  8,  7,  4,
         3, 14,  9,  7, 12, 14, 13,  7, 16,  1, 14,  0,  7,  5,  6, 14, 11,  7,
        15,  1, 14, 10,  7,  2, 14,  8,  7,  4,  3, 14,  9,  7, 12, 14, 13,  7,
        16,  1, 14,  0,  7,  5,  6, 14, 11,  7, 15,  1, 14, 10,  7,  2, 14,  8,
         7,  4,  3, 14,  9,  7, 12, 14, 13,  7, 16,  1, 14,  0,  7,  5,  6, 14,
        11,  7, 15,  1, 14, 10,  7,  2, 

In [40]:
LM_B, LM_T = 16, 12

In [51]:
def lm_batch():
    starts = torch.randint(0, len(lm_ids) - LM_T - 1, (LM_B,))
    X = torch.stack([lm_ids[i:i+LM_T] for i in starts])
    Y = torch.stack([lm_ids[i + 1 : i +LM_T+1] for i in starts])
    return X.to(device), Y.to(device)

In [54]:
class WordLMLSTM(nn.Module):
    def __init__(self, vocab_size, d_emb=64, hidden=128, num_layers=1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_emb)
        self.rnn = nn.LSTM(d_emb, hidden, num_layers=num_layers, batch_first=True)
        self.proj = nn.Linear(hidden, vocab_size)

    def forward(self, x):
        e = self.emb(x) # [B, T, d]
        out, _ = self.rnn(e) # [B, T, H]
        logits = self.proj(out) # [B, T, V]
        return logits

In [55]:
lm_model = WordLMLSTM(vocab_size=len(lm_vocab)).to(device)
lm_opt = torch.optim.Adam(lm_model.parameters(), lr=3e-3)
lm_loss_fn = nn.CrossEntropyLoss()

In [56]:
def lm_train(steps=300):
    lm_model.train()
    for step in range(1, steps + 1):
        x, y = lm_batch()
        logits = lm_model(x)
        loss = lm_loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        lm_opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(lm_model.parameters(), 1.0)
        lm_opt.step()
        if step & 30 == 0:
            ppl = math.exp(float(loss.detach().cpu()))
            print(f"[LM] step {step} loss={loss.item():.3f} ppl={ppl:.2f}")

In [65]:
lm_itos_list = [None] * len(lm_stoi)
for w, i in lm_stoi.items():
    lm_itos_list[i] = w

In [72]:
def lm_generate(prompt: List[str], max_new_tokens=20, temperature=1.0, top_k: Optional[int]=None):
    lm_model.eval()
    V = len(lm_itos_list)
    with torch.no_grad():
        ctx = [lm_stoi.get(w, 0) for w in prompt][-LM_T:] #use 1 TO LM_T words as the promt
        for _ in range(max_new_tokens):
            x = torch.tensor([ctx[-LM_T:]], dtype=torch.long, device=device)  # [1, T]
            logits = lm_model(x)[:, -1, :] / max(temperature, 1e-8)
            if top_k is not None:
                v, i = torch.topk(logits, k=min(top_k, V), dim=-1) #clamp k to V
                probs = F.softmax(v, dim=-1)
                nxt = i[0, torch.multinomial(probs[0], 1).item()].item()
            else:
                probs = F.softmax(logits, dim=-1)
                nxt = torch.multinomial(probs[0], 1).item()

            if not (0 <= nxt < V):
                nxt = 0
            ctx.append(int(nxt))
        return " ".join(lm_itos_list[i] if 0 <= i < V else "<unk>" for i in ctx)

In [74]:
print("\n[LM] training tiny word-level LSTM for a few steps…")
lm_train(steps=300)
print("[LM] sample:", lm_generate(["the"], max_new_tokens=20, temperature=0.8, top_k=5))


[LM] training tiny word-level LSTM for a few steps…
[LM] step 1 loss=0.061 ppl=1.06
[LM] step 32 loss=0.069 ppl=1.07
[LM] step 33 loss=0.041 ppl=1.04
[LM] step 64 loss=0.097 ppl=1.10
[LM] step 65 loss=0.037 ppl=1.04
[LM] step 96 loss=0.055 ppl=1.06
[LM] step 97 loss=0.061 ppl=1.06
[LM] step 128 loss=0.068 ppl=1.07
[LM] step 129 loss=0.052 ppl=1.05
[LM] step 160 loss=0.075 ppl=1.08
[LM] step 161 loss=0.036 ppl=1.04
[LM] step 192 loss=0.068 ppl=1.07
[LM] step 193 loss=0.048 ppl=1.05
[LM] step 224 loss=0.071 ppl=1.07
[LM] step 225 loss=0.077 ppl=1.08
[LM] step 256 loss=0.075 ppl=1.08
[LM] step 257 loss=0.064 ppl=1.07
[LM] step 288 loss=0.051 ppl=1.05
[LM] step 289 loss=0.104 ppl=1.11
[LM] sample: the pacing is slow the soundtrack is wonderful and the acting is great however the script is weak and the plot
