<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/RLHF_TUTOR_MISTRAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets trl peft accelerate bitsandbytes --q

In [2]:
!nvidia-smi

Tue Oct 29 01:47:09 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   51C    P8              13W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import warnings
warnings.filterwarnings("ignore")

warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")

import os
os.environ["WANDB_DISABLED"] = "true"  # Disable wandb
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import gc  # Import the garbage collector

from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, EarlyStoppingCallback, TrainerCallback, TrainerState, TrainerControl # Import necessary modules
from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast # Add this import

from datasets import load_dataset
from trl import RewardTrainer, RewardConfig
from accelerate import Accelerator
import torch
from peft import LoraConfig, get_peft_model  # Import PEFT modules


import torch.nn as nn  # Import the neural network module


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)


# Load the LLaMA 2 model and tokenizer
#model_name = "meta-llama/Llama-2-7b-hf"
#tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the Mistral 7B model and tokenizer
model_name = "mistralai/Mistral-7B-v0.1"  # Changed to Mistral 7B
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token

# Quantization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Use 4-bit quantization for lower memory usage
    bnb_4bit_use_double_quant=True,  # Enable double quantization
    bnb_4bit_quant_type="nf4",  # Use nf4 quantization type
    #bnb_4bit_compute_dtype=torch.float16  # Set compute dtype to float16
    bnb_4bit_compute_dtype=torch.bfloat16,  # Change to bfloat16
)

# Load the model with quantization
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=1,
    quantization_config=quantization_config,
    device_map="auto",
)

# Set pad_token_id in the model config
model.config.pad_token_id = tokenizer.pad_token_id

# PEFT configuration
peft_config = LoraConfig(
    r=8,  # Dimensionality of the low-rank matrices
    lora_alpha=16,  # Scaling factor
    lora_dropout=0.05,  # Dropout probability
    bias="none",  # No bias for the PEFT adapters
    task_type="SEQ_CLS",  # Sequence classification task
)

# Add PEFT adapters to the model
model = get_peft_model(model, peft_config)
print('\n')
print('Print the number of trainable parameters')
model.print_trainable_parameters()  # Print the number of trainable parameters
print('\n\n')

# Store the original model with PEFT adapters
original_model = model

# Load the Anthropic HH-RLHF dataset
dataset = load_dataset("Anthropic/hh-rlhf")

def format_data(example):
  if isinstance(example["chosen"], list):
    chosen_text = " ".join([item["text"] for item in example["chosen"]])
  else:
    chosen_text = example["chosen"]  # If it's a string, use it directly

  if isinstance(example["rejected"], list):
    rejected_text = " ".join([item["text"] for item in example["rejected"]])
  else:
    rejected_text = example["rejected"]  # If it's a string, use it directly

  # Tokenize the chosen and rejected texts with padding
  chosen_encoding = tokenizer(chosen_text, truncation=True, max_length=512, padding="max_length")
  rejected_encoding = tokenizer(rejected_text, truncation=True, max_length=512, padding="max_length")

  return {
      "input_ids_chosen": chosen_encoding["input_ids"],
      "attention_mask_chosen": chosen_encoding["attention_mask"],
      "input_ids_rejected": rejected_encoding["input_ids"],
      "attention_mask_rejected": rejected_encoding["attention_mask"],
  }

# Format the dataset
dataset = dataset.map(format_data)

# Split the dataset into train and eval sets
#train_dataset = dataset["train"].select(range(100000))  # Select first 100k examples for training
#eval_dataset = dataset["train"].select(range(100000, 110000))  # Select next 10k examples for evaluation
#Total steps = (Number of training examples) / (Effective batch size)
#             = 100,000 / 8
#             = 12,500

# Split the dataset into train and eval sets FOR POC
#train_dataset = dataset["train"].select(range(10000))  # Select first 10k examples
#eval_dataset = dataset["train"].select(range(10000, 11000))  # Select next 1k examples
#By making these changes, you'll reduce the total training steps from 12,500 to 1,250.
#This will significantly shorten the runtime for your POC, allowing you to experiment and
#iterate more quickly.


# Split the dataset (using 1,000 examples for the POC)
train_dataset = dataset["train"].select(range(1000))
eval_dataset = dataset["train"].select(range(1000, 2000))

# Training arguments as RewardConfig - Modified
training_args = RewardConfig(
    per_device_train_batch_size=1,  # Reduced batch size
    gradient_accumulation_steps=8,  # Increased gradient accumulation
    learning_rate=1e-6,  # Further reduced learning rate
    #fp16=False,  # Disable fp16
    fp16=True,  # Enable fp16 for potentially better performance
    logging_steps=25,
    output_dir="reward_model",
    num_train_epochs=1,  # You can increase this for better results
    report_to="none",  # Disable wandb reporting
    load_best_model_at_end=True,  # Ensure the best model is loaded
    evaluation_strategy="steps",  # Evaluate and save every "steps"
    save_strategy="steps",
    remove_unused_columns=False  # Prevent removal of unused columns
)

# Initialize the Accelerator
accelerator = Accelerator()

