In [None]:
import os
import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModel
from trl import SFTConfig, SFTTrainer

In [None]:
bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float32
)

repo_id = "Qwen/Qwen3-0.6B-Base"
#repo_id = "Qwen/Qwen3-4B-Instruct-2507"
model = AutoModelForCausalLM.from_pretrained(repo_id,
                                             device_map='cuda:0', 
                                             quantization_config=bnb_config,
                                             torch_dtype='auto')
print(f'Model memory footprint: {model.get_memory_footprint()/1e6} GB')

In [None]:
print(f'Model memory footprint: {model.get_memory_footprint()/1e9} GB')

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params/1e9)

In [None]:
from collections import defaultdict

layer_params = defaultdict(int)
# Accumulate parameters per layer (by prefix)
for name, param in model.named_parameters():
    if param.requires_grad:
        # Extract the layer/module name (e.g., "transformer.h.0.attn")
        layer_name = ".".join(name.split(".")[:3])  # Adjust depth as needed
        layer_params[layer_name] += param.numel()

# Print parameter count per layer
for layer_name, param_count in sorted(layer_params.items()):
    print(f"{layer_name:<60} {param_count:,} parameters")

In [None]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(r = 8,
                    lora_alpha=16,
                    bias='none',
                    lora_dropout=0.05,
                    task_type='CAUSAL_LM',
                    target_modules=['o_proj', 'qkv_proj', 'gate_up_proj', 'down_proj'],
)
#model = get_peft_model(model, config)

In [None]:
model

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(total_params/1e9)

In [None]:
ds = load_dataset("voidful/reasoning_gemini_300k", num_proc=8)
ds = ds['train']

In [None]:
tokenizer = AutoTokenizer.from_pretrained(repo_id)
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

In [None]:
def format_dataset(example):
    converted_sample = [
            {"role": "user", "content": example['message']},
            {"role": "assistant", "content": '<think>' + example['reasoning'] + '</think>' + example['answer']},
        ]
    return {'messages': converted_sample}

In [None]:
ds = ds.map(format_dataset, num_proc=8)
ds = ds.remove_columns(['message', 'reasoning', 'answer'])

In [None]:
print(tokenizer.apply_chat_template(ds[0]['messages'], tokenize=False))

In [None]:
sft_config = SFTConfig(
    ## GROUP 1: Memory usage
    # These arguments will squeeze the most out of your GPU's RAM
    # Checkpointing
    gradient_checkpointing=True,    # this saves a LOT of memory
    # Set this to avoid exceptions in newer versions of PyTorch
    gradient_checkpointing_kwargs={'use_reentrant': False}, 
    # Gradient Accumulation / Batch size
    # Actual batch (for updating) is same (1x) as micro-batch size
    gradient_accumulation_steps=8,  
    # The initial (micro) batch size to start off with
    per_device_train_batch_size=2, 
    max_length=1024,
    
    ## GROUP 2: Dataset-related
    # Dataset
    # packing a dataset means no padding is needed
    packing=False,

    ## GROUP 3: These are typical training parameters
    num_train_epochs=10,
    learning_rate=5e-5,
    lr_scheduler_type='linear',
    warmup_ratio=0.2,

    # Optimizer
    # 8-bit Adam optimizer - doesn't help much if you're using LoRA!
    optim='paged_adamw_8bit',  
    max_steps=500,     

    dataloader_num_workers=8,
    dataset_num_proc=8,
    
    ## GROUP 4: Logging parameters
    logging_steps=10,
    logging_dir='./logs',
    output_dir='./qwen3_adapter',
    report_to='none',
)

In [None]:
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=ds,
    peft_config=peft_config
)

In [None]:
dl = trainer.get_train_dataloader()
batch = next(iter(dl))

In [None]:
batch['input_ids'][0], batch['labels'][0]

In [None]:
trainer.train()