Skip to content

Modernize deps (drop torchtext, HF datasets) and address open issues#31

Merged
keon merged 3 commits into
masterfrom
modernize-deps-and-fixes
May 18, 2026
Merged

Modernize deps (drop torchtext, HF datasets) and address open issues#31
keon merged 3 commits into
masterfrom
modernize-deps-and-fixes

Conversation

@keon
Copy link
Copy Markdown
Owner

@keon keon commented May 18, 2026

Summary

  • Drop deprecated torchtext (legacy Field / Multi30k.splits was removed in 0.9+, and the package itself was deprecated entirely in 2024). Load Multi30k via HuggingFace datasets instead.
  • Switch additive attention from relu to tanh to match Bahdanau et al. 2014 — the README already claims this architecture.
  • Add greedy inference mode: Seq2Seq.forward(src, trg=None, max_len, sos) now works without a target.
  • Remove deprecated torch.autograd.Variable; auto-detect CUDA / Apple MPS / CPU in train.py.
  • Pin modern versions in requirements.txt; update README to new spacy model names.

Issues addressed

Test plan

Verified end-to-end on CPU in a fresh venv:

  • pip install -r requirements.txt succeeds against current PyPI.
  • python -m spacy download {de_core_news_sm,en_core_web_sm} succeeds.
  • utils.load_dataset(8) builds vocabs (7853 DE / 9797 EN) and yields 3625 / 127 / 125 train/val/test batches from bentrevett/multi30k.
  • 3 training steps run; loss drops 9.20 → 9.15 → 9.10.
  • Seq2Seq.forward(src, trg=None, max_len=20, sos=2) produces predictions with no target.
  • Eval loop runs cleanly over val batches.

🤖 Generated with Claude Code

keon and others added 3 commits May 17, 2026 23:59
- Drop deprecated torchtext; load Multi30k via HuggingFace `datasets`
  (closes #10, #11, #12).
- Use tanh in additive attention to match Bahdanau et al. 2014
  (closes #15, #28).
- Add inference mode: `Seq2Seq.forward(src, trg=None, max_len, sos)`
  performs greedy decoding without a target (closes #22).
- Remove deprecated `torch.autograd.Variable`; auto-detect
  CUDA / Apple MPS / CPU in train.py.
- Pin modern versions in requirements.txt (torch>=2.0, datasets>=2.14,
  spacy>=3.7) and update README install instructions to new spacy
  model names.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Addresses #14: 100 epochs at fixed lr on 30k-example Multi30k drives
the train loss to ~0 while val plateaus. ReduceLROnPlateau halves the
LR after 2 stagnant val epochs; early stopping breaks after `-patience`
epochs without val improvement (default 5).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced May 18, 2026
@keon keon merged commit 7c243a0 into master May 18, 2026
keon added a commit that referenced this pull request May 18, 2026
Address #31 review: Bahdanau s_0 bridge, safe inference, single best.pt
pull Bot pushed a commit to vishalbelsare/seq2seq that referenced this pull request May 18, 2026
…e inference

Addresses code review feedback on PR keon#31:

- HIGH: drop `inplace=True` from decoder embedding dropout (autograd footgun)
- HIGH: write `outputs[0]` one-hot for the start token so callers can do
  `outputs.argmax(-1)` without getting a UNK at position 0
- MEDIUM: replace `hidden[:n_layers]` slice (which silently kept only the
  encoder's *forward* direction) with a Bahdanau §A.2.2 bridge:
  s_0 = tanh(W_s · ←h_1) from the encoder's last backward state
- MEDIUM: save a single `best.pt` instead of per-epoch `seq2seq_{e}.pt`
- MEDIUM: pre-tokenize via `dataset.map` so spaCy doesn't re-run every epoch
- LOW: simplify attention `v` init (drop redundant `torch.rand`)

Re-verified on CPU: 500 batches, train 9.19 → 4.84, val 4.93
(vs 4.97 before the bridge — slight improvement, otherwise same dynamics).
`outputs[0].argmax()` correctly returns SOS for all batch rows.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

don't have the inference mode? about the way to calculate attention weight The Pytorch version?

1 participant