Train the reverse engineered SmolLM2 model from scratch for 5000 steps on custom data

In [1]:
!python train.py --steps 5000 --save checkpoint_model.pt

Using device: cuda
  _C._set_float32_matmul_precision(precision)
tokenizer_config.json: 3.69kB [00:00, 15.8MB/s]
vocab.json: 801kB [00:00, 34.9MB/s]
merges.txt: 466kB [00:00, 89.5MB/s]
tokenizer.json: 2.10MB [00:00, 133MB/s]
special_tokens_map.json: 100% 831/831 [00:00<00:00, 6.09MB/s]
Loaded 341094 tokens
333 batches per full pass

Training from step 0 → 5000
LR schedule enabled (max=0.0003, min=2.9999999999999997e-05)
Using max_steps=5000 for LR

step 0 | loss 11.2806 | lr 0.000003 | tok/s    721.1
step 100 | loss 5.9812 | lr 0.000300 | tok/s   2813.2
step 200 | loss 5.1232 | lr 0.000300 | tok/s   2675.3
step 300 | loss 4.1185 | lr 0.000299 | tok/s   2698.2
step 400 | loss 4.4375 | lr 0.000298 | tok/s   2656.1
step 500 | loss 4.8015 | lr 0.000296 | tok/s   2518.2
step 600 | loss 4.7621 | lr 0.000293 | tok/s   2611.3
step 700 | loss 4.3069 | lr 0.000290 | tok/s   2640.1
step 800 | loss 4.7454 | lr 0.000287 | tok/s   2600.9
step 900 | loss 4.5716 | lr 0.000283 | tok/s   2629.8
step 100

Resume the model Training from the above saved checkpoint and continue for 100steps more

In [7]:
!python train.py --steps 5100 --ckpt checkpoint_model.pt --save checkpoint_model2.pt

Using device: cuda
  _C._set_float32_matmul_precision(precision)
Loaded 341094 tokens
333 batches per full pass
Loaded checkpoint 'checkpoint_model.pt' at step=5000, loss=0.3861677050590515
Resuming from step 5000

Training from step 5000 → 5100
LR schedule enabled (max=0.0003, min=2.9999999999999997e-05)
Using max_steps=5000 for LR

step 5000 | loss 0.3418 | lr 0.000030 | tok/s    935.2

Saved checkpoint: checkpoint_step_5000.pt

Saved checkpoint: checkpoint_model2.pt

Final loss: 0.2695
Model saved → checkpoint_model2.pt


Test the trained model with a prompt

In [23]:
# Load the trained model checkpoint
import torch
from model import SmolLM, SmolConfig
from transformers import AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")

# Initialize the model with config
cfg = SmolConfig()
model = SmolLM(cfg)

# Load checkpoint - use the latest checkpoint from training
checkpoint_path = "checkpoint_model2.pt"  # or use checkpoint_model2.pt if available
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# If checkpoint contains state_dict, load it; otherwise it's the state_dict directly
if isinstance(checkpoint, dict) and "model" in checkpoint:
    model.load_state_dict(checkpoint["model"])
else:
    model.load_state_dict(checkpoint)

# Move to CPU (or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

print(f"Model loaded from {checkpoint_path}")
print(f"Model device: {device}")

Model loaded from checkpoint_model2.pt
Model device: cuda


In [24]:
# Generate text from a prompt
def generate(prompt: str, max_new_tokens: int = 50, temperature: float = 0.7, top_k: int = 50):
    """
    Generate text based on a prompt using the loaded model.

    Args:
        prompt: Input text prompt
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        top_k: Top-k sampling parameter

    Returns:
        Generated text
    """
    # Tokenize the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Initialize past_key_values for caching
    past_key_values = None
    generated_ids = input_ids.clone()

    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Get model output
            if past_key_values is None:
                # First pass: use full input
                logits, past_key_values = model(input_ids, use_cache=True)
            else:
                # Subsequent passes: use only the last token
                logits, past_key_values = model(input_ids[:, -1:], past_key_values=past_key_values, use_cache=True)

            # Get logits for the last token
            next_logits = logits[:, -1, :]

            # Apply temperature
            next_logits = next_logits / temperature

            # Top-k sampling
            if top_k > 0:
                values, indices = torch.topk(next_logits, top_k, dim=-1)
                next_logits = torch.full_like(next_logits, float('-inf'))
                next_logits.scatter_(-1, indices, values)

            # Sample next token
            probs = torch.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Stop if we generate the end-of-sequence token
            if next_token.item() == tokenizer.eos_token_id:
                break

            # Append to generated sequence
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            input_ids = next_token

    # Decode generated tokens
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text



In [26]:
# Test the generation
prompt = "First Citizen"
print(f"Prompt: {prompt}")
print(f"\nGenerated response:")
response = generate(prompt, max_new_tokens=100, temperature=0.8)
print(response)


Prompt: First Citizen

Generated response:
First Citizen:
What in a thousand love that which al lie,
Than too goodly over York, upon thy awhile,
For God, fair thy fair comfort to seek
Than see a king with thy bed woo!
Not that thyak but for tremble speed
With ear that hath thisAgainst the name of this land
Forts, though thou be no haste for thy thy life,
And from my knave: thy one, yet in thee
