<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MISTRAL_TPU_FT_T2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets -q
!pip install evaluate -q

In [None]:
import torch
import os
import evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
from huggingface_hub import login
from google.colab import userdata

# Import the necessary libraries for TPU
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# This function will be executed on each TPU core.
def fine_tune_on_tpu(index, model_id, train_file, test_file, hf_token):
    # Authenticate with Hugging Face on each process
    login(token=hf_token)

    # Get the specific TPU device for this process
    device = xm.xla_device()
    print(f"Process {index} on device: {device}")

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Load the datasets
    train_dataset = load_dataset("json", data_files=train_file, split="train")
    eval_dataset = load_dataset("json", data_files=test_file, split="train")

    def tokenize_conversation(sample):
        prompt = tokenizer.apply_chat_template(sample["messages"], tokenize=False, add_generation_prompt=True)
        return tokenizer(prompt, truncation=True, padding="max_length", max_length=512)

    tokenized_train_dataset = train_dataset.map(tokenize_conversation, batched=False)
    tokenized_eval_dataset = eval_dataset.map(tokenize_conversation, batched=False)

    # Load the model directly to the TPU device
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16
    ).to(device)

    # Define the custom compute_metrics function
    # This will be called at the end of each evaluation epoch.
    def compute_metrics(eval_preds):
        # We need to compute the loss first, which is handled automatically by Trainer
        # The main task here is to compute accuracy

        # We decode the predictions and labels
        predictions = eval_preds.predictions
        labels = eval_preds.label_ids

        # Ignore -100 in labels, which are padding tokens
        labels[labels == -100] = tokenizer.pad_token_id

        # Decode the tokens back to strings
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Calculate exact match accuracy
        exact_match_accuracy = 0
        for pred, label in zip(decoded_preds, decoded_labels):
            if pred.strip() == label.strip():
                exact_match_accuracy += 1

        # Return the metrics as a dictionary
        return {"exact_match_accuracy": exact_match_accuracy / len(decoded_preds)}

    # Define TrainingArguments
    training_args = TrainingArguments(
        output_dir="./mistral-7b-text-to-sql",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        weight_decay=0.01,
        logging_steps=10,
        save_strategy="epoch",
        eval_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss", # We still use loss for best model
        report_to="none"
    )

    # Trainer Initialization
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics # Pass the custom metrics function
    )

    # Start Training
    print(f"Starting fine-tuning on device {device}...")
    trainer.train()

    # Save the model from the master process
    if xm.is_master_ordinal():
        trainer.save_model("./mistral-7b-text-to-sql-finetuned")
        print("Fine-tuning complete. Model saved.")

if __name__ == '__main__':
    # Get the Hugging Face token from Colab secrets.
    token = userdata.get('HF_TOKEN')
    model_id = "mistralai/Mistral-7B-Instruct-v0.1"
    train_file = "train_dataset.json"
    test_file = "test_dataset.json"

    print("Launching fine-tuning processes on TPU cores...")
    xmp.spawn(
        fine_tune_on_tpu,
        args=(model_id, train_file, test_file, token),
        nprocs=None,
        start_method="fork"
    )

Launching fine-tuning processes on TPU cores...
Process 6 on device: xla:0Process 2 on device: xla:0

Process 4 on device: xla:0
Process 7 on device: xla:1
Process 0 on device: xla:0
Process 1 on device: xla:1
Process 3 on device: xla:1
Process 5 on device: xla:1
