# Step 0: Prepare Environment

### Restart after this step. Then, continue with Step 1.



In [None]:
%pip install  --upgrade \
  "torch>=2.4.0" \
  tensorboard \
  "transformers>=4.51.3" \
  datasets \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.21.0" \
  "peft==0.14.0" \
  "protobuf==3.20.3" \
  sentencepiece \
  evaluate \
  sacrebleu \
  rouge-score

# Step 1: Enter your Hugging Face access token to be able to use Gemma

In [None]:
from huggingface_hub import login

login()

# Step 2: Create a combined dataset

In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets
import random

# Fix random seed for reproducibility
random.seed(22)

## Step 2.1: Load datasets

In [None]:
# Load datasets
alpaca = load_dataset("tatsu-lab/alpaca")
tulu = load_dataset("allenai/tulu-v2-sft-mixture")
ultra = load_dataset("HuggingFaceH4/ultrachat_200k")

## Step 2.2: Sample and split datasets

In [None]:
# Samples and splits datasets
def sample_split(dataset, train_size=5000, test_size=2000):
  indexes = list(range(len(dataset)))
  random.shuffle(indexes)

  train_indexes = indexes[:train_size]
  test_indexes = indexes[train_size:train_size + test_size]

  train_split = dataset.select(train_indexes)
  test_split = dataset.select(test_indexes)

  return train_split, test_split


alpaca_train, alpaca_test = sample_split(alpaca["train"])
tulu_train, tulu_test = sample_split(tulu["train"])
ultrachat_train, ultrachat_test = sample_split(ultra["train_sft"])

## Step 2.3: Format datasets as system, instruction, input, and response

In [None]:
# Format alpaca sample as system, instruction, input, and response
def format_alpaca(row):
  system = ""
  instruction = row["instruction"]
  input = row["input"].strip()
  response = row["output"]

  chat = [
      {"system": system, "instruction": instruction, "input": input, "response": response}
  ]

  return {"chat": chat}


# Format tulu sample as system, instruction, input, and response
def format_tulu(row):
  chat = []
  i = 0
  system = ""
  input = ""
  while i < len(row["messages"]):
    role = row["messages"][i]["role"]
    if role == "system":
      system = row["messages"][i]["content"]
    elif role == "user":
      instruction = row["messages"][i]["content"]
    elif role == "assistant":
      response = row["messages"][i]["content"]
    i += 1

  chat.append({"system": system, "instruction": instruction, "input": input, "response": response})

  return {"chat": chat}


# Format ultrachat sample as system, instruction, input, and response
def format_ultrachat(row):
  chat = []
  i = 0
  system = ""
  input = ""
  while i < len(row["messages"]):
    role = row["messages"][i]["role"]
    if role == "user":
      instruction = row["messages"][i]["content"]
    elif role == "assistant":
      response = row["messages"][i]["content"]
    else: print("****")
    i += 1

  chat.append({"system": system, "instruction": instruction, "input": input, "response": response})

  return {"chat": chat}


# Format and map training samples into dataset format
alpaca_train_mapped = alpaca_train.map(format_alpaca, remove_columns=alpaca_train.column_names)
tulu_train_mapped= tulu_train.map(format_tulu, remove_columns=tulu_train.column_names)
ultrachat_train_mapped= ultrachat_train.map(format_ultrachat, remove_columns=ultrachat_train.column_names)

# Format and map test samples into dataset format
alpaca_test_mapped = alpaca_test.map(format_alpaca, remove_columns=alpaca_test.column_names)
tulu_test_mapped= tulu_test.map(format_tulu, remove_columns=tulu_test.column_names)
ultrachat_test_mapped= ultrachat_test.map(format_ultrachat, remove_columns=ultrachat_test.column_names)

## Step 2.4: Combine datasets

In [None]:
# Combine sample datasets as train and test
combined_train = concatenate_datasets([alpaca_train_mapped,  tulu_train_mapped, ultrachat_train_mapped])
combined_test = concatenate_datasets([alpaca_test_mapped,  tulu_test_mapped, ultrachat_test_mapped])

## Step 2.5: Format chats as role:user/assistant and content

In [None]:
# Format chat as role:user/assistant and content
def format_chat(row):
  messages = []
  for message in row["chat"]:
    content = ""
    if message["system"].strip():
      content += message["system"].strip() + "\n"
    content += message["instruction"]
    if message["input"].strip():
      content += "\n" + message["input"].strip()

    messages.append({"role": "user", "content": content})
    messages.append({"role": "assistant", "content": message["response"]})

  return {"messages": messages}


# Format chat to a generic format and map to dataset format.
combined_train_formatted= combined_train.map(format_chat, remove_columns=combined_train.column_names)
combined_test_formatted= combined_test.map(format_chat, remove_columns=combined_test.column_names)

