<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/FTA_DYNAMIC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers accelerate trl bitsandbytes --quiet

## Univeral FineTuningAgent

In [None]:
# Install necessary libraries
!pip install transformers accelerate trl bitsandbytes datasets --quiet

import os

# Set environment variables for debugging
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DISABLED"] = "true"
!export CUDA_LAUNCH_BLOCKING=1  # Enable synchronous CUDA error reporting

# Import necessary modules
from transformers import TrainingArguments
import accelerate

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    Trainer,
    DataCollatorForLanguageModeling,
)

from datasets import load_dataset, DatasetDict

import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import warnings
from trl import SFTTrainer

warnings.filterwarnings("ignore")

# Define the rl_pairs
RL_PAIRS = [
    {
        "model_id": "google/flan-t5-xl",  # A versatile text-to-text model
        "dataset_name": "anthropic/hh-rlhf",  # Anthropic's helpful and harmless dataset
    },
    {
        "model_id": "bigscience/T0_3B",  # A general-purpose text-to-code model
        "dataset_name": "openai/safety-gym",  # OpenAI's Safety Gym for safe RL
    },
    {
        "model_id": "EleutherAI/gpt-neo-125M",  # A smaller model for faster experimentation
        "dataset_name": "MineRL/MineRLBasaltFindCave-v0",  # Minecraft environment for goal-oriented RL
    },
    {
        "model_id": "facebook/blenderbot-400M-distill",  # A dialogue-focused model
        "dataset_name": "stanfordnlp/coqa",  # Conversational Question Answering dataset
    },
    {
        "model_id": "microsoft/DialoGPT-medium",  # Another conversational model
        "dataset_name": "huggingface/rl-chatbot",  # A dataset for training RL chatbots
    },
]


