In [None]:
# ==========================================================================================
#
#  NOTEBOOK FOR FINE-TUNING GEMMA 3 1B LOCALLY
#
# ------------------------------------------------------------------------------------------
#
#  **DISCLAIMER**
#
#  This notebook is intended for educational purposes only.
#
#  - Date: July 2025
#  - Not suitable for production environments.
#  - Use at your own risk.
#
# ==========================================================================================
#
#  Overview:
#
#  This notebook provides a step-by-step guide to fine-tuning the Gemma 3 1B model.
#  The process involves:
#
#      01. Serving the Gemma 3 1B base model and getting a baseline answer.
#      02. Fine-tuning the model with a custom dataset.
#      03. Saving the fine-tuned adapter and the full merged model to local directories.
#      04. Serving the fine-tuned model from local storage and comparing its answers
#          with the base model.
#
# ------------------------------------------------------------------------------------------
#
#  Requirements:
#  - A local or cloud environment (like Colab) with a GPU (e.g., NVIDIA L4).
#  - A HuggingFace account that has acknowledged the Gemma 3 1B terms and a read-permission API token.
#
# ==========================================================================================


In [None]:
# --- 1. Installation ---
# Install necessary Python packages
!pip install -q -U transformers datasets accelerate peft trl bitsandbytes vllm

# Restart the session to apply the newly installed packages
import os
os.kill(os.getpid(), 9)


In [None]:
# --- 2. Setup ---
# Log in to your HuggingFace Account to download the model
from huggingface_hub import notebook_login
notebook_login()
## use your token 

In [None]:
# Necessary imports
import torch
import gc
import os
import re
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from trl import SFTTrainer
from vllm import LLM, SamplingParams

# --- 3. Configuration ---
# Define model, dataset, and local directory paths

# Use the instruction-tuned model as the base
BASE_MODEL_ID = "google/gemma-3-1b-it"
DATASET_ID = "fredmo/gemma_ft_dc_rules_dataset"

# Define local directories for outputs
TRAINING_OUTPUT_DIR = "./results"
MERGED_MODEL_DIR = "./merged_fine_tuned_model"
BASE_ANSWER_FILE = "base_model_answer.txt"

# Create directories if they don't exist
os.makedirs(TRAINING_OUTPUT_DIR, exist_ok=True)
os.makedirs(MERGED_MODEL_DIR, exist_ok=True)

# Helper function to clear GPU memory
def clear_gpu_memory():
    """Frees up GPU memory."""
    torch.cuda.empty_cache()
    gc.collect()

# --- 4. Base Model Inference ---
# Generate an answer from the base model to use as a baseline for comparison.
print("\n--- Performing base inference ---")
llm_base = None
try:
    # Configure and load the base model using vLLM for efficient inference
    llm_base = LLM(
        model=BASE_MODEL_ID,
        trust_remote_code=True,
        dtype=torch.float16,
        gpu_memory_utilization=0.5 # Use less memory to be safe
    )

    question = "What are Cloud Front Ends and how do they relate to customer VMs in Google Cloud?"
    sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=256)

    # The instruction-tuned (-it) model uses a chat template, which we apply here
    tokenizer_base = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
    messages = [{"role": "user", "content": question}]
    prompt = tokenizer_base.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Generate the answer and save it to a local file
    outputs = llm_base.generate([prompt], sampling_params)
    base_model_answer = outputs[0].outputs[0].text.strip()
    with open(BASE_ANSWER_FILE, "w") as f:
        f.write(base_model_answer)
    print(f"\nBase model answer saved to {BASE_ANSWER_FILE}")

finally:
    # Clean up to free GPU memory
    if llm_base:
        del llm_base
    clear_gpu_memory()
    print("Base inference complete and resources cleared.")

# --- 5. Fine-Tuning ---
# Fine-tune the base model using the specified dataset.
print("\n--- Performing fine-tuning ---")
trainer = None
model = None
try:
    # Load the dataset
    dataset = load_dataset(DATASET_ID, split="train").shuffle(seed=42)

    # Configure BitsAndBytes for 4-bit quantization to save memory
    compute_dtype = torch.float16
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype
    )

    # Load the model with the quantization config
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    model.config.use_cache = False

    # Configure the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
    tokenizer.pad_token = tokenizer.eos_token

    # Configure LoRA (Low-Rank Adaptation) for efficient fine-tuning
    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM"
    )

    # Configure training arguments
    training_args = TrainingArguments(
        output_dir=TRAINING_OUTPUT_DIR,
        num_train_epochs=2,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        optim="adamw_torch",
        save_strategy="epoch",
        logging_steps=25,
        learning_rate=2e-4,
        fp16=True,
        max_grad_norm=0.3,
        warmup_ratio=0.03,
        group_by_length=True,
        lr_scheduler_type="constant",
        report_to="none"
    )

    # Format the dataset using the chat template
    formatted_dataset = dataset.map(lambda ex: {"text": tokenizer.apply_chat_template(ex["messages"], tokenize=False)})

    # Initialize and run the trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=formatted_dataset,
        peft_config=peft_config,
        args=training_args
    )
    trainer.train()
    print("\nFine-tuning complete.")

finally:
    # Clean up resources
    if 'model' in locals():
        del model
    if 'trainer' in locals():
        del trainer
    clear_gpu_memory()
    print("Fine-tuning resources cleared.")

# --- 6. Merge Adapter and Save Model ---
# Find the latest adapter checkpoint from training
checkpoint_dirs = [d for d in os.listdir(TRAINING_OUTPUT_DIR) if d.startswith("checkpoint-")]
latest_checkpoint_name = max(checkpoint_dirs, key=lambda d: int(re.search(r"(\d+)", d).group(1)))
ADAPTER_PATH = os.path.join(TRAINING_OUTPUT_DIR, latest_checkpoint_name)
print(f"Using adapter from: {ADAPTER_PATH}")

# Merge the fine-tuned LoRA adapter into the base model to create a full model
print("\n--- Merging adapter to create a full model ---")
try:
    # Load the base model in float16
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    # Apply the adapter to the base model
    merged_model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
    # Merge the weights and unload the adapter
    merged_model = merged_model.merge_and_unload()

    # Save the full merged model and its tokenizer to a local directory
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
    merged_model.save_pretrained(MERGED_MODEL_DIR)
    tokenizer.save_pretrained(MERGED_MODEL_DIR)
    print(f"Full fine-tuned model saved to {MERGED_MODEL_DIR}")

finally:
    # Clean up resources
    if 'base_model' in locals():
        del base_model
    if 'merged_model' in locals():
        del merged_model
    clear_gpu_memory()
    print("Merging resources cleared.")


# --- 7. Inference with Fine-Tuned Model (with Fallback) ---
print("\n--- Performing inference with the fine-tuned model ---")

finetuned_model_answer = None

# Primary Method: Use vLLM with the full merged model for fast inference
try:
    print("\nAttempting to use primary method: vLLM with full merged model...")
    llm_finetuned = LLM(
        model=MERGED_MODEL_DIR, # Load from the local directory
        trust_remote_code=True,
        max_model_len=2048,
        gpu_memory_utilization=0.8,
        dtype=torch.float16,
    )
    tokenizer_finetuned = AutoTokenizer.from_pretrained(MERGED_MODEL_DIR)

    question = "What are Cloud Front Ends and how do they relate to customer VMs in Google Cloud?"
    messages = [{"role": "user", "content": question}]
    prompt = tokenizer_finetuned.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=256)
    outputs = llm_finetuned.generate([prompt], sampling_params)
    finetuned_model_answer = outputs[0].outputs[0].text.strip()
    print("vLLM method succeeded.")
    del llm_finetuned
    clear_gpu_memory()

except Exception as e:
    print(f"vLLM method failed: {e}")
    finetuned_model_answer = None  # Ensure fallback runs

# Fallback Method: If vLLM fails, use Transformers + PEFT with the adapter
if finetuned_model_answer is None:
    try:
        print("\nAttempting fallback method: Transformers + PEFT...")
        # Load the base model
        base_model_peft = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        # Apply the LoRA adapter from the local path
        peft_model = PeftModel.from_pretrained(base_model_peft, ADAPTER_PATH)
        tokenizer_peft = AutoTokenizer.from_pretrained(BASE_MODEL_ID)

        question = "What are Cloud Front Ends and how do they relate to customer VMs in Google Cloud?"
        messages = [{"role": "user", "content": question}]
        prompt = tokenizer_peft.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        # Generate the response
        inputs = tokenizer_peft(prompt, return_tensors="pt").to(peft_model.device)
        with torch.no_grad():
            outputs = peft_model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer_peft.eos_token_id)
        # Decode and remove the prompt from the output
        response = tokenizer_peft.decode(outputs[0], skip_special_tokens=True)
        finetuned_model_answer = response[len(prompt):].strip()
        print("PEFT fallback method succeeded.")
        del base_model_peft, peft_model
        clear_gpu_memory()

    except Exception as e2:
        print(f"PEFT fallback also failed: {e2}")
        finetuned_model_answer = "Error: Both vLLM and PEFT methods failed. Check model, adapter, or config files."


# --- 8. Final Comparison ---
# Display the expected answer alongside the answers from the base and fine-tuned models.

print("\n--- Reading base model answer from local file ---")
with open(BASE_ANSWER_FILE, "r") as f:
    base_model_answer = f.read()

print("\n\n" + "==========" + " FINAL RESULTS " + "==========")
print("\n## Desired Answer: ##")
print("Cloud Front Ends are specific GFEs located in the same cloud region as customer VMs for minimizing latency. They allow customer VMs to communicate with Google APIs and services without needing external IP addresses.")
print("\n## Base Model Answer: ##")
print(base_model_answer)
print("\n## Fine-tuned Model Answer: ##")
print(finetuned_model_answer)
print("\n" + "==========")

print("\n--- Script Finished. ---")