<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Training_Smart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q -U datasets transformers accelerate peft trl bitsandbytes sentencepiece interpret
!pip install colab-env --quiet

In [None]:
import colab_env
import os

access_token_write = os.getenv("HUGGINGFACE_ACCESS_TOKEN_WRITE")

In [None]:
import sqlite3
import torch
from datasets import load_dataset
from peft import get_peft_model, prepare_model_for_kbit_training
from interpret.glassbox import ExplainableBoostingClassifier  # For the EBM
from sklearn.feature_extraction.text import TfidfVectorizer  # For EBM
# Add necessary libraries for RLHF (e.g., trl, stable-baselines3)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

In [None]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
)
from torch.nn import Module, Linear, ReLU, MSELoss
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training
from trl import SFTTrainer
from huggingface_hub import login

# 0. Login to Hugging Face Hub
login(
  token=access_token_write,
  add_to_git_credential=True
)

# 1. Load Mistral model and tokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import setup_chat_format

# Hugging Face model id
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id,use_fast=True)
tokenizer.padding_side = 'right' # to prevent warnings

# We redefine the pad_token and pad_token_id with out of vocabulary token (unk_token)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)

# 2. Load Spider dataset
spider_dataset = load_dataset("spider")

# 3. Preprocess the data
def preprocess_function(examples):
    inputs = [f"Translate to SQL: {q}" for q in examples["question"]]
    targets = examples["query"]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

spider_dataset = spider_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=spider_dataset["train"].column_names,
)

# 4. Split the dataset
train_dataset = spider_dataset["train"].select(range(1000))  # Smaller subset for POC
eval_dataset = spider_dataset["validation"].select(range(100))

# 5. Define the Energy-Based Model (EBM) component
class SimpleEBM(Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = Linear(input_dim, hidden_dim)
        self.linear2 = Linear(hidden_dim, 1)  # Output a single energy value
        self.relu = ReLU()

    def forward(self, x):
        x = self.relu(self.linear1(x))
        energy = self.linear2(x)
        return energy.squeeze()  # Return a scalar energy value

# 6. Combine Mistral and EBM
class HybridT2SQL(Module):
    def __init__(self, mistral_model, ebm):
        super().__init__()
        self.mistral_model = mistral_model
        self.ebm = ebm

    def forward(self, input_ids, attention_mask, labels=None):
        # Get Mistral output
        mistral_output = self.mistral_model(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = mistral_output.logits

        # Calculate energy using EBM
        energy = self.ebm(hidden_states)

        if labels is not None:
            # Calculate loss (MSE loss for demonstration)
            loss_fn = MSELoss()
            # Calculate the "target energy" (e.g., 0 for correct sequences)
            target_energy = torch.zeros_like(energy)
            loss = loss_fn(energy, target_energy) + mistral_output.loss  # Combine with Mistral's loss
            return loss
        else:
            return energy

# 7. Initialize the model
input_dim = mistral_model.config.hidden_size
hidden_dim = 256
ebm = SimpleEBM(input_dim, hidden_dim)
model = HybridT2SQL(mistral_model, ebm)

# 7.5 Prepare for int8 training and apply LoRA
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# 8. Define TrainingArguments with PEFT
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_exact_match",
    push_to_hub=True,
    hub_model_id="frankmorales2020/mistral-7b-ebm-rlhf-finetuned-t2sql",  # Your Hub model ID
    # Add more arguments as needed for pushing to Hub (e.g., push_to_hub_token)
)

# 9. Define Reward Function/Class (Example)
class SQLRewardFunction:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, samples):
        # 1. Decode the generated SQL queries
        decoded_queries = [
            self.tokenizer.decode(sample, skip_special_tokens=True)
            for sample in samples
        ]

        # 2. Evaluate the SQL queries (using a metric or your logic)
        rewards = []
        for query in decoded_queries:
            try:
                # Example: Check if the query is syntactically correct
                # You can use a SQL parser or your own logic here
                # ...
                reward = 1.0  # Assign a reward of 1.0 if correct
            except:
                reward = 0.0  # Assign a reward of 0.0 if incorrect
            rewards.append(reward)

        return torch.tensor(rewards)

# 10. Initialize the reward function
reward_fn = SQLRewardFunction(tokenizer)

# 11. Define SFTTrainer for RLHF
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",  # Assuming your dataset has a "text" field
    reward_model=reward_fn,  # Pass the reward function/class
    # Add other RLHF configurations here
)

# 12. Train the model with RLHF
trainer.train()