<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/Mistral_TPU_Fine_Tuning_for_Text_to_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q datasets
!pip install -q evaluate

In [None]:
import torch
import os
import evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
from huggingface_hub import login
from google.colab import userdata

# Import the necessary libraries for TPU
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# --- Dataset Preparation ---
# Define the system message for the conversation template
system_message = """You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA:
{schema}"""

# Function to convert dataset samples to OAI (OpenAI) message format
def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message.format(schema=sample["context"])},
      {"role": "user", "content": sample["question"]},
      {"role": "assistant", "content": sample["answer"]}
    ]
  }

# Load the base dataset from the Hugging Face Hub
# This is the "b-mc2/sql-create-context" dataset
print("Loading dataset 'b-mc2/sql-create-context' from Hugging Face Hub...")
dataset = load_dataset("b-mc2/sql-create-context", split="train")

# Shuffle the dataset and select a subset (e.g., 12,500 samples) for faster experimentation
# You can adjust this range or remove it to use the full dataset
print(f"Original dataset size: {len(dataset)} samples. Selecting 12500 samples.")
dataset = dataset.shuffle(seed=42).select(range(12500)) # Using a fixed seed for reproducibility

# Convert the selected dataset samples into the required conversation format
# This step prepares the data for the Mistral model's chat template
print("Converting dataset to OAI messages format...")
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)

# Split the dataset into training and test sets
# 10,000 samples for training, 2,500 for testing
print("Splitting dataset into training and test sets...")
dataset_splits = dataset.train_test_split(test_size=2500/12500, seed=42) # Using a fixed seed for reproducibility
train_dataset_hf = dataset_splits["train"]
eval_dataset_hf = dataset_splits["test"]

print(f"Train dataset size: {len(train_dataset_hf)}")
print(f"Test dataset size: {len(eval_dataset_hf)}")

# Save the processed datasets to JSON files, which the fine-tuning script expects
train_file = "train_dataset.json"
test_file = "test_dataset.json"
print(f"Saving training dataset to {train_file}...")
train_dataset_hf.to_json(train_file, orient="records")
print(f"Saving evaluation dataset to {test_file}...")
eval_dataset_hf.to_json(test_file, orient="records")

# --- Fine-tuning Function (executed on each TPU core) ---
def fine_tune_on_tpu(index, model_id, train_file, test_file, hf_token):
  # Authenticate with Hugging Face on each process
  login(token=hf_token)

  # Get the specific TPU device for this process
  device = xm.xla_device()
  print(f"Process {index} on device: {device}")

  # Load model and tokenizer
  tokenizer = AutoTokenizer.from_pretrained(model_id)
  # Set padding token and side for the tokenizer
  # Using eos_token as pad_token is a common practice for GPT-like models
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_side = "right" # Important for Causal LMs

  # Load the datasets from the JSON files saved previously
  # These are now the pre-processed datasets in conversation format
  print(f"Process {index}: Loading training dataset from {train_file}")
  train_dataset = load_dataset("json", data_files=train_file, split="train")
  print(f"Process {index}: Loading evaluation dataset from {test_file}")
  eval_dataset = load_dataset("json", data_files=test_file, split="train") # Note: split="train" often means the primary split in a single-file JSON load

  # Tokenize the conversations
  # The `create_conversation` step already formatted the data, now we just tokenize it
  def tokenize_conversation(sample):
    # apply_chat_template takes the list of messages and converts them to a single string
    # with special tokens indicating roles (e.g., <s>[INST]...[/INST] )
    prompt = tokenizer.apply_chat_template(
        sample["messages"],
        tokenize=False,  # We want the string output first
        add_generation_prompt=False # No need to add generation prompt here for fine-tuning
    )
    # Then tokenize the string
    return tokenizer(prompt, truncation=True, padding="max_length", max_length=512) # Added max_length for consistency

  print(f"Process {index}: Tokenizing training dataset...")
  tokenized_train_dataset = train_dataset.map(tokenize_conversation, batched=True, remove_columns=["messages"])
  print(f"Process {index}: Tokenizing evaluation dataset...")
  tokenized_eval_dataset = eval_dataset.map(tokenize_conversation, batched=True, remove_columns=["messages"])

  # Load the model directly to the TPU device
  print(f"Process {index}: Loading model {model_id} to device {device}...")
  model = AutoModelForCausalLM.from_pretrained(
      model_id,
      torch_dtype=torch.bfloat16, # Use bfloat16 for TPU efficiency
  ).to(device)

  # Define the custom compute_metrics function for evaluation
  def compute_metrics(eval_preds):
    preds = eval_preds.predictions
    labels = eval_preds.label_ids

    # Replace -100 in labels, which are padding tokens, with the actual padding token ID
    labels = labels.copy() # Avoid modifying original array
    labels[labels == -100] = tokenizer.pad_token_id

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Calculate exact match accuracy
    exact_match_accuracy = 0
    for pred, label in zip(decoded_preds, decoded_labels):
      if pred.strip() == label.strip(): # Strip whitespace for robust comparison
        exact_match_accuracy += 1

    return {"exact_match_accuracy": exact_match_accuracy / len(decoded_preds)}

  # Define TrainingArguments
  training_args = TrainingArguments(
      output_dir="./mistral-7b-text-to-sql",
      num_train_epochs=3,
      per_device_train_batch_size=4,
      per_device_eval_batch_size=4,
      gradient_accumulation_steps=4, # Accumulate gradients over 4 steps
      learning_rate=2e-5,
      weight_decay=0.01,
      logging_steps=10, # Log every 10 steps
      save_strategy="epoch", # Save model at the end of each epoch
      eval_strategy="epoch", # Evaluate at the end of each epoch
      load_best_model_at_end=True, # Load the best model based on metric_for_best_model
      metric_for_best_model="eval_loss", # Use validation loss to determine the best model
      report_to="none", # Do not report to external services like Weights & Biases
      # Additional arguments specific to XLA (TPU) training if needed:
      # xla_device_type="TPU",
      # sharded_ddp="simple",
  )

  # Trainer Initialization
  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=tokenized_train_dataset,
      eval_dataset=tokenized_eval_dataset,
      tokenizer=tokenizer,
      compute_metrics=compute_metrics, # Pass the custom metrics function
  )

  # Start Training
  print(f"Process {index}: Starting fine-tuning on device {device}...")
  trainer.train()

  # Save the model from the master process only
  if xm.is_master_ordinal():
    save_path = "./mistral-7b-text-to-sql-finetuned"
    trainer.save_model(save_path)
    print(f"Fine-tuning complete. Model saved to {save_path}.")

