In [None]:
import os
import torch
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModel
from trl import SFTConfig, SFTTrainer

In [None]:
bnb_config = BitsAndBytesConfig(
   load_in_8bit=True
)

repo_id = "Qwen/Qwen3-8B-Base"
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id,
                                             device_map='cuda:0', 
                                             quantization_config=bnb_config,
                                             use_cache=False,
                                             torch_dtype='auto')
print(f'Model memory footprint: {model.get_memory_footprint()/1e9} GB')

In [None]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(r = 8,
                         lora_alpha=16,
                         bias='none',
                         lora_dropout=0,
                         task_type='CAUSAL_LM',
                         target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                                         "gate_proj", "up_proj", "down_proj",],
)
model

In [None]:
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.batch_decode(outputs)[0])

In [None]:
import re
def format_dataset(example):
    instruction = example['instruction']
    output = example['output']

    # Remove ser blocks
    output = re.sub(r'<ser>.*?</ser>', '', output, flags=re.DOTALL)
    
    # Use a regular expression to match the entire block starting with <think> and ending with the explanation after </think>
    think_blocks = re.findall(r'(<think>.*?</think>.*?)(?=\n<think>|$)', output, re.DOTALL)
    
    # Create a list of messages
    converted_sample = [
            {"role": "user", "content": instruction},
        ]
    
    for block in think_blocks:
        converted_sample.append({"role": "assistant", "content": block})

    return {'messages': converted_sample}

In [None]:
dataset = load_dataset("HelpingAI/Intermediate-Thinking-130k", split='train')
dataset = dataset.map(format_dataset)
dataset = dataset.remove_columns(['instruction','input','output','conversation'])

In [None]:
print(tokenizer.apply_chat_template(dataset[0]['messages'], tokenize=False))

In [None]:
sft_config = SFTConfig(
    ## GROUP 1: Memory usage
    # These arguments will squeeze the most out of your GPU's RAM
    # Checkpointing
    gradient_checkpointing=True,    # this saves a LOT of memory
    # Set this to avoid exceptions in newer versions of PyTorch
    gradient_checkpointing_kwargs={'use_reentrant': False}, 
    # Gradient Accumulation / Batch size
    # Actual batch (for updating) is same (1x) as micro-batch size
    gradient_accumulation_steps=32,  
    # The initial (micro) batch size to start off with
    per_device_train_batch_size=4, 
    max_length=1024,
    
    ## GROUP 2: Dataset-related
    # Dataset
    # packing a dataset means no padding is needed
    packing=False,

    ## GROUP 3: These are typical training parameters
    num_train_epochs=10,
    learning_rate=5e-5,
    lr_scheduler_type='linear',
    warmup_ratio=0.2,

    # Optimizer
    # 8-bit Adam optimizer - doesn't help much if you're using LoRA!
    #optim='adamw_8bit',  
    max_steps=100,     

    dataloader_num_workers=8,
    dataset_num_proc=8,
    
    ## GROUP 4: Logging parameters
    logging_steps=1,
    log_level='info',
    logging_dir='./logs',
    output_dir='./qwen3_adapter',
    report_to='none',
)

In [None]:
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=ds,
    peft_config=peft_config
)

In [None]:
dl = trainer.get_train_dataloader()
batch = next(iter(dl))

In [None]:
batch['input_ids'][0], batch['labels'][0]

In [None]:
trainer.train()

In [None]:
import json

with open('qwen3_adapter/checkpoint-500/trainer_state.json', 'r') as file:
    data = json.load(file)

loss, lr = [], []
for step in data['log_history']:
    loss.append(step['loss'])
    lr.append(step['learning_rate'])

In [None]:
import matplotlib.pyplot as plt
plt.plot(loss)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.plot(lr)
plt.xlabel('Step')
plt.ylabel('Learning rate')
plt.show()

In [None]:
from peft import PeftModel

repo_id = "Qwen/Qwen3-0.6B-Base"
model = AutoModelForCausalLM.from_pretrained(repo_id,
                                             device_map='cuda:0',
                                             torch_dtype='auto')
peft_model = PeftModel.from_pretrained(
    model, 'qwen3_adapter/checkpoint-500/', torch_dtype=torch.float16
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(repo_id)
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

In [None]:
inputs['input_ids']

In [None]:
tokenizer.decode(inputs['input_ids'][0])

In [None]:
loss