In [None]:
import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

In [None]:
repo_id = "HuggingFaceTB/SmolLM-135M-Instruct"

model = AutoModelForCausalLM.from_pretrained(repo_id,
                                             device_map='cuda:0',
                                             torch_dtype='auto')

In [None]:
print(f'Model memory footprint: {model.get_memory_footprint()/1e9} GB')

total_params = sum(p.numel() for p in model.parameters())
print(total_params/1e9)

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
ds = load_dataset("voidful/reasoning_gemini_300k", num_proc=8)
ds = ds['train']

In [None]:
tokenizer = AutoTokenizer.from_pretrained(repo_id)
messages = [
    {"role": "user", "content": "Top 10 attractions in paris"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

In [None]:
import torch
import torch.nn as nn

inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_tensors="pt",
).to(model.device)

outputs = model(inputs)

# Define loss
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs.logits[0], inputs[0])
loss.backward()

In [None]:
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: grad type = {param.grad.dtype}")
    else:
        print(f"{name}: No gradient computed.")

In [None]:
def format_dataset(example):
    converted_sample = [
            {"role": "user", "content": example['message']},
            {"role": "assistant", "content": '<think>' + example['reasoning'] + '</think>' + example['answer']},
        ]
    return {'messages': converted_sample}
ds = ds.map(format_dataset)
ds = ds.remove_columns(['message', 'reasoning', 'answer'])

In [None]:
sft_config = SFTConfig(
    ## GROUP 1: Memory usage
    # These arguments will squeeze the most out of your GPU's RAM
    # Checkpointing
    gradient_checkpointing=False,    # this saves a LOT of memory
    # Set this to avoid exceptions in newer versions of PyTorch
    # Gradient Accumulation / Batch size
    # Actual batch (for updating) is same (1x) as micro-batch size
    gradient_accumulation_steps=2,  
    # The initial (micro) batch size to start off with
    per_device_train_batch_size=1, 
    max_length = 256,
    max_steps=50,
    bf16=True,
    
    ## GROUP 2: Dataset-related
    # Dataset
    # packing a dataset means no padding is needed
    packing=False,
    dataset_num_proc=8,
    dataloader_num_workers=8,
    include_tokens_per_second=True,
    include_num_input_tokens_seen=True,
    
    ## GROUP 3: These are typical training parameters
    num_train_epochs=1,
    learning_rate=2e-4,
    # Optimizer
    
    ## GROUP 4: Logging parameters
    logging_steps=1,
    logging_dir='./logs',
    output_dir='./qwen3-adapter',
    report_to='none'
)

In [None]:
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=ds
)

In [None]:
trainer.train()

In [None]:
# Check dtypes of optimizer states
for i, (key, state) in enumerate(trainer.optimizer.state.items()):
    print(f"\nParameter {i}:")
    for state_key, value in state.items():
        if isinstance(value, torch.Tensor):
            print(f"  {state_key}: dtype = {value.dtype}")

In [None]:
for i, (param, state) in enumerate(trainer.optimizer.state.items()):
    print(f"\nParameter {i}:")
    param_size = param.numel() * param.element_size()
    print(f"  - Model parameter: {param.numel()} elements * {param.element_size()} bytes = {param_size} bytes")

    for key, value in state.items():
        if isinstance(value, torch.Tensor):
            state_size = value.numel() * value.element_size()
            print(f"  - {key}: {value.numel()} elements * {value.element_size()} bytes = {state_size} bytes")


In [None]:
model