# --- Main execution block ---
if __name__ == '__main__':
  # Get the Hugging Face token from Colab secrets.
  # Ensure you have a secret named 'HF_TOKEN' in Google Colab.
  hf_token = userdata.get('HF_TOKEN')
  if not hf_token:
    print("Warning: Hugging Face token not found. Please set 'HF_TOKEN' in Colab secrets.")
    # Exit or handle gracefully if token is mandatory
    exit()

  model_id = "mistralai/Mistral-7B-Instruct-v0.1"

  # The dataset preparation steps are now integrated before spawning processes.
  # The `train_file` and `test_file` variables are defined above.

  print("Launching fine-tuning processes on TPU cores...")
  # xmp.spawn will call fine_tune_on_tpu function on each TPU core.
  # nprocs=None means use all available TPU cores.
  xmp.spawn(
      fine_tune_on_tpu,
      args=(model_id, train_file, test_file, hf_token),
      nprocs=None,
      start_method="fork" # "fork" is often preferred for Colab/TPU environments
  )

Loading dataset 'b-mc2/sql-create-context' from Hugging Face Hub...


README.md: 0.00B [00:00, ?B/s]

sql_create_context_v4.json:   0%|          | 0.00/21.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/78577 [00:00<?, ? examples/s]

Original dataset size: 78577 samples. Selecting 12500 samples.
Converting dataset to OAI messages format...


Map:   0%|          | 0/12500 [00:00<?, ? examples/s]

Splitting dataset into training and test sets...
Train dataset size: 10000
Test dataset size: 2500
Saving training dataset to train_dataset.json...


Creating json from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Saving evaluation dataset to test_dataset.json...


Creating json from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Launching fine-tuning processes on TPU cores...
Process 7 on device: xla:1
Process 5 on device: xla:1Process 4 on device: xla:0
Process 0 on device: xla:0

Process 2 on device: xla:0Process 3 on device: xla:1



tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Process 4: Loading training dataset from train_dataset.jsonProcess 3: Loading training dataset from train_dataset.json
Process 0: Loading training dataset from train_dataset.json
Process 7: Loading training dataset from train_dataset.json

Process 2: Loading training dataset from train_dataset.jsonProcess 5: Loading training dataset from train_dataset.json



Generating train split: 0 examples [00:00, ? examples/s]

Process 0: Loading evaluation dataset from test_dataset.json
Process 7: Loading evaluation dataset from test_dataset.json
Process 4: Loading evaluation dataset from test_dataset.jsonProcess 3: Loading evaluation dataset from test_dataset.json
Process 5: Loading evaluation dataset from test_dataset.json
Process 2: Loading evaluation dataset from test_dataset.json



Generating train split: 0 examples [00:00, ? examples/s]

Process 0: Tokenizing training dataset...
Process 7: Tokenizing training dataset...
Process 4: Tokenizing training dataset...
Process 2: Tokenizing training dataset...


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Process 5: Tokenizing training dataset...


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Process 3: Tokenizing training dataset...


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Process 7: Tokenizing evaluation dataset...
Process 0: Tokenizing evaluation dataset...


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Process 7: Loading model mistralai/Mistral-7B-Instruct-v0.1 to device xla:1...
Process 0: Loading model mistralai/Mistral-7B-Instruct-v0.1 to device xla:0...


config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Process 2: Tokenizing evaluation dataset...
Process 2: Loading model mistralai/Mistral-7B-Instruct-v0.1 to device xla:0...
Process 3: Tokenizing evaluation dataset...
Process 3: Loading model mistralai/Mistral-7B-Instruct-v0.1 to device xla:1...
Process 4: Tokenizing evaluation dataset...
Process 4: Loading model mistralai/Mistral-7B-Instruct-v0.1 to device xla:0...
Process 5: Tokenizing evaluation dataset...


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Process 5: Loading model mistralai/Mistral-7B-Instruct-v0.1 to device xla:1...


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]