<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/FTA_DYNAMIC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers accelerate trl bitsandbytes --quiet

## Univeral FineTuningAgent

In [1]:
!pip install transformers accelerate trl bitsandbytes --quiet

import os

# Set environment variables for debugging
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DISABLED"] = "true"
!export CUDA_LAUNCH_BLOCKING=1  # Enable synchronous CUDA error reporting

# Import necessary modules
from transformers import TrainingArguments
import accelerate

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    Trainer,
    DataCollatorForLanguageModeling,
)

from datasets import load_dataset, DatasetDict, Dataset
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import warnings
from trl import SFTTrainer

warnings.filterwarnings("ignore")

# Define the rl_pairs
RL_PAIRS = [
    {
        "model_id": "mistralai/Mistral-7B-Instruct-v0.1",  # Mistral model
        "dataset_name": "b-mc2/sql-create-context",  # SQL dataset
    }
]

class FineTuningAgent:
    def __init__(self, model_id, dataset_name, config):
        """
        Initializes the FineTuningAgent with model ID, dataset name, and configuration.
        """
        self.model_id = model_id
        self.dataset_name = dataset_name
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dataset = None # Initialize the dataset as None

    def _observe(self):
        """
        Observes the environment by loading the model, tokenizer, and dataset.
        """
        # 1. Load Model and Tokenizer (with quantization if enabled)
        quantization_config = None  # Initialize as None to allow for disabling quantization
        if self.config.get("quantization"):
            if "mistral" in self.model_id.lower():
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                )
            else:
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_use_double_quant=False,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.float32,
                )

        # Determine the correct model class based on architecture
        config = AutoConfig.from_pretrained(self.model_id)
        if config.is_encoder_decoder:
            #model_class = AutoModelForSeq2SeqLM

            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                self.model_id,
                quantization_config=quantization_config,
                trust_remote_code=True,
            )
        else:
            #model_class = AutoModelForCausalLM

            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                quantization_config=quantization_config,
                trust_remote_code=True,
            )


        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_id, trust_remote_code=True
        )
        self.model.config.use_cache = False

        # Add padding token if it does not exist
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            #self.model.resize_token_embeddings(len(self.tokenizer))

        # Move model to device
        self.model.to(self.device)

        print("\n")
        print(f"Model: {self.model}")
        print("\n")

        # 2. Load Dataset (using dataset name from Hugging Face Hub or other sources)
        self.dataset = load_dataset(self.dataset_name)

    def _orient(self):
        """
        Orients the agent by formatting the dataset and preparing training arguments.
        """
        # Check if dataset is None or empty before proceeding
        if self.dataset is None:
            print("Dataset is not loaded. Skipping dataset processing.")
            return  # Exit the function if dataset is None
        if isinstance(self.dataset, DatasetDict):
           if all(len(self.dataset[split]) == 0 for split in self.dataset):
               print("Dataset is empty. Skipping dataset processing.")
               return  # Exit the function if the dataset is empty
        elif len(self.dataset) == 0:
           print("Dataset is empty. Skipping dataset processing.")
           return

        # Define system message
        system_message = "You are a helpful and harmless AI assistant."

        # Convert dataset to OAI messages based on dataset name
        if self.dataset_name == "anthropic/hh-rlhf":
            def create_conversation(sample):
                return {
                    "messages": [
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": sample["chosen"]},
                        {"role": "assistant", "content": sample["rejected"]},
                    ]
                }
        else:
            def create_conversation(sample):
                return {
                    "messages": [
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": sample["question"]},
                        {"role": "assistant", "content": sample["answer"]},
                    ],
                }

        # Batch processing
        batch_size = 1000  # Adjust the batch size as needed
        if isinstance(self.dataset, DatasetDict):
            # Create a new DatasetDict to hold the updated data
            new_dataset = DatasetDict()
            for split in self.dataset:
                # Initialize an empty list to hold the updated data for this split
                updated_split = []
                for i in range(0, len(self.dataset[split]), batch_size):
                    batch = self.dataset[split][i: i + batch_size]
                    batch_dataset = Dataset.from_list(batch)  # Convert batch to Dataset
                    if self.dataset_name == "anthropic/hh-rlhf":
                        mapped_batch = batch_dataset.map(create_conversation)
                    else:
                        # Remove all columns except messages
                        mapped_batch = batch_dataset.map(
                            create_conversation, remove_columns=batch_dataset.column_names
                        )
                    updated_split.extend(mapped_batch)
                # Add the updated split to the new_dataset
                new_dataset[split] = Dataset.from_list(updated_split)
            # Reassign the dataset to the newly created DatasetDict
            self.dataset = new_dataset
        else:
            # Initialize an empty list to hold the updated data
            updated_dataset = []
            for i in range(0, len(self.dataset), batch_size):
                batch = self.dataset[i: i + batch_size]
                batch_dataset = Dataset.from_list(batch)  # Convert batch to Dataset
                if self.dataset_name == "anthropic/hh-rlhf":
                    mapped_batch = batch_dataset.map(create_conversation)
                else:
                    # Remove all columns except messages
                    mapped_batch = batch_dataset.map(
                        create_conversation, remove_columns=batch_dataset.column_names
                    )
                updated_dataset.extend(mapped_batch)
            # Concatenate the updated_dataset back into one dataset
            self.dataset = Dataset.from_list(updated_dataset)

        # Conditional split based on test_split_percentage
        if self.config.get("test_split_percentage") is not None:
            test_split_percentage = self.config["test_split_percentage"]
            # Check if the dataset is already a DatasetDict with splits
            if isinstance(self.dataset, DatasetDict):
                # Assume we want to split the training data
                if "train" in self.dataset:
                    test_size = int(
                        len(self.dataset["train"]) * test_split_percentage
                    )
                    test_size = max(
                        test_size, 1
                    )  # Ensure a minimum test size
                    # Split the train dataset
                    train_dataset_dict = self.dataset["train"].train_test_split(
                        test_size=test_size
                    )
                    # Update the dataset with the new train split and add the test split
                    self.dataset = DatasetDict(
                        {
                            "train": train_dataset_dict["train"],
                            "test": train_dataset_dict["test"],
                        }
                    )
            else:
                # If the dataset is not a DatasetDict, then it must be a list
                test_size = int(len(self.dataset) * test_split_percentage)
                test_size = max(
                    test_size, 1
                )  # Ensure a minimum test size
                # Split the dataset
                train_dataset_dict = self.dataset.train_test_split(
                    test_size=test_size
                )
                # Update the dataset with the new splits
                self.dataset = train_dataset_dict

        # 3. Prepare Training Arguments
        self.training_args = TrainingArguments(
            **self.config.get("training_args", {})
        )
        self.training_args.remove_unused_columns = False


        print("\n")
        print(f"Orient - Training Arguments: {self.training_args}")
        print("\n")

        print("\n")
        print(f"Orient - Dataset: {self.dataset}")
        print("\n")

    def _decide(self):
        """
        Decides on the fine-tuning strategy, including LoRA configuration.
        """
        # 4. PEFT Configuration (LoRA)
        if self.config.get("lora"):
            self.model = prepare_model_for_kbit_training(self.model)

            # Determine the correct target modules based on model architecture
            if self.model.config.is_encoder_decoder:
                target_modules = [
                    "q", "k", "v", "o", "wi_0", "wi_1", "wo"
                ]  # For encoder-decoder models
                task_type = "SEQ_2_SEQ_LM"
            elif "mistral" in self.model_id.lower():
                target_modules = [
                    "q_proj", "k_proj", "v_proj", "o_proj"
                ]  # For Mistral models
                task_type = "CAUSAL_LM"
            else:
                target_modules = [
                    "c_attn", "c_proj", "w1", "w2"
                ]  # For other causal models
                task_type = "CAUSAL_LM"

            print("\n")
            print(f"LORA Target Modules: {target_modules}")
            print("\n")

            # Updated LoRA config
            peft_config = LoraConfig(
                lora_alpha=128,
                lora_dropout=0.05,
                r=256,
                bias="none",
                target_modules="all-linear",
                task_type="CAUSAL_LM",
            )

            self.peft_config = peft_config
            self.model = get_peft_model(self.model, peft_config)

            print("\n")
            self.model.print_trainable_parameters()
            print("\n")

    def _act(self):
        """
        Acts by preprocessing the dataset and initializing the training loop.
        """
        try:
            print("Preprocessing dataset...")
            print("Dataset before preprocessing:", self.dataset)
            # Batch preprocessing
            batch_size = 1000  # Adjust the batch size as needed
            new_train_dataset = [] #initialize an empty list to hold the transformed dataset
            for i in range(0, len(self.dataset["train"]), batch_size):
                batch = self.dataset["train"][i: i + batch_size]
                batch_dataset = Dataset.from_dict(batch) # Convert batch to Dataset
                transformed_batch_dataset = batch_dataset.map(
                    self._preprocess_function,
                    batched=True,
                    remove_columns=self.dataset["train"].column_names,
                )
                new_train_dataset.append(transformed_batch_dataset) #append the transformed dataset to the list
            # concatenate the transformed_dataset back into one dataset
            self.dataset["train"] = Dataset.from_dict({k: [item for sublist in [d[k] for d in new_train_dataset] for item in sublist] for k in new_train_dataset[0].column_names})
            print("Dataset preprocessed successfully.")
            print("Dataset after preprocessing:", self.dataset)

            print("Creating data collator...")
            data_collator = DataCollatorForLanguageModeling(
                tokenizer=self.tokenizer, mlm=False
            )
            print("Data collator created successfully.")

            print("Initializing trainer...")
            if self.dataset_name == "anthropic/hh-rlhf":
                # Use a regular Trainer for RLHF (you'll need to adapt this for your RL library)
                print("Using Trainer for RLHF.")
                self.trainer = Trainer(
                    model=self.model,
                    args=self.training_args,
                    train_dataset=self.dataset["train"],
                    eval_dataset=self.dataset.get(
                        "test"
                    ),  # Use get() to handle missing "test"
                    preprocess_logits_for_metrics=False,
                    data_collator=data_collator,  # pass the data_collator
                )
                #... (configure the Trainer for RL - this will depend on your RL library)...
            else:
                # Use SFTTrainer for supervised fine-tuning
                print("Using SFTTrainer for supervised fine-tuning.")
                self.trainer = SFTTrainer(
                    model=self.model,
                    args=self.training_args,
                    train_dataset=self.dataset["train"],
                    eval_dataset=self.dataset.get("test"),
                    preprocess_logits_for_metrics=False,
                    peft_config=self.peft_config,
                    data_collator=data_collator,  # pass the data_collator
                )
            print("Trainer initialized successfully.")
            print("Training arguments:", self.training_args)

        except Exception as e:
            print(f"An error occurred in _act(): {e}")
            raise  # Re-raise the exception to preserve the stack trace

    def _preprocess_function(self, example):
            """
            Preprocesses the data by combining context and question, tokenizing, and formatting.
            """
            if self.dataset_name == "b-mc2/sql-create-context":
                # Concatenate all messages into a single string for input.
                input_text = "".join([msg['content'] for msg in example['messages']])
                label_text = input_text
            else:
                input_text = "".join([msg['content'] for msg in example['messages']])
                label_text = example["answer"]

            # Tokenize the inputs and labels
            model_inputs = self.tokenizer(
                input_text, max_length=1024, truncation=True, padding="max_length"
            )

            with self.tokenizer.as_target_tokenizer():
                label_ids = self.tokenizer(
                    label_text, max_length=512, truncation=True, padding="max_length"
                )

            model_inputs["labels"] = label_ids["input_ids"]

            return model_inputs

    def run(self):
        """
        Executes the OODA loop and fine-tunes the language model using RL or SFT.
        """
        for pair in RL_PAIRS:
            self.model_id = pair["model_id"]
            self.dataset_name = pair["dataset_name"]

            print("Observe: Start")
            print("\n")  # Separate Observe stage
            self._observe()
            print("Observe: End")
            print("\n")  # Separates the observe stage from the Orient stage

            print("Orient: Start")
            print("\n")  # Separate Orient stage
            self._orient()
            print("Orient: End")
            print("\n")  # Separates the Orient stage from the Decide stage

            print("Decide: Start")
            print("\n")  # Separate Decide stage
            self._decide()
            print("Decide: End")
            print("\n")  # Separates the Decide stage from the Act stage

            print("Act: Start")
            print("\n")  # Separate Act stage
            self._act()
            print("Act: End")
            print("\n")  # Separates the Act stage from the training info

            print(
                f"Start: Fine-tuning for Model: {self.model_id}, "
                f"Dataset: {self.dataset_name}"
            )
            print("\n")  # Separate training information

            # Start the training
            print("Starting training...")
            self.trainer.train()
            print("Training completed.")

            print(
                f"End: Fine-tuning for Model: {self.model_id}, "
                f"Dataset: {self.dataset_name}"
            )
            print("\n")  # Separate training information

            # Explore synergies between Newton, Galileo, Einstein, and Hinton
            self._explore_synergies()

    def _explore_synergies(self):
        """
        Explores the potential synergies between Newton, Galileo, Einstein, and Hinton in the context of AI.
        """
        print("\nExploring synergies between Newton, Galileo, Einstein, and Hinton in the context of AI:")
        # Add your code here to explore the synergies.
        # This could involve generating text, analyzing data, or conducting experiments.
        # For example, you could use the fine-tuned model to generate text about the contributions of each figure to AI.
        # You could also analyze the model's performance on different tasks to see how it reflects the principles of these figures.
        # Be creative and explore the connections between these influential figures and the field of AI.
        print("Synergies exploration completed.\n")

# Example Usage with rl_pairs
config = {
    "training_args": {
        "output_dir": "./results",
        "num_train_epochs": 1,
        "per_device_train_batch_size": 1,
        "gradient_accumulation_steps": 1,
        "optim": "adamw_torch_fused",
        "learning_rate": 2e-4,
        "bf16": False,
        "max_grad_norm": 0.3,
        "warmup_ratio": 0.03,
        "lr_scheduler_type": "constant",
        "logging_steps": 1,
        "evaluation_strategy": "steps",
        "eval_steps": 1,
        "tf32": True,  # Enable TF32 for A100
    },
    "quantization": True,  # Add quantization
    "lora": True,
    "test_split_percentage": 0.25,
}

agent = FineTuningAgent(model_id=None, dataset_name=None, config=config)

# Iterate through the rl_pairs and run the fine-tuning for each pair
agent.run()  # The run() method now handles the iteration

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Observe: Start




`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Model: MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (no

KeyError: 0