ideas and script from https://www.philschmid.de/fine-tune-llms-in-2025

In [None]:
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, ModelConfig, SFTConfig, get_peft_config
from datasets import load_dataset
# much faster kernel
from liger_kernel.transformers import AutoLigerKernelForCausalLM

# Arguments and Configuration

In [None]:
model_name = 'meta-llama/Llama-3.2-1B-Instruct'
dataset_name = 'ai-abstract-dataset.jsonl.xz'
output_dir = "runs/" + model_name.split("/")[-1] + dataset_name.split(".")[0]

In [None]:
model_args = ModelConfig(model_name_or_path=model_name, 
                         model_revision='main', 
                         torch_dtype='bfloat16', 
                         trust_remote_code=False, 
                         attn_implementation='flash_attention_2', 
                         use_peft=True, 
                         lora_r=16, 
                         lora_alpha=16, 
                         lora_dropout=0.05, 
                         lora_target_modules='all-linear', 
                         lora_modules_to_save=['lm_head', 'embed_tokens'],
                         lora_task_type='CAUSAL_LM', 
                         use_rslora=False, 
                         load_in_8bit=False, 
                         load_in_4bit=True, 
                         bnb_4bit_quant_type='nf4', 
                         use_bnb_nested_quant=False
                        )

In [None]:
training_args = SFTConfig(
     output_dir=output_dir,    
     num_train_epochs=1,
     bf16=True,
     packing=True,
     max_length=1024,
     per_device_train_batch_size=8,
     gradient_accumulation_steps=2,
     gradient_checkpointing=True,
     gradient_checkpointing_kwargs = { "use_reentrant": False },
     learning_rate=2.0e-4,
     lr_scheduler_type="constant",
     use_liger_kernel=True,
     warmup_ratio=0.1,
)

# Load dataset

In [None]:
train_dataset = load_dataset("json", data_files=dataset_name, split="train")

f'Dataset with {len(train_dataset)} samples and the following features: {train_dataset.features}'

# Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None: 
    tokenizer.pad_token = tokenizer.eos_token

# Model

In [None]:
# define model kwargs
model_kwargs = dict(
    revision=model_args.model_revision, # What revision from Huggingface to use, defaults to main
    trust_remote_code=model_args.trust_remote_code, # Whether to trust the remote code, this also you to fine-tune custom architectures
    attn_implementation=model_args.attn_implementation, # What attention implementation to use, defaults to flash_attention_2
    dtype=model_args.torch_dtype, # What torch dtype to use, defaults to auto
    use_cache=False if training_args.gradient_checkpointing else True, # Whether
    low_cpu_mem_usage=True,  # Reduces memory usage on CPU for loading the model
)

In [None]:
model_kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['dtype'],
    bnb_4bit_quant_storage=model_kwargs['dtype'],
)

peft_config = get_peft_config(model_args)

In [None]:
model = AutoLigerKernelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)

# Trainer

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,
    peft_config=peft_config,
)
trainer.model.print_trainable_parameters()

## Training loop

In [None]:
train_result = trainer.train()
# log metrics
metrics = train_result.metrics
metrics['train_samples'] = len(train_dataset)
trainer.save_metrics('train', metrics)
trainer.save_state()
metrics

# Save model

In [None]:
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
f"saved to {training_args.output_dir}"