# Step 3: Evaluate base model: google/gemma-3-1b-it

In [None]:
import evaluate
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

## Step 3.1: Format dataset to gemma format

In [None]:
# Format message to gemma format
def build_gemma_chat_prompt(messages):
  prompt = ""

  for m in messages[:-1]:
    if m["role"] == "user":
      prompt += f"<start_of_turn>user\n{m['content']}\n<end_of_turn>\n"
    else:
      prompt += f"<start_of_turn>model\n{m['content']}\n<end_of_turn>\n"

  last_user = None
  for m in reversed(messages):
    if m["role"] == "user":
      last_user = m["content"]
      break

  prompt += (
    f"<start_of_turn>user\n{last_user}\n<end_of_turn>\n"
    f"<start_of_turn>model\n"
  )

  return prompt


# Retrieve reference
def get_reference(messages):
  for m in reversed(messages):
    if m["role"] == "assistant":
      return m["content"]
  raise ValueError("No assistant reference found")


# Generate prompts
all_prompts = [build_gemma_chat_prompt(item["messages"]) for item in combined_test_formatted]
# Generate references
all_refs = [get_reference(item["messages"]) for item in combined_test_formatted]

## Step 3.2: Load model and tokenizer

In [None]:
# Load model and tokenizer
model_id = "google/gemma-3-1b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

model.eval()

## Step 3.3: Generate predictions

In [None]:
# Generate predictions
def fast_generate(batch, max_new_tokens=64):
  inputs = tokenizer(
      batch,
      return_tensors="pt",
      padding=True,
      truncation=True,
      max_length=2048,
  ).to("cuda")

  with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=False,
        pad_token_id=tokenizer.eos_token_id,
        )

  texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  preds = []
  for prompt, full in zip(batch, texts):
    preds.append(full[len(prompt):].strip())

  return preds


def evaluate_model(prompts, max_new_tokens=64):
  all_preds = []
  batch_size = 8

  for i in tqdm(range(0, len(prompts), batch_size), desc="Evaluating"):
    batch = prompts[i:i+batch_size]

    try:
      preds = fast_generate(batch, max_new_tokens=max_new_tokens)

    except RuntimeError as e:
      if "out of memory" in str(e).lower():
        print("OOM — retrying batch with size=1")
        torch.cuda.empty_cache()
        preds = []
        for p in batch:
          preds.extend(fast_generate([p], max_new_tokens=max_new_tokens))
      else:
        raise e

    all_preds.extend(preds)

  return all_preds


# Evaluate model
all_preds = evaluate_model(all_prompts)

## 3.4 Evaluate

In [None]:
# Compute BLEU-4
bleu = evaluate.load("sacrebleu")
bleu4 = bleu.compute(predictions=all_preds, references=all_refs)
print("BLEU-4:", bleu4["score"])

# Compute ROUGE-L
rouge = evaluate.load("rouge")
rouge_result = rouge.compute(predictions=all_preds, references=all_refs)
print("ROUGE-L:", rouge_result["rougeL"])

# Step 4: Finetune - QLoRA

In [None]:
import time
import torch
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

## Step 4.1: Load model and tokenizer

In [None]:
# Hugging Face model id
model_id = "google/gemma-3-1b-it"

# Select model class based on id

model_class = AutoModelForCausalLM

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
  torch_dtype = torch.bfloat16
else:
  torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
  attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
  torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
  device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
  load_in_4bit=True,
  bnb_4bit_use_double_quant=True,
  bnb_4bit_quant_type='nf4',
  bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
  bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

## Step 4.2: Prepare for finetuning