class FineTuningAgent:
    """
    An agent that fine-tunes a language model for text-to-SQL translation or
    other tasks using Reinforcement Learning, structured according to the OODA loop.
    """

    def __init__(self, model_id, dataset_name, config):
        """
        Initializes the FineTuningAgent with model ID, dataset name, and configurations.
        """
        self.model_id = model_id
        self.dataset_name = dataset_name
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _observe(self):
        """
        Observes the environment by loading the model, tokenizer, and dataset.
        """
        # 1. Load Model and Tokenizer (with quantization if enabled)
        quantization_config = None  # Initialize as None to allow for disabling
        if self.config.get("quantization"):
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=False,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float32,
            )

        # Determine the correct model class based on architecture
        config = AutoConfig.from_pretrained(self.model_id)
        if config.is_encoder_decoder:
            model_class = AutoModelForSeq2SeqLM
        else:
            model_class = AutoModelForCausalLM

        self.model = model_class.from_pretrained(
            self.model_id,
            quantization_config=quantization_config,
            trust_remote_code=True,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_id, trust_remote_code=True
        )
        self.model.config.use_cache = False

        #self.model.gradient_checkpointing_enable()  # enable gradient checkpointing

        # Add padding token if it does not exist
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            self.model.resize_token_embeddings(len(self.tokenizer))

        # Move model to device
        self.model.to(self.device)

        print("\n")
        print(f"Model: {self.model}")
        print("\n")

        #print(f"trainable_parameters: {self.model.print_trainable_parameters}")

        #print(f"Submodules: {self.model.print_submodules}")

        #print("\n")
        #print(f"Tokenizer: {self.tokenizer}")
        #print("\n")

        # 2. Load Dataset (using dataset name from Hugging Face Hub or other sources)
        self.dataset = load_dataset(self.dataset_name)

    def _orient(self):
        """
        Orients the agent by formatting the dataset and preparing training arguments.
        """
        # Define system message
        system_message = "You are a helpful and harmless AI assistant."

        # Convert dataset to OAI messages based on dataset name
        if self.dataset_name == "anthropic/hh-rlhf":

            def create_conversation(sample):
                return {
                    "messages": [
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": sample["chosen"]},
                        {"role": "assistant", "content": sample["rejected"]},
                    ]
                }

            self.dataset = self.dataset.map(create_conversation)
        else:

            def create_conversation(sample):
                return {
                    "messages": [
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": sample["question"]},
                        {"role": "assistant", "content": sample["answer"]},
                    ],
                    # Retain original columns
                    "question": sample["question"],
                    "context": sample["context"],
                    "answer": sample["answer"],
                }

            self.dataset = self.dataset.map(
                create_conversation, remove_columns=True
            )

        # Conditional split based on test_split_percentage
        if self.config.get("test_split_percentage") is not None:
            test_split_percentage = self.config["test_split_percentage"]
            # Check if the dataset is already a DatasetDict with splits
            if isinstance(self.dataset, DatasetDict):
                # Assume we want to split the training data
                if "train" in self.dataset:
                    test_size = int(
                        len(self.dataset["train"]) * test_split_percentage
                    )
                    test_size = max(test_size, 1)  # Ensure a minimum test size
                    # Split the train dataset
                    train_dataset_dict = self.dataset["train"].train_test_split(
                        test_size=test_size
                    )
                    # Update the dataset with the new train split and add the new test split
                    self.dataset = DatasetDict(
                        {
                            "train": train_dataset_dict["train"],
                            "test": train_dataset_dict["test"],
                        }
                    )
            else:
                # If the dataset is not a DatasetDict, then it must be a Dataset
                test_size = int(len(self.dataset) * test_split_percentage)
                test_size = max(test_size, 1)  # Ensure a minimum test size
                # Split the dataset
                train_dataset_dict = self.dataset.train_test_split(
                    test_size=test_size
                )
                # Update the dataset with the new splits
                self.dataset = train_dataset_dict

        # 3. Prepare Training Arguments
        self.training_args = TrainingArguments(**self.config.get("training_args"))
        self.training_args.remove_unused_columns = False


        #print("\n")
        #print(f"Training Arguments: {self.training_args}")
        print("\n")
        print(f"Dataset: {self.dataset}")
        print("\n")


    def _decide(self):
        """
        Decides on the fine-tuning strategy, including LoRA configuration and RL algorithm.
        """
        # 4. PEFT Configuration (LoRA)
        if self.config.get("lora"):
            self.model = prepare_model_for_kbit_training(self.model)

            # Determine the correct target modules based on model architecture
            if self.model.config.is_encoder_decoder:
                target_modules = [
                    "q", "k", "v", "o", "wi_0", "wi_1", "wo"  # For encoder-decoder models
                ]
                task_type = "SEQ_2_SEQ_LM"
            elif "mistral" in self.model_id.lower():
                target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]  # Mistral
                task_type = "CAUSAL_LM"
            else:
                target_modules = ["c_attn", "c_proj", "w1", "w2"]  # For decoder-only models
                task_type = "CAUSAL_LM"

           #print('\n')
            print(f"LoRA Target Modules: {target_modules}")
            #print('\n')

            peft_config = LoraConfig(
                lora_alpha=16,
                lora_dropout=0.1,
                r=8,
                bias="none",
                target_modules=target_modules,
                task_type=task_type,
            )
            self.peft_config = peft_config
            self.model = get_peft_model(self.model, peft_config)
            print("\n")
            self.model.print_trainable_parameters()
            print("\n")


    def _get_lora_target_modules(self, model):
        """
        Parses the model structure and returns the LoRA target modules.
        """
        target_modules = []
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                target_modules.append(name)
        return target_modules

    # Get the LoRA target modules
    #target_modules = self._get_lora_target_modules(self.model)

    def _act(self):
        """
        Acts by preprocessing the dataset and initializing the training loop.
        """
        # Preprocess the data
        self.dataset = self.dataset.map(
            self._preprocess_function,
            batched=True,
            remove_columns=self.dataset["train"].column_names,
        )
        # Create data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer, mlm=False
        )

        # 6. Initialize Trainer based on dataset type
        if self.dataset_name == "anthropic/hh-rlhf":
            # Use a regular Trainer for RLHF (you'll need to adapt this for your RL library)
            self.trainer = Trainer(
                model=self.model,
                args=self.training_args,
                train_dataset=self.dataset["train"],
                eval_dataset=self.dataset.get(
                    "test"
                ),  # Use get() to handle missing "test"
                preprocess_logits_for_metrics=False,
                data_collator=data_collator,  # pass the data_collator
            )
            # ... (configure the Trainer for RL - this will depend on your RL library)...
        else:
            # Use SFTTrainer for supervised fine-tuning
            self.trainer = SFTTrainer(
                model=self.model,
                args=self.training_args,
                train_dataset=self.dataset["train"],
                eval_dataset=self.dataset.get("test"),
                preprocess_logits_for_metrics=False,
                peft_config=self.peft_config,
                data_collator=data_collator,  # pass the data_collator
            )

    def _preprocess_function(self, examples):
        """
        Preprocesses the data by combining context and question, tokenizing inputs and labels.
        """
        # Iterate through the messages
        inputs = []
        labels = []
        for example in examples["messages"]:
            # Extract the information from the messages
            user_prompt = ""
            assistant_response = ""

            # Process the message to get the right information
            for sub_message in example:
                if sub_message["role"] == "user":
                    user_prompt = sub_message["content"]
                elif sub_message["role"] == "assistant":
                    assistant_response = sub_message["content"]
            # Only append the data to the list if both variables have content
            if user_prompt and assistant_response:
                inputs.append(f"### Prompt: {user_prompt}")
                labels.append(assistant_response)

        # Tokenize the inputs and labels
        # set padding to 'max_length' to ensure that all the examples are being padded
        model_inputs = self.tokenizer(
            inputs, max_length=1024, truncation=True, padding="max_length" # Reduced the max_length
        )
        with self.tokenizer.as_target_tokenizer():
            label_ids = self.tokenizer(
                labels, max_length=512, truncation=True, padding="max_length" # Reduced the max_length
            )
        model_inputs["labels"] = label_ids["input_ids"]
        return model_inputs

    def run(self):
        """
        Executes the OODA loop and fine-tunes the language model using RL.
        """
        for pair in RL_PAIRS:
            self.model_id = pair["model_id"]
            self.dataset_name = pair["dataset_name"]

            print("Observe: Start")
            print("\n")  # Separate Observe stage
            self._observe()
            print("Observe: End")
            print("\n")  # Separates the observe stage from the Orient stage

            print("Orient: Start")
            print("\n")  # Separate Orient stage
            self._orient()
            # print("\n") # Separate Orient stage
            print("Orient: End")
            print("\n")  # Separates the Orient stage from the Decide stage

            print("Decide: Start")
            print("\n")  # Separate Decide stage
            self._decide()
            print("Decide: End")
            print("\n")  # Separates the Decide stage from the Act stage

            print("Act: Start")
            print("\n")  # Separate Act stage
            self._act()
            print("Act: End")
            print("\n")  # Separates the Act stage from the training information

            print(
                f"Start: Fine-tuning for Model: {self.model_id}, "
                f"Dataset: {self.dataset_name}"
            )
            print("\n")  # Separate training information

            # Start the training
            print("# Start the training")
            self.trainer.train()
            print(
                f"End: Fine-tuning for Model: {self.model_id}, "
                f"Dataset: {self.dataset_name}"
            )
            print("\n")  # Separate training information

            # Start the RL training loop
            # ... (using a suitable RL library or a custom implementation)


