TinyGPT is a small GPT-style language model trained on the TinyStories dataset using JAX and Flax NNX.
The project is implemented in the notebook miniGPT.ipynb and covers the full pipeline:
- model architecture
- data processing
- training with checkpointing
- text generation (inference)
This notebook builds a decoder-only transformer from scratch with:
- token + positional embeddings
- causal self-attention
- stacked transformer blocks (pre-layer norm)
- linear language modeling head
The model is trained autoregressively (next-token prediction) on TinyStories text split by <|endoftext|>.
- Sequence length:
128 - Embedding dimension:
256 - Attention heads:
8 - Feed-forward dimension:
1024 - Transformer blocks:
6 - Tokenizer:
tiktokenGPT-2 encoding - Epochs:
3 - Batch size:
32 - Max stories loaded:
100000
- Python
- JAX
- Flax NNX
- Optax
- Orbax (checkpoints)
- tiktoken
- grain (data loading)
The notebook is organized into four parts:
- Model Architecture
- Data Loading
- Training
- Inference
TokenEmbedding: combines token and learned positional embeddingscausal_attention_mask: creates a lower-triangular mask for autoregressive attentionTransformerBlock: pre-LN attention + MLP with residualsminiGPTModel: full GPT-style stack and vocabulary projection
- Reads TinyStories training text file
- Splits stories using
<|endoftext|>delimiter - Tokenizes with GPT-2 tokenizer
- Truncates/pads to fixed length (
maxlen) - Uses
grainDataLoaderwith batching
- Loss: cross-entropy with integer labels (
optax.softmax_cross_entropy_with_integer_labels) - LR schedule: warmup + cosine decay
- Optimizer: AdamW (
optax.adamw) - JIT-compiled train step with
@nnx.jit - Checkpoints saved every
100steps and at each epoch end via Orbax
- Restores a saved checkpoint
- Generates text autoregressively with temperature sampling
- Stops when
<|endoftext|>is generated or max tokens reached
Install dependencies:
pip install jax flax optax orbax-checkpoint tiktoken grainIf you are running in Colab (as in the notebook), mount Google Drive:
from google.colab import drive
drive.mount('/content/drive')Then update paths in the notebook to point to:
- TinyStories training text file
- checkpoint output directory
Open and run tinyStoriesGPT.ipynb top-to-bottom:
- Install/import dependencies
- Configure dataset and checkpoint paths
- Build model
- Load data
- Train and save checkpoints
- Restore checkpoint
- Generate stories from prompts
The loader expects a plain text file containing stories separated by <|endoftext|>.
Example:
Story one text...<|endoftext|>
Story two text...<|endoftext|>
Checkpoints are written to the configured checkpoint_dir with names such as:
epoch_1_step_100epoch_1_final
To run inference, restore one checkpoint into the model state and call generate(...).
- The notebook currently uses Google Drive paths and Colab mounting.
- If you run locally, replace those paths with local filesystem paths.
- Hyperparameters are intentionally small to keep the model lightweight.
- Paarth Sharma