# Prepare the model and data loaders with accelerate
model, train_dataloader, eval_dataloader = accelerator.prepare(
    model, train_dataset, eval_dataset
)

# Train the reward model
trainer = RewardTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=train_dataloader,  # Use the prepared train data loader
    eval_dataset=eval_dataloader,  # Use the prepared eval data loader
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]  # Add early stopping
)

####### NEW ####

# Initialize the optimizer with a lower learning rate and gradient clipping
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)  # Slightly higher learning rate




# Create a gradient clipping function
#def clip_gradients(params):
#    torch.nn.utils.clip_grad_norm_(params, 1.0)  # Standard clip value

# Create a gradient clipping function with an adjustable norm (corrected)
def clip_gradients(params, clip_norm=1.0):
    torch.nn.utils.clip_grad_norm_(params, clip_norm)  # Correct usage of clip_grad_norm_

# Override the training step function to apply gradient clipping more frequently
class GradientClippingCallback(TrainerCallback):
    def on_step_end(self, args: training_args, state: TrainerState, control: TrainerControl, **kwargs):
        clip_gradients(kwargs["model"].parameters(), clip_norm=0.5)  # Example: Clip with norm 0.5
        return control



trainer.add_callback(GradientClippingCallback())  # Add the callback to the trainer


from tqdm import tqdm  # Import tqdm for the progress bar


# Calculate the total number of training steps across all epochs using enumerate
total_steps = sum(1 for _ in enumerate(train_dataloader)) * training_args.num_train_epochs

# Create a single tqdm progress bar for the entire training process
progress_bar = tqdm(total=total_steps, desc="Training Progress", leave=False)

In [5]:
print('\n')
print('Print the number of trainable parameters')
model.print_trainable_parameters()  # Print the number of trainable parameters
print('\n\n')

print('\n')
# Training loop with adjusted input handling and tensor conversion
for epoch in range(training_args.num_train_epochs):
    # Create a tqdm progress bar for the training loop
    #progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{training_args.num_train_epochs}", leave=False)

    for step, batch in enumerate(train_dataloader):

        #if int(step)%100==0:
          #print(f"Epoch {epoch}, Step {step}")

        # --- Check for NaN in input tensors ---
        for key, value in batch.items():
            if torch.is_tensor(value) and torch.isnan(value).any():
                print(f"Warning: NaN values found in input tensor '{key}' at epoch {epoch}, step {step}")
                # Handle NaN values (e.g., replace with 0, skip the batch, etc.)
                # Example: Replace NaN with 0
                batch[key] = torch.nan_to_num(value, nan=0.0)

        # Extract the input tensors from the batch
        input_ids_chosen = batch.get("input_ids_chosen")
        attention_mask_chosen = batch.get("attention_mask_chosen")
        input_ids_rejected = batch.get("input_ids_rejected")
        attention_mask_rejected = batch.get("attention_mask_rejected")



        # Convert lists to tensors and move to device
        # Ensure tensors are on the correct device and have the correct data type
        input_ids_chosen = torch.tensor(input_ids_chosen, device=device, dtype=torch.long)
        attention_mask_chosen = torch.tensor(attention_mask_chosen, device=device, dtype=torch.long)
        input_ids_rejected = torch.tensor(input_ids_rejected, device=device, dtype=torch.long)
        attention_mask_rejected = torch.tensor(attention_mask_rejected, device=device, dtype=torch.long)



        # Reshape input_ids before concatenating
        input_ids_chosen = input_ids_chosen.unsqueeze(0)
        input_ids_rejected = input_ids_rejected.unsqueeze(0)
        input_ids = torch.cat([input_ids_chosen, input_ids_rejected], dim=0)

        # Reshape attention masks before concatenating
        attention_mask_chosen = attention_mask_chosen.unsqueeze(0)
        attention_mask_rejected = attention_mask_rejected.unsqueeze(0)
        #attention_mask = torch.cat([attention_mask_chosen, attention_mask_rejected], dim=0)


        attention_mask = torch.cat([attention_mask_chosen, attention_mask_rejected], dim=0)


        # Pass the concatenated inputs to the model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        # --- Modified loss calculation and gradient handling ---
        logits = outputs.logits

        # 1: Check for NaN in logits and replace with 0
        #logits = torch.nan_to_num(logits, nan=0.0)


         #--- Check for NaN in logits and handle them ---
        if torch.isnan(logits).any():
            print("Warning: NaN values found in logits!")
            # Handle NaN values (e.g., replace with 0, skip the batch, etc.)
            # Example: Replace NaN with 0
            logits = torch.nan_to_num(logits, nan=0.0)


        # 2: Apply sigmoid to get probabilities
        probs = torch.sigmoid(logits)

        # 3: Ensure probs are float32 before clipping
        probs = probs.type(torch.float32)

        # 3a: Cast logits to float32 before loss calculation
        logits = logits.type(torch.float32)  # Cast logits to float32


        # 4: Clip probabilities to [0, 1]
        #probs = torch.clamp(probs, 0.0, 1.0)

        #--- Ensure all tensors have the same dtype (float32) ---
        probs = torch.sigmoid(logits)
        #probs = torch.clamp(probs, 0.0, 1.0) # Clip probabilities (optional)


        # 5: Create labels with the same shape as logits (probs) and move to device
        labels = torch.zeros(logits.size(), dtype=torch.float32, device=device)
        labels[0, 0] = 1  # Label for chosen text
        labels[1, 0] = 0  # Label for rejected text

        # 6: Print shapes and dtypes for debugging
        #print("Logits shape:", logits.shape, "dtype:", logits.dtype)
        #print("Probs shape:", probs.shape, "dtype:", probs.dtype)
        #print("Labels shape:", labels.shape, "dtype:", labels.dtype)


        # 7xs: Use BCELoss
        loss_fn = nn.BCELoss()
        loss = loss_fn(probs, labels)


        # --- Gradient Scaling ---
        loss = loss / training_args.gradient_accumulation_steps  # Scale the loss

        if int(step)%100==0:
          #print(f"Loss: {loss}")
          print('\n')
          print(f"Epoch {epoch}, Step {step}, Loss: {loss}")
          print('\n')


        # --- More frequent gradient clipping ---
        clip_gradients(model.parameters(), 0.5)  # Clip every step

        # Clip gradients more frequently (e.g., every 10 steps)
        #if step % 10 == 0:
        #    clip_gradients(model.parameters(), 0.5)  # Correct usage



        # --- Gradient Handling ---
        # 7: Reduce learning rate (if needed)
        #optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6) # Example: Reduced lr

        # 8: Lower gradient accumulation steps (if needed)
        # Example: Reduced accumulation to potentially improve stability
        #training_args.gradient_accumulation_steps = 4


        # --- Backpropagation and Optimization ---
        #loss.backward() # Calculate gradients

        # 9: Use gradient clipping
        #clip_gradients(params=model.parameters())  # Clip gradients

        #optimizer.step() # Update model parameters using scaler
        #optimizer.zero_grad() # Reset gradients after each step

        # Update the progress bar after each step
        #print('\n')
        progress_bar.update(1)  # Manually update
        #print('\n')