In [None]:
# Set LoraConfig
peft_config = LoraConfig(
  lora_alpha=16,
  lora_dropout=0.05,
  r=16,
  bias="none",
  target_modules="all-linear",
  task_type="CAUSAL_LM",
  modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

# Set SFTConfig
args = SFTConfig(
  output_dir="./qlora-gemma",         # directory to save and repository id
  max_length=2048,                         # max sequence length for model and packing of the dataset
  packing=True,                           # Groups multiple samples in the dataset into a single sequence
  num_train_epochs=2,                     # number of training epochs
  per_device_train_batch_size=1,          # batch size per device during training
  gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
  gradient_checkpointing=True,            # use gradient checkpointing to save memory
  optim="adamw_torch_fused",              # use fused adamw optimizer
  logging_steps=10,                       # log every 10 steps
  save_strategy="epoch",                  # save checkpoint every epoch
  learning_rate=2e-4,                     # learning rate, based on QLoRA paper
  fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
  bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
  max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
  warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
  lr_scheduler_type="constant",           # use constant learning rate scheduler
  push_to_hub=False,                       # push model to hub
  report_to="tensorboard",                # report metrics to tensorboard
  dataset_kwargs={
      "add_special_tokens": False, # We template with special tokens
      "append_concat_token": True, # Add EOS token as separator token between examples
  }
)

# Create Trainer object
trainer = SFTTrainer(
  model=model,
  args=args,
  train_dataset=combined_train_formatted,
  peft_config=peft_config,
  processing_class=tokenizer
)

## Step 4.3: Finetune

In [None]:
torch.cuda.reset_peak_memory_stats()
start_time = time.time()

# Start training
trainer.train()

end_time = time.time()

# Compute peak memory
peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # in GB
# Compute training time
training_time = end_time - start_time  # in seconds

print(f"Peak GPU memory: {peak_memory:.2f} GB")
print(f"Total training time: {training_time:.2f} seconds")

## Step 4:4: Save model

In [None]:
# Save model
trainer.save_model()

# Step 5: Evaluate QLoRA

In [None]:
import evaluate
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

## Step 5.1: Format dataset to gemma format

In [None]:
# Format message to gemma format
def build_gemma_chat_prompt(messages):
  prompt = ""

  for m in messages[:-1]:
    if m["role"] == "user":
      prompt += f"<start_of_turn>user\n{m['content']}\n<end_of_turn>\n"
    else:
      prompt += f"<start_of_turn>model\n{m['content']}\n<end_of_turn>\n"

  last_user = None
  for m in reversed(messages):
    if m["role"] == "user":
      last_user = m["content"]
      break

  prompt += (
    f"<start_of_turn>user\n{last_user}\n<end_of_turn>\n"
    f"<start_of_turn>model\n"
  )

  return prompt

# Retrieve reference
def get_reference(messages):
  for m in reversed(messages):
    if m["role"] == "assistant":
      return m["content"]
  raise ValueError("No assistant reference found")


# Generate prompts
all_prompts = [build_gemma_chat_prompt(item["messages"]) for item in combined_test_formatted]
# Generate references
all_refs = [get_reference(item["messages"]) for item in combined_test_formatted]

## Step 5.2: Load model

In [None]:
# Load fine-tuned model
model = PeftModel.from_pretrained(model, "./qlora-gemma")
model.eval()

## Step 5.3: Generate predictions

In [None]:
# Generate predictions
def fast_generate(batch, max_new_tokens=64):
  inputs = tokenizer(
      batch,
      return_tensors="pt",
      padding=True,
      truncation=True,
      max_length=2048,
  ).to("cuda")

  with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=False,
        pad_token_id=tokenizer.eos_token_id,
        )

  texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  preds = []
  for prompt, full in zip(batch, texts):
    preds.append(full[len(prompt):].strip())

  return preds


def evaluate_model(prompts, max_new_tokens=64):
  all_preds = []
  batch_size = 8

  for i in tqdm(range(0, len(prompts), batch_size), desc="Evaluating"):
    batch = prompts[i:i+batch_size]

    try:
      preds = fast_generate(batch, max_new_tokens=max_new_tokens)

    except RuntimeError as e:
      if "out of memory" in str(e).lower():
        print("OOM — retrying batch with size=1")
        torch.cuda.empty_cache()
        preds = []
        for p in batch:
          preds.extend(fast_generate([p], max_new_tokens=max_new_tokens))
      else:
        raise e

    all_preds.extend(preds)

  return all_preds


# Evaluate model
all_preds = evaluate_model(all_prompts)

## Step 5.4: Evaluate

In [None]:
# Compute BLEU-4
bleu = evaluate.load("sacrebleu")
bleu4 = bleu.compute(predictions=all_preds, references=all_refs)
print("BLEU-4:", bleu4["score"])

# Compute ROUGE-L
rouge = evaluate.load("rouge")
rouge_result = rouge.compute(predictions=all_preds, references=all_refs)
print("ROUGE-L:", rouge_result["rougeL"])

# Step 6: Finetune LoRA

In [None]:
import time
import torch
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

## Step 6.1: Load model and tokenizer

In [None]:
# Hugging Face model id
model_id = "google/gemma-3-1b-it" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id

model_class = AutoModelForCausalLM

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
  torch_dtype = torch.bfloat16
else:
  torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
  attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
  torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
  device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
  load_in_4bit=True,
  bnb_4bit_use_double_quant=True,
  bnb_4bit_quant_type='nf4',
  bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
  bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

## Step 6.2: Prepare for finetuning

In [None]:
# Set LoraConfig
peft_config = LoraConfig(
  r=16,
  lora_alpha=16,
  lora_dropout=0.05,
  target_modules=["q_proj", "v_proj"],
  bias="none",
  task_type="CAUSAL_LM",
  modules_to_save=["lm_head", "embed_tokens"]
)

