A tiny repo that “decapitates” a causal LM (Qwen3-0.6B-Base) and fine-tunes a new 2-way classification head for spam vs ham. It trains on three CSV files, writes versioned checkpoints, and includes a simple CLI for inference.
This project uses uv for dependency management and execution.
-
Install deps
uv sync
-
Train
uv run python train_qwen_classifier.py
-
Classify some text (uses the
bestcheckpoint symlink)uv run python run_qwen_classifier.py "free $$$ click here NOW" # -> Ham: 03.12% || Spam: 96.88%
You can also point at a specific checkpoint directory name:
uv run python run_qwen_classifier.py "hello there, just checking in" 20251024Z223015-
create_model.py— Loads Qwen/Qwen3-0.6B-Base, then replaceslm_headwith a freshtorch.nn.Linear(..., out_features=2)(binary classes). Usesdtype="auto"anddevice_map="auto"so the model tries to place itself on available GPU/CPU automatically. -
train_qwen_classifier.py— Minimal trainer:- Builds datasets from CSV files (see Data format below).
- Tokenizes with the Qwen tokenizer and pads/truncates to a fixed max length (computed from the longest training sample unless you pass
max_lengthto the dataset ctor). - Uses cross-entropy on the last-token logits for the two classes.
- AdamW (lr
5e-5, weight decay0.1), batch size8,num_epochs=5. - Prints loss every 50 steps (evaluating over 5 mini-batches) and accuracy per epoch (again over a small eval window).
- Saves a timestamped checkpoint each eval; updates a
checkpoints/bestsymlink when validation loss improves.
-
persistence.py— Checkpoint I/O:- Saves
safetensorsstate dict tocheckpoints/YYYYMMDDZHHMMSS/model.safetensors+meta.json. - Maintains the
checkpoints/bestsymlink to the best-so-far run. load_model(checkpoint)restores weights into a freshly constructed model.
- Saves
-
run_qwen_classifier.py— Tiny CLI usingclick; prints Ham and Spam probabilities for an input string using a chosen checkpoint (default:best).
Place three CSV files at the repo root:
classification-train.csvclassification-validation.csvclassification-test.csv
Each must have the columns:
Label— integer class id (0for ham,1for spam)Text— the raw input string
Example (classification-train.csv):
Label,Text
0,"Hello! Are we still on for lunch?"
1,"CONGRATULATIONS! You've won a FREE cruise. Click now!"
0,"Invoice attached. Thanks."
Notes
- By default, the training set’s longest tokenized sample length becomes the fixed sequence length. Validation and test are padded/truncated to that length for consistency.
- Padding uses the tokenizer’s pad token id
After/while training you’ll have:
checkpoints/
20251024Z223015/
model.safetensors
meta.json # {"epoch": ..., "train_loss": ..., "val_loss": ...}
20251024Z231742/
model.safetensors
meta.json
best -> 20251024Z231742 # symlink to the lowest val loss so far
Use the directory name (or best) with the CLI.
Managed by uv. Typical runtime stack:
torchtransformerspandassafetensorsclick- (Optional but recommended)
accelerate— improvesdevice_map="auto"placement
Install via:
uv syncAll commands should be run with uv run, e.g.:
uv run python train_qwen_classifier.py
uv run python run_qwen_classifier.py "some text"- Model — Start from
Qwen/Qwen3-0.6B-Base(causal LM). Replacelm_headwith a 2-logit linear layer (ham/spam). - Labeling scheme — For each input, forward once and read the last token’s logits. Apply softmax to get class probabilities.
- Training loop — Simple supervised fine-tuning on cross-entropy of those logits versus integer labels. Mini evals every 50 steps; best-val symlink updated on improvement.
- GPU vs CPU:
device_map="auto"will try GPU if available. If you only have CPU and see CUDA errors, ensure your PyTorch install matches your platform or set CUDA-related env vars off. - Model downloads: Hugging Face will download
Qwen/Qwen3-0.6B-Baseon first run; ensure you have internet or have it cached. - Sequence length: Very long inputs will be truncated to the training max length; very short inputs are padded. If desired, set a hard
max_lengthinSpamDataset.
See the LICENSE file.