In [None]:
# sentiment_minimal.py
import torch, torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from collections import Counter

# 1) tiny toy data
train_texts = ["i love this movie", "this film is great", "i hate this movie", "this film is bad", "awesome acting", "terrible plot"]
train_labels = [1,1,0,0,1,0]
test_texts = ["great movie", "bad acting", "i love it", "i hate it"]
test_labels = [1,0,1,0]

# 2) tokenizer
def tok(s): return s.lower().split()

# 3) vocab from train only
PAD, UNK = "<pad>", "<unk>"
cnt = Counter(w for s in train_texts for w in tok(s))
itos = [PAD, UNK] + [w for w,_ in cnt.items()]
stoi = {w:i for i,w in enumerate(itos)}
pad_id, unk_id = stoi[PAD], stoi[UNK]

# 4) numericalize
def to_ids(s): return torch.tensor([stoi.get(w, unk_id) for w in tok(s)], dtype=torch.long)

# 5) datasets
train_X = [to_ids(s) for s in train_texts]
train_y = torch.tensor(train_labels, dtype=torch.long)
test_X = [to_ids(s) for s in test_texts]
test_y = torch.tensor(test_labels, dtype=torch.long)

# 6) collate: pad + mask (mask optional here)
def collate(batch):
    seqs, labels = zip(*batch)
    lens = torch.tensor([len(s) for s in seqs])
    padded = pad_sequence(seqs, batch_first=True, padding_value=pad_id)
    return padded, lens, torch.tensor(labels)

# 7) dataloaders
train_ds = list(zip(train_X, train_y.tolist()))
test_ds = list(zip(test_X, test_y.tolist()))
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collate)
test_dl  = torch.utils.data.DataLoader(test_ds,  batch_size=2, shuffle=False, collate_fn=collate)



# # pad once (batch_first=True → shape (N, T))
# pad_id = stoi["<pad>"]
# train_X_pad = pad_sequence(train_X, batch_first=True, padding_value=pad_id)
# test_X_pad  = pad_sequence(test_X,  batch_first=True, padding_value=pad_id)

# # lengths and masks (optional)
# train_lens = torch.tensor([len(s) for s in train_X])
# test_lens  = torch.tensor([len(s) for s in test_X])
# train_mask = (train_X_pad != pad_id)  # (N, T) bool
# test_mask  = (test_X_pad  != pad_id)

# # use default collate (no collate_fn)
# train_ds = TensorDataset(train_X_pad, train_lens, train_mask, train_y)
# test_ds  = TensorDataset(test_X_pad,  test_lens,  test_mask,  test_y)
# train_dl = DataLoader(train_ds, batch_size=2, shuffle=True)
# test_dl  = DataLoader(test_ds,  batch_size=2, shuffle=False)



# 8) model: Embedding → mean → Linear
class MeanEmbClassifier(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, num_classes=2, pad_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.fc = nn.Linear(emb_dim, num_classes)
    def forward(self, x, lens):
        e = self.emb(x)                      # (B,T,E)
        e = e.sum(dim=1) / lens.unsqueeze(1) # mean over tokens (ignore pads via lengths)
        return self.fc(e)
    

#     	•	dim=1 means sum across the sequence length T.
# 	•	So it adds up all token embeddings in a sentence.
# 	•	Shape becomes (B, E)

# Example: if T=10, we collapse it into just one 64-d vector per sentence (sum of 10 embeddings).

# ⸻

# 3. lens.unsqueeze(1)
# 	•	lens is the tensor of sequence lengths (how many real tokens per sample, not counting <pad>).
# Example: lens = tensor([7, 5, 9])
# 	•	Shape of lens: (B,)
# 	•	.unsqueeze(1) → (B, 1) so it can broadcast in division.


device = "cuda" if torch.cuda.is_available() else "cpu"
model = MeanEmbClassifier(len(itos), emb_dim=64, num_classes=2, pad_idx=pad_id).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

# 9) train
for epoch in range(10):
    model.train()
    tot, correct, total = 0.0, 0, 0
    for xb, lens, yb in train_dl:
        xb,lens,yb = xb.to(device), lens.to(device), yb.to(device)
        logits = model(xb, lens)
        loss = loss_fn(logits, yb)
        opt.zero_grad(); loss.backward(); opt.step()
        tot += loss.item()*xb.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred==yb).sum().item()
        total += xb.size(0)
    print(f"epoch {epoch+1}: loss={tot/total:.4f}, acc={correct/total:.3f}")

# 10) eval
model.eval()
with torch.no_grad():
    correct, total = 0, 0
    for xb, lens, yb in test_dl:
        xb,lens,yb = xb.to(device), lens.to(device), yb.to(device)
        pred = model(xb, lens).argmax(dim=1)
        correct += (pred==yb).sum().item()
        total += xb.size(0)
    print(f"test acc: {correct/total:.3f}")