# Save the trained model
trainer.save_model("reward_model")



Print the number of trainable parameters
trainable params: 3,411,968 || all params: 7,114,076,160 || trainable%: 0.0480







Training Progress:   2%|▏         | 18/1000 [00:55<46:17,  2.83s/it]



Epoch 0, Step 0, Loss: 0.21893148124217987




Training Progress:  12%|█▏        | 118/1000 [03:27<22:13,  1.51s/it]



Epoch 0, Step 100, Loss: 0.6079674363136292




Training Progress:  22%|██▏       | 218/1000 [05:59<19:30,  1.50s/it]



Epoch 0, Step 200, Loss: 0.09748947620391846




Training Progress:  32%|███▏      | 318/1000 [08:31<17:08,  1.51s/it]



Epoch 0, Step 300, Loss: 0.027507081627845764




Training Progress:  42%|████▏     | 418/1000 [11:03<14:48,  1.53s/it]



Epoch 0, Step 400, Loss: 0.06258226931095123




Training Progress:  52%|█████▏    | 518/1000 [13:35<12:05,  1.51s/it]



Epoch 0, Step 500, Loss: 0.04816298559308052




Training Progress:  62%|██████▏   | 618/1000 [16:08<09:34,  1.50s/it]



Epoch 0, Step 600, Loss: 0.09454096853733063




Training Progress:  72%|███████▏  | 718/1000 [18:40<07:10,  1.53s/it]



Epoch 0, Step 700, Loss: 0.11152325570583344




Training Progress:  82%|████████▏ | 818/1000 [21:12<04:45,  1.57s/it]



Epoch 0, Step 800, Loss: 0.13049648702144623




Training Progress:  92%|█████████▏| 918/1000 [23:44<02:04,  1.52s/it]



Epoch 0, Step 900, Loss: 0.30196622014045715




Training Progress: 1017it [26:14,  1.51s/it]

In [7]:
# Test cases
def evaluate_example(prompt, chosen, rejected):
  inputs = tokenizer(
      [f"{prompt} {chosen}", f"{prompt} {rejected}"],
      return_tensors="pt",
      padding=True
  ).to(accelerator.device)  # Move inputs to the appropriate device

  #print("Input IDs:", inputs["input_ids"])
  #print("Attention Mask:", inputs["attention_mask"])

  outputs = model(**inputs)
  chosen_score = outputs.logits[0].item()
  rejected_score = outputs.logits[1].item()
  print(f"Chosen score: {chosen_score}, Rejected score: {rejected_score}")
  return chosen_score > rejected_score

# Example usage
prompt = "What is the capital of France?"
chosen = "Paris"
rejected = "London"
print('\n')
print(f"Prompt: {prompt}, Chosen: {chosen}, Rejected: {rejected}")

print('\n')
if evaluate_example(prompt, chosen, rejected):
  print("Test passed!")
else:
  print("Test failed.")



Prompt: What is the capital of France?, Chosen: Paris, Rejected: London


Chosen score: -3.775390625, Rejected score: -6.671875
Test passed!
