LLM Fine-tuning (Vertex AI Edition)

We will fine-tune a Large Language Model (LLM) using industry-standard tools and Google Cloud's Vertex AI platform. We will use Hugging Face's `transformers` and `datasets` libraries, and Vertex AI Experiments.

**Goal**: Fine-tune a `Gemma 3` model to speak like a Pirate!

**Tech Stack**:
*   **Model**: [Google Gemma 3 1B](https://huggingface.co/google/gemma-3-1b-it)
*   **Dataset**: [Pirate UltraChat](https://huggingface.co/datasets/winglian/pirate-ultrachat-10k)
*   **Logging & Eval**: Google Cloud Vertex AI (Experiments & Gemini Judge)


In [None]:
# Install dependencies
!pip install transformers datasets peft google-cloud-aiplatform matplotlib python-dotenv --quiet
!pip install torch --quiet

In [None]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

# Vertex AI
import vertexai
from vertexai.generative_models import GenerativeModel, SafetySetting

# TODO: Initialize Vertex AI
GCP_PROJECT_ID = os.getenv('GOOGLE_CLOUD_PROJECT')
vertexai.init(project="GCP_PROJECT_ID", location="us-east4")

In [None]:
def setup_device():
    """
    Checks for available hardware (CUDA, MPS, or CPU), 
    prints the status, and returns the torch device.
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"‚úÖ Using CUDA Device: {torch.cuda.get_device_name(0)}  |  {torch.cuda.memory_allocated(0) / 1024**3:.2f}GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB Total")
        
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("üçé Using Apple Metal Performance Shaders (MPS)")
    
    else:
        device = torch.device("cpu")
        print("üíª Using CPU")

    return device

device = setup_device()

Please make sure that your device is not CPU

## 1. Data Preparation

We will use the `winglian/pirate-ultrachat-10k` dataset from Hugging Face. This dataset contains conversations re-written in a pirate style.

In [None]:
# Load dataset from Hugging Face
dataset_id = "winglian/pirate-ultrachat-10k"
dataset = load_dataset(dataset_id, split="train[:2000]") # Use a subset for speed

print(f"Loaded {len(dataset)} samples.")
print("Sample entry:", dataset[0])

## 2. Model Setup (Gemma 3)

We use `google/gemma-3-1b-it`, the smallest and most efficient version of Google's latest open model series. While a 270M parameter version doesn't exist, this 1B model is highly optimized and perfect for fine-tuning.

In [None]:
model_id = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

In [None]:
prompt = "What's the capital of Italy"
tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

## 3. LoRA Configuration

We use Low-Rank Adaptation (LoRA) to fine-tune efficiently. We rely on the `peft` library's defaults to target the standard attention projection layers.

In [None]:
def apply_lora(model):
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        task_type="CAUSAL_LM",
        lora_dropout=0.05
    )
    return get_peft_model(model, lora_config)

model = apply_lora(model)
model.print_trainable_parameters()

## 4. Training Loop with Vertex AI Logging

We replace Comet with Vertex AI Experiments for tracking metrics.

In [None]:
def train(model, dataset, tokenizer, max_steps=600, accumulation_steps=8, learning_rate=2e-4, eval_samples=5):
    # Initialize Vertex AI Experiment
    vertexai.init(experiment="pirate-finetune-lab")
    vertexai.start_run("run-gemma3-pirate-advanced")
    
    # Advanced Optimizer & Scheduler Setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Scheduler: Reduce LR when loss plateaus
    # Check every 5 *effective* updates
    eval_interval = accumulation_steps * 5
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
    
    model.train()
    
    # Simple data collator/formatter
    def collate_fn(batch):
        prompts = []
        for item in batch:
             chat = item['messages']
             text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
             prompts.append(text)
        return tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    data_iter = iter(dataloader)
    
    losses = []
    print(f"Starting Training (Max Steps: {max_steps}, Accumulation: {accumulation_steps})...")
    
    for step in tqdm(range(max_steps * accumulation_steps)):
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)
        
        outputs = model(**batch, labels=batch['input_ids'])
        loss = outputs.loss / accumulation_steps
        
        loss.backward()
        
        # Track real (un-scaled) loss
        losses.append(loss.item() * accumulation_steps)
        
        # --- Optimizer Step ---
        if (step + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        # --- Scheduler & Logging Step ---
        if (step + 1) % eval_interval == 0:
            avg_loss = np.mean(losses)
            scheduler.step(avg_loss)
            
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Step {(step+1)//accumulation_steps}: Loss {avg_loss:.4f} | LR {current_lr:.2e}")
            
            # Log to Vertex AI
            vertexai.log_metrics({
                "loss": avg_loss,
                "learning_rate": current_lr
            }, step=(step+1)//accumulation_steps)
            
            losses = [] # Reset buffer

    # --- Automated Evaluation Phase ---
    print("\nTraining Complete. Running Evaluation...")
    model.eval()
    eval_scores = []
    
    test_prompts = [
        "Tell me about your favorite treasure.",
        "What should we do with the rum?",
        "Where is the hidden island?",
        "How do you handle a mutiny?",
        "What is the best way to sail through a storm?"
    ]
    
    for i in range(min(eval_samples, len(test_prompts))):
        prompt = test_prompts[i]
        response = generate_response(model, tokenizer, prompt)
        score = evaluate_pirate_style(response)
        eval_scores.append(score)
        print(f"Sample {i+1} | Prompt: {prompt} | Score: {score}/10")
    
    avg_pirate_score = np.mean(eval_scores)
    
    vertexai.log_metrics({"final_avg_pirate_score": avg_pirate_score})
    
    print(f"\nFinal Average Pirate Score: {avg_pirate_score:.2f}/10")
    print("View your results at: https://console.cloud.google.com/vertex-ai/experiments")
    
    vertexai.end_run()
    return model

# Start training and logging
# model = train(model, dataset, tokenizer, max_steps=100)

## 5. Evaluation with Gemini (LLM-as-a-Judge)

We use Google's Gemini 1.5 Flash via Vertex AI to evaluate the "Pirate-ness" of our model.

In [None]:
def evaluate_pirate_style(text):
    judge_model = GenerativeModel("gemini-2.5-flash")
    
    prompt = f"""
    You are a strict judge evaluating if the following text sounds like a real pirate.
    Text: "{text}"
    
    Rate the 'Pirate Style' on a scale of 1-10.
    Return ONLY a JSON object: {{'score': <number>}}
    """
    
    response = judge_model.generate_content(prompt)
    try:
        # Clean up code blocks if Gemini returns them
        content = response.text.replace("```json", "").replace("```", "").strip()
        return json.loads(content)['score']
    except:
        return 0

# Test evaluation
sample_text = "Arr matey! Hand over yer loot or walk the plank!"
score = evaluate_pirate_style(sample_text)
print(f"Pirate Score: {score}/10")

## Conclusion

You have successfully fine-tuned a model using standard tools and integrated it with Google Cloud Vertex AI for professional-grade logging and evaluation.