# Example Usage with rl_pairs
config = {
    "training_args": {
        "output_dir": "./results",
        "num_train_epochs": 1,
        "per_device_train_batch_size": 1,  # reduced the batch_size
        "gradient_accumulation_steps": 1,  # reduced the number of steps
        # "report_to": None, # removed report to
        #"gradient_checkpointing": True,
        "optim": "adamw_torch_fused",
        "learning_rate": 2e-4,
        "bf16": False,  # disabled bfloat16 to find the root of the error.
        # "tf32": True, # for L4 AND A100 # Removed because it may be the cause of the error
        # "tf32": False, # Set tf32 to false since your hardware may not support it # Removed because it may be the cause of the error
        "max_grad_norm": 0.3,
        "warmup_ratio": 0.03,
        "lr_scheduler_type": "constant",
        "logging_steps": 1, # Added logging_steps
        "evaluation_strategy": "steps", # Added an evaluation strategy
        "eval_steps":1, # Add an evaluation step
    },
    "quantization": False,  # disabled quantization to find the root of the error.
    "lora": True,  # Enable LORA
    "test_split_percentage": 0.25,  # Use 25% of the data for evaluation
}

agent = FineTuningAgent(
    model_id=None,  # We'll set this in the loop
    dataset_name=None,  # We'll set this in the loop
    config=config,
)

# Iterate through the rl_pairs and run the fine-tuning for each pair
agent.run()  # The run() method now handles the iteration

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Observe: Start




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Model: T5ForConditionalGeneration(
  (shared): Embedding(32128, 2048)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 2048)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=2048, out_features=2048, bias=False)
              (k): Linear(in_features=2048, out_features=2048, bias=False)
              (v): Linear(in_features=2048, out_features=2048, bias=False)
              (o): Linear(in_features=2048, out_features=2048, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=2048, out_features=5120, bias=False)
              (wi_1): Linear(in_features=2048, out_features=5120, bias=False

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Observe: End


Orient: Start




Dataset: DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'messages'],
        num_rows: 120600
    })
    test: Dataset({
        features: ['chosen', 'rejected', 'messages'],
        num_rows: 40200
    })
})


Orient: End


Decide: Start


LoRA Target Modules: ['q', 'k', 'v', 'o', 'wi_0', 'wi_1', 'wo']
