LLM Fine-tuning (Vertex AI Edition)

We will fine-tune a Large Language Model (LLM) using industry-standard tools and Google Cloud's Vertex AI platform. We will use Hugging Face's `transformers` and `datasets` libraries, and Vertex AI Experiments.

**Goal**: Fine-tune a `Gemma 3` model to speak like a Pirate!

**Tech Stack**:
*   **Model**: [Google Gemma 3 1B](https://huggingface.co/google/gemma-3-1b-it)
*   **Dataset**: [Pirate UltraChat](https://huggingface.co/datasets/winglian/pirate-ultrachat-10k)
*   **Logging & Eval**: Google Cloud Vertex AI (Experiments & Gemini Judge)


In [None]:
# Install dependencies
!pip install transformers datasets peft google-cloud-aiplatform --quiet
!pip install torch --quiet

In [None]:
import os
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

# Vertex AI
import vertexai
from vertexai.generative_models import GenerativeModel, SafetySetting

# TODO: Initialize Vertex AI
# vertexai.init(project="YOUR_PROJECT_ID", location="us-central1")

## 1. Data Preparation

We will use the `winglian/pirate-ultrachat-10k` dataset from Hugging Face. This dataset contains conversations re-written in a pirate style.

In [None]:
# Load dataset from Hugging Face
dataset_id = "winglian/pirate-ultrachat-10k"
dataset = load_dataset(dataset_id, split="train[:2000]") # Use a subset for speed

print(f"Loaded {len(dataset)} samples.")
print("Sample entry:", dataset[0])

## 2. Model Setup (Gemma 3)

We use `google/gemma-3-1b-it`, the smallest and most efficient version of Google's latest open model series. While a 270M parameter version doesn't exist, this 1B model is highly optimized and perfect for fine-tuning.

In [None]:
model_id = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

## 3. LoRA Configuration

We use Low-Rank Adaptation (LoRA) to fine-tune efficiently. We rely on the `peft` library's defaults to target the standard attention projection layers.

In [None]:
def apply_lora(model):
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        task_type="CAUSAL_LM",
        lora_dropout=0.05
    )
    return get_peft_model(model, lora_config)

model = apply_lora(model)
model.print_trainable_parameters()

## 4. Training Loop with Vertex AI Logging

We replace Comet with Vertex AI Experiments for tracking metrics.

In [None]:
def train(model, dataset, tokenizer, max_steps=100):
    # Initialize Vertex AI Experiment
    vertexai.init(experiment="pirate-finetune-lab")
    vertexai.start_run("run-1")
    
    optimizer = Lion(model.parameters(), lr=1e-4)
    model.train()
    
    # Simple data collator/formatter
    def collate_fn(batch):
        prompts = []
        for item in batch:
             # Assuming dataset has 'messages' list with 'content'
             # We adapt to a simple format for this lab
             # Note: Adjust key names based on actual dataset inspection
             chat = item['messages'] # Adapt this if dataset structure differs
             # Construct simple prompt: User -> Assistant
             text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
             prompts.append(text)
        
        return tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    
    step = 0
    losses = []
    
    for batch in tqdm(dataloader):
        if step >= max_steps: break
        
        outputs = model(**batch, labels=batch['input_ids'])
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.item())
        
        if step % 10 == 0:
            avg_loss = np.mean(losses[-10:])
            print(f"Step {step}: Loss {avg_loss:.4f}")
            # Log to Vertex AI
            vertexai.log_metrics({"loss": avg_loss}, step=step)
            
        step += 1
        
    vertexai.end_run()
    return model

# Start training
# model = train(model, dataset, tokenizer)

## 5. Evaluation with Gemini (LLM-as-a-Judge)

We use Google's Gemini 1.5 Flash via Vertex AI to evaluate the "Pirate-ness" of our model.

In [None]:
def evaluate_pirate_style(text):
    judge_model = GenerativeModel("gemini-2.5-flash")
    
    prompt = f"""
    You are a strict judge evaluating if the following text sounds like a real pirate.
    Text: "{text}"
    
    Rate the 'Pirate Style' on a scale of 1-10.
    Return ONLY a JSON object: {{'score': <number>}}
    """
    
    response = judge_model.generate_content(prompt)
    try:
        # Clean up code blocks if Gemini returns them
        content = response.text.replace("```json", "").replace("```", "").strip()
        return json.loads(content)['score']
    except:
        return 0

# Test evaluation
sample_text = "Arr matey! Hand over yer loot or walk the plank!"
score = evaluate_pirate_style(sample_text)
print(f"Pirate Score: {score}/10")

## Conclusion

You have successfully fine-tuned a model using standard tools and integrated it with Google Cloud Vertex AI for professional-grade logging and evaluation.