# Set SFTConfig
args = SFTConfig(
  output_dir="./lora-gemma",         # directory to save and repository id
  max_length=2048,                         # max sequence length for model and packing of the dataset
  packing=True,                           # Groups multiple samples in the dataset into a single sequence
  num_train_epochs=2,                     # number of training epochs
  per_device_train_batch_size=1,          # batch size per device during training
  gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
  gradient_checkpointing=True,            # use gradient checkpointing to save memory
  optim="adamw_torch_fused",              # use fused adamw optimizer
  logging_steps=10,                       # log every 10 steps
  save_strategy="epoch",                  # save checkpoint every epoch
  learning_rate=2e-4,                     # learning rate, based on QLoRA paper
  fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
  bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
  max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
  warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
  lr_scheduler_type="constant",           # use constant learning rate scheduler
  push_to_hub=False,                       # push model to hub
  report_to="tensorboard",                # report metrics to tensorboard
  dataset_kwargs={
      "add_special_tokens": False, # We template with special tokens
      "append_concat_token": True, # Add EOS token as separator token between examples
  }
)

# Create Trainer object
trainer = SFTTrainer(
  model=model,
  args=args,
  train_dataset=combined_train_formatted,
  peft_config=peft_config,
  processing_class=tokenizer
)

## Step 6.3: Finetune

In [None]:
torch.cuda.reset_peak_memory_stats()
start_time = time.time()

# Start training
trainer.train()

end_time = time.time()

# Compute peak memory
peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # in GB
# Compute training time
training_time = end_time - start_time  # in seconds

print(f"Peak GPU memory: {peak_memory:.2f} GB")
print(f"Total training time: {training_time:.2f} seconds")

## Step 6.4: Save model

In [None]:
# Save model
trainer.save_model()

# Step 7: Evaluate LoRA

In [None]:
import evaluate
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

## Step 7.1: Format dataset to gemma format

In [None]:
# Format message to gemma format
def build_gemma_chat_prompt(messages):
  prompt = ""

  for m in messages[:-1]:
    if m["role"] == "user":
      prompt += f"<start_of_turn>user\n{m['content']}\n<end_of_turn>\n"
    else:
      prompt += f"<start_of_turn>model\n{m['content']}\n<end_of_turn>\n"

  last_user = None
  for m in reversed(messages):
    if m["role"] == "user":
      last_user = m["content"]
      break

  prompt += (
    f"<start_of_turn>user\n{last_user}\n<end_of_turn>\n"
    f"<start_of_turn>model\n"
  )

  return prompt

# Retrieve reference
def get_reference(messages):
  for m in reversed(messages):
    if m["role"] == "assistant":
      return m["content"]
  raise ValueError("No assistant reference found")

# Generate prompts
all_prompts = [build_gemma_chat_prompt(item["messages"]) for item in combined_test_formatted]
# Generate references
all_refs = [get_reference(item["messages"]) for item in combined_test_formatted]

## Step 7.2: Load model

In [None]:
# Load fine-tuned model
model = PeftModel.from_pretrained(model, "./lora-gemma")
model.eval()

## Step 7.3: Generate predictions

In [None]:
# Generate predictions
def fast_generate(batch, max_new_tokens=64):
  inputs = tokenizer(
      batch,
      return_tensors="pt",
      padding=True,
      truncation=True,
      max_length=2048,
  ).to("cuda")

  with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=False,
        pad_token_id=tokenizer.eos_token_id,
        )

  texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  preds = []
  for prompt, full in zip(batch, texts):
    preds.append(full[len(prompt):].strip())

  return preds


def evaluate_model(prompts, max_new_tokens=64):
  all_preds = []
  batch_size = 8

  for i in tqdm(range(0, len(prompts), batch_size), desc="Evaluating"):
    batch = prompts[i:i+batch_size]

    try:
      preds = fast_generate(batch, max_new_tokens=max_new_tokens)

    except RuntimeError as e:
      if "out of memory" in str(e).lower():
        print("OOM — retrying batch with size=1")
        torch.cuda.empty_cache()
        preds = []
        for p in batch:
          preds.extend(fast_generate([p], max_new_tokens=max_new_tokens))
      else:
        raise e

    all_preds.extend(preds)

  return all_preds


# Evaluate model
all_preds = evaluate_model(all_prompts)

## Step 7.4 Evaluate

In [None]:
# Compute BLEU-4
bleu = evaluate.load("sacrebleu")
bleu4 = bleu.compute(predictions=all_preds, references=all_refs)
print("BLEU-4:", bleu4["score"])

# Compute ROUGE-L
rouge = evaluate.load("rouge")
rouge_result = rouge.compute(predictions=all_preds, references=all_refs)
print("ROUGE-L:", rouge_result["rougeL"])