<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Mistral_TPU_Fine_Tuning_for_Text_to_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q datasets
!pip install -q evaluate

In [None]:
!pip install tensorflow-datasets -q
!pip install flax -q
!pip install optax -q
!pip install tensorflow_cpu -q

In [2]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import optax # For optimizers
import flax.linen as nn # For neural network modules
from flax.training import train_state
import tensorflow_datasets as tfds # For loading datasets
import tensorflow as tf # Used for tfds data loading, not for model building
import time
import math # For math.inf
import numpy as np # For np.array
import warnings
warnings.filterwarnings("ignore")

print(f"JAX version: {jax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"TensorFlow version (used for data loading): {tf.__version__}")

# --- 1. TPU Initialization (JAX style) ---
try:
    devices = jax.devices()
    tpu_devices = [d for d in devices if d.platform == 'tpu']
    if not tpu_devices:
        raise ValueError("No TPU devices found.")
    print(f"Found JAX devices: {devices}")
    print(f"Number of TPU devices available: {len(tpu_devices)}")
except ValueError as e:
    print(f"ERROR: {e}. Please ensure your Colab runtime is set to TPU.")
    print("Go to 'Runtime' -> 'Change runtime type' and select 'TPU'.")
    raise SystemExit("TPU not found or not initialized for JAX.")

JAX version: 0.5.3
Optax version: 0.2.5
TensorFlow version (used for data loading): 2.20.0
Found JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Number of TPU devices available: 8


In [3]:
from jax.sharding import NamedSharding, PartitionSpec as P
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)

In [4]:
mesh = jax.make_mesh((4, 2), ('x', 'y'))

x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))

@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
  return y


In [5]:
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)

* It looks like you're encountering a RESOURCE_EXHAUSTED error on your TPU, which means the device is running out of memory. This is a common challenge when fine-tuning large models like Mistral-7B, as they require significant memory.

To address this, I've updated the Canvas document to reduce the memory footprint per training step.

* Solution:

Memory Optimization
The RuntimeError: RESOURCE_EXHAUSTED indicates that the TPU doesn't have enough memory to allocate the necessary buffers for training. Specifically, it tried to allocate 112 MB but only had about 35 MB free. To resolve this, we need to reduce the memory consumed by each training step.

I've modified the TrainingArguments in the Canvas document by:

1. Reducing per_device_train_batch_size: This has been changed from 4 to 1. This means each individual TPU core will process only one sample at a time, significantly reducing immediate memory demand.

2. Increasing gradient_accumulation_steps: This has been increased from 4 to 16. By doing this, the model will process 16 small batches (each of size 1) and accumulate their gradients before performing a single weight update. This effectively keeps your overall "effective batch size" the same (1 sample/batch * 16 accumulation steps = 16 samples per update), but spreads the memory load over more steps.

This combination allows for the training process to proceed without exceeding the TPU's memory capacity while still achieving an update frequency equivalent to a larger batch size.

In [2]:
!pip install -U bitsandbytes -q

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import torch
import os
import evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, set_seed, BitsAndBytesConfig
from datasets import load_dataset
from huggingface_hub import login
from google.colab import userdata

# Import the necessary libraries for TPU
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# Set a random seed for reproducibility across different runs.
# This helps ensure that your results are consistent if you rerun the code.
set_seed(42)

# --- Dataset Preparation ---
# Define the system message for the conversation template
system_message = """You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA:
{schema}"""

# Function to convert dataset samples to OAI (OpenAI) message format
def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message.format(schema=sample["context"])},
      {"role": "user", "content": sample["question"]},
      {"role": "assistant", "content": sample["answer"]}
    ]
  }

# Load the base dataset from the Hugging Face Hub
# This is the "b-mc2/sql-create-context" dataset
print("Loading dataset 'b-mc2/sql-create-context' from Hugging Face Hub...")
dataset = load_dataset("b-mc2/sql-create-context", split="train")

# Shuffle the dataset and select a subset (e.g., 12,500 samples) for faster experimentation
# You can adjust this range or remove it to use the full dataset
print(f"Original dataset size: {len(dataset)} samples. Selecting 12500 samples.")
dataset = dataset.shuffle(seed=42).select(range(12500)) # Using a fixed seed for reproducibility

# Convert the selected dataset samples into the required conversation format
# This step prepares the data for the Mistral model's chat template
print("Converting dataset to OAI messages format...")
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)

# Split the dataset into training and test sets
# 10,000 samples for training, 2,500 for testing
print("Splitting dataset into training and test sets...")
dataset_splits = dataset.train_test_split(test_size=2500/12500, seed=42) # Using a fixed seed for reproducibility
train_dataset_hf = dataset_splits["train"]
eval_dataset_hf = dataset_splits["test"]

