<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Phi_2_TPU_FT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://huggingface.co/microsoft/phi-2

In [None]:
!pip install transformers==4.47.0 datasets optimum-tpu==0.2.3 torch-xla==2.5.1 -q

In [2]:
from google.colab import userdata
from huggingface_hub import login
from google.colab import drive
token=userdata.get('HF_TOKEN')
login(token)

In [None]:
!pip install -q jax-ai-stack==2025.4.9
!pip install -Uq "jax[tpu]==0.5.3" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info orbax-checkpoint==0.11.12
!pip install -Uq datasets

In [None]:
import warnings
# Ignore the specific JAX warning about skipped cross-host ArrayMetadata validation
warnings.filterwarnings(
    "ignore",
    message=".*Skipped cross-host ArrayMetadata validation because only one process is found.*",
    category=UserWarning,  # Or Warning if the category is different
)

# All necessary imports from the original notebook
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jax import random
import jax.nn.initializers as init
import jax.nn as nn
from jax.lib import xla_bridge
from jax.experimental.mesh_utils import create_device_mesh
import optax
import time
import orbax.checkpoint as orbax
import numpy as np
import shutil
from datasets import load_dataset
from transformers import GPT2Tokenizer
import tiktoken
import flax.nnx as nnx

In [5]:
import jax
mesh = jax.make_mesh((1,), ('batch',))

In [6]:
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
# Changed PartitionSpec to use the 'batch' axis to match the mesh
y = jax.device_put(x, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(y)

In [None]:
import torch_xla
device=torch_xla.device


In [None]:
!pip install -Uq datasets bitsandbytes

In [None]:
!pip install -q peft

In [None]:
from google.colab import userdata
api_key = userdata.get('WANDB_KEY')
import wandb
wandb.login(key=api_key)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import load_dataset
from huggingface_hub import login
from google.colab import userdata
import torch_xla.core.xla_model as xm
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import math
import warnings
warnings.filterwarnings("ignore")

# Authenticate with Hugging Face Hub using your token.
try:
    token = userdata.get('HF_TOKEN')
    login(token=token)
except Exception as e:
    print(f"Failed to login to Hugging Face Hub: {e}")

# Load the Phi-2 model and tokenizer without quantization.
# BitsAndBytes is not compatible with TPUs.

device = xm.xla_device()
print(f"--- Loading non-quantized Phi-2 on {device} ---")
model_name = "microsoft/Phi-2"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

# Load the model with BFloat16 precision, which is a good balance of memory and performance
# for TPUs, and then move it to the XLA device.
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)
model.to(device)

# Prepare the model for PEFT training
model.config.use_cache = False  # Recommended for training
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("--- Model loaded successfully! ---")

# Load and format the dataset.
print("--- Loading dataset... ---")
# Split the dataset into train and test sets to have a proper evaluation
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
split_dataset = dataset.train_test_split(test_size=0.1)
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

def format_data(examples):
    formatted_texts = []
    for instruction, context, response in zip(examples['instruction'], examples['context'], examples['response']):
        formatted_text = f"Instruction: {instruction}\n"
        if context:
            formatted_text += f"Context: {context}\n"
        formatted_text += f"Response: {response}"
        formatted_texts.append(formatted_text)
    return {"text": formatted_texts}

formatted_train_dataset = train_dataset.map(format_data, batched=True)
formatted_eval_dataset = eval_dataset.map(format_data, batched=True)

# Tokenize the dataset and add labels for the trainer.
def tokenize_function(examples):
    tokenized_inputs = tokenizer(examples["text"], padding="max_length", truncation=True)
    tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
    return tokenized_inputs

tokenized_train_dataset = formatted_train_dataset.map(tokenize_function, batched=True)
tokenized_eval_dataset = formatted_eval_dataset.map(tokenize_function, batched=True)

# Define TrainingArguments for TPU and initialize the Trainer.
output_dir = "./phi-2_finetuned"
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    num_train_epochs=1,
    fp16=False,
    bf16=True,
    optim="adamw_torch_xla",
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    tokenizer=tokenizer
)

In [14]:
print("--- Starting fine-tuning process... ---")
trainer.train()

print("--- Fine-tuning completed! ---")

--- Starting fine-tuning process... ---


Step,Training Loss,Validation Loss
500,0.1872,0.167103
1000,0.1803,0.163098
1500,0.166,0.161683
2000,0.1969,0.161072
2500,0.1557,0.160491
3000,0.1835,0.16018
3500,0.1542,0.159701
4000,0.1617,0.159485
4500,0.173,0.159292
5000,0.1783,0.159065


Step,Training Loss,Validation Loss
500,0.1872,0.167103
1000,0.1803,0.163098
1500,0.166,0.161683
2000,0.1969,0.161072
2500,0.1557,0.160491
3000,0.1835,0.16018
3500,0.1542,0.159701
4000,0.1617,0.159485
4500,0.173,0.159292
5000,0.1783,0.159065


--- Fine-tuning completed! ---


In [15]:
# Evaluate and calculate perplexity after training
print("--- Evaluating model and calculating perplexity... ---")
eval_results = trainer.evaluate()
perplexity = math.exp(eval_results['eval_loss'])
print(f"Final Evaluation Loss: {eval_results['eval_loss']:.4f}")
print(f"Final Perplexity: {perplexity:.4f}")

--- Evaluating model and calculating perplexity... ---


Final Evaluation Loss: 0.1580
Final Perplexity: 1.1712