print(f"Train dataset size: {len(train_dataset_hf)}")
print(f"Test dataset size: {len(eval_dataset_hf)}")

# Save the processed datasets to JSON files, which the fine-tuning script expects
train_file = "train_dataset.json"
test_file = "test_dataset.json"
print(f"Saving training dataset to {train_file}...")
train_dataset_hf.to_json(train_file, orient="records")
print(f"Saving evaluation dataset to {test_file}...")
eval_dataset_hf.to_json(test_file, orient="records")

# --- Fine-tuning Function (executed on each TPU core) ---
def fine_tune_on_tpu(index, model_id, train_file, test_file, hf_token):
  # Authenticate with Hugging Face on each process
  login(token=hf_token)

  # Get the specific TPU device for this process
  device = xm.xla_device()
  print(f"Process {index} on device: {device}")

  # Load model and tokenizer
  tokenizer = AutoTokenizer.from_pretrained(model_id)
  # Set padding token and side for the tokenizer
  # Using eos_token as pad_token is a common practice for GPT-like models
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_side = "right" # Important for Causal LMs

  # Load the datasets from the JSON files saved previously
  # These are now the pre-processed datasets in conversation format
  print(f"Process {index}: Loading training dataset from {train_file}")
  train_dataset = load_dataset("json", data_files=train_file, split="train")
  print(f"Process {index}: Loading evaluation dataset from {test_file}")
  eval_dataset = load_dataset("json", data_files=test_file, split="train") # Note: split="train" often means the primary split in a single-file JSON load

  # Tokenize the conversations
  # The `create_conversation` step already formatted the data, now we just tokenize it
  def tokenize_conversation(sample):
    # apply_chat_template takes the list of messages and converts them to a single string
    # with special tokens indicating roles (e.g., <s>[INST]...[/INST] )
    prompt = tokenizer.apply_chat_template(
        sample["messages"],
        tokenize=False,  # We want the string output first
        add_generation_prompt=False # No need to add generation prompt here for fine-tuning
    )
    # Then tokenize the string
    return tokenizer(prompt, truncation=True, padding="max_length", max_length=512) # Added max_length for consistency

  print(f"Process {index}: Tokenizing training dataset...")
  tokenized_train_dataset = train_dataset.map(tokenize_conversation, batched=True, remove_columns=["messages"])
  print(f"Process {index}: Tokenizing evaluation dataset...")
  tokenized_eval_dataset = eval_dataset.map(tokenize_conversation, batched=True, remove_columns=["messages"])

  # --- Quantization Configuration ---
  # Define BitsAndBytesConfig for 4-bit quantization
  # This significantly reduces memory usage by loading the model in lower precision
  bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4", # Normalized float 4-bit quantization
      bnb_4bit_compute_dtype=torch.bfloat16, # Compute in bfloat16 for better precision
      bnb_4bit_use_double_quant=False, # Optional: double quantization
  )

  # Load the model with quantization.
  # Use device_map="auto" to let accelerate handle the device placement across available devices.
  # This is often crucial when using quantization with distributed training.
  print(f"Process {index}: Loading model {model_id} with 4-bit quantization and auto device map...")
  model = AutoModelForCausalLM.from_pretrained(
      model_id,
      quantization_config=bnb_config, # Apply the quantization configuration
      device_map="auto", # Let accelerate manage device placement
  )

  # Enable gradient checkpointing to save memory.
  # This trades a small amount of compute time for significant memory reduction.
  model.gradient_checkpointing_enable()


  # Define the custom compute_metrics function for evaluation
  def compute_metrics(eval_preds):
    preds = eval_preds.predictions
    labels = eval_preds.label_ids

    # Replace -100 in labels, which are padding tokens, with the actual padding token ID
    labels = labels.copy() # Avoid modifying original array
    labels[labels == -100] = tokenizer.pad_token_id

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Calculate exact match accuracy
    exact_match_accuracy = 0
    for pred, label in zip(decoded_preds, decoded_labels):
      if pred.strip() == label.strip(): # Strip whitespace for robust comparison
        exact_match_accuracy += 1

    return {"exact_match_accuracy": exact_match_accuracy / len(decoded_preds)}

  # Define TrainingArguments
  training_args = TrainingArguments(
      output_dir="./mistral-7b-text-to-sql",
      num_train_epochs=3,
      per_device_train_batch_size=1, # Reduced to 1 to save memory
      per_device_eval_batch_size=4, # Keep eval batch size as it has less memory impact
      gradient_accumulation_steps=16, # Increased to 16 to compensate for smaller train batch size
      learning_rate=2e-5,
      weight_decay=0.01,
      logging_steps=10, # Log every 10 steps
      save_strategy="epoch", # Save model at the end of each epoch
      eval_strategy="epoch", # Evaluate at the end of each epoch
      load_best_model_at_end=True, # Load the best model based on metric_for_best_model
      metric_for_best_model="eval_loss", # Use validation loss to determine the best model
      report_to="none", # Do not report to external services like Weights & Biases
      gradient_checkpointing=True, # Enabled gradient checkpointing in TrainingArguments
      # Additional arguments specific to XLA (TPU) training if needed:
      # xla_device_type="TPU", # This is handled internally by Trainer with XLA if TPU is detected
      # sharded_ddp="simple", # For DDP sharding, often combined with FSDP for very large models
  )

  # Trainer Initialization
  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=tokenized_train_dataset,
      eval_dataset=tokenized_eval_dataset,
      tokenizer=tokenizer,
      compute_metrics=compute_metrics, # Pass the custom metrics function
  )

  # Start Training
  print(f"Process {index}: Starting fine-tuning on device {device}...")
  trainer.train()

  # Save the model from the master process only
  if xm.is_master_ordinal():
    save_path = "./mistral-7b-text-to-sql-finetuned"
    trainer.save_model(save_path)
    print(f"Fine-tuning complete. Model saved to {save_path}.")

# --- Main execution block ---
if __name__ == '__main__':
  # Get the Hugging Face token from Colab secrets.
  # Ensure you have a secret named 'HF_TOKEN' in Google Colab.
  hf_token = userdata.get('HF_TOKEN')
  if not hf_token:
    print("Warning: Hugging Face token not found. Please set 'HF_TOKEN' in Colab secrets.")
    # Exit or handle gracefully if token is mandatory
    exit()

  model_id = "mistralai/Mistral-7B-Instruct-v0.1"

  # The dataset preparation steps are now integrated before spawning processes.
  # The `train_file` and `test_file` variables are defined above.

  print("Launching fine-tuning processes on TPU cores...")
  # xmp.spawn will call fine_tune_on_tpu function on each TPU core.
  # nprocs=None means use all available TPU cores.
  xmp.spawn(
      fine_tune_on_tpu,
      args=(model_id, train_file, test_file, hf_token),
      nprocs=None,
      start_method="fork" # "fork" is often preferred for Colab/TPU environments
  )


Loading dataset 'b-mc2/sql-create-context' from Hugging Face Hub...
Original dataset size: 78577 samples. Selecting 12500 samples.
Converting dataset to OAI messages format...


Map:   0%|          | 0/12500 [00:00<?, ? examples/s]

Splitting dataset into training and test sets...
Train dataset size: 10000
Test dataset size: 2500
Saving training dataset to train_dataset.json...


Creating json from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Saving evaluation dataset to test_dataset.json...


Creating json from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Launching fine-tuning processes on TPU cores...
Process 2 on device: xla:0Process 6 on device: xla:0

Process 4 on device: xla:0
Process 0 on device: xla:0Process 3 on device: xla:1

Process 6: Loading training dataset from train_dataset.json
Process 2: Loading training dataset from train_dataset.jsonProcess 0: Loading training dataset from train_dataset.json
Process 4: Loading training dataset from train_dataset.json

Process 3: Loading training dataset from train_dataset.json


Generating train split: 0 examples [00:00, ? examples/s]

Process 6: Loading evaluation dataset from test_dataset.json
Process 2: Loading evaluation dataset from test_dataset.json
Process 3: Loading evaluation dataset from test_dataset.json
Process 0: Loading evaluation dataset from test_dataset.json
Process 4: Loading evaluation dataset from test_dataset.json


Generating train split: 0 examples [00:00, ? examples/s]

Process 6: Tokenizing training dataset...
Process 2: Tokenizing training dataset...
Process 4: Tokenizing training dataset...
Process 3: Tokenizing training dataset...

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Process 0: Tokenizing training dataset...



Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Process 0: Tokenizing evaluation dataset...
Process 6: Tokenizing evaluation dataset...


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Process 4: Tokenizing evaluation dataset...


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Process 6: Loading model mistralai/Mistral-7B-Instruct-v0.1 with 4-bit quantization and auto device map...
Process 4: Loading model mistralai/Mistral-7B-Instruct-v0.1 with 4-bit quantization and auto device map...
Process 0: Loading model mistralai/Mistral-7B-Instruct-v0.1 with 4-bit quantization and auto device map...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Process 2: Tokenizing evaluation dataset...
Process 2: Loading model mistralai/Mistral-7B-Instruct-v0.1 with 4-bit quantization and auto device map...




Process 3: Tokenizing evaluation dataset...
Process 3: Loading model mistralai/Mistral-7B-Instruct-v0.1 with 4-bit quantization and auto device map...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]