In [None]:
import torch 
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import EarlyStoppingCallback
from trl import AutoModelForCausalLMWithValueHead
from trl import PPOTrainer
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
import pandas as pd
from datasets import Dataset

In [None]:
# CONFIGURATION

# --- 1. Set the Input Policy Model ---
# For direct PPO (Experiments 1A & 1B), use the base model ID.
POLICY_MODEL_ID = "frhew/sigdial_ft_a2" 
# For SFT-enhanced PPO (Experiments 2A & 2B), comment the line above and use 
# the SFT model output directory instead.
# POLICY_MODEL_ID = SFT_MODEL_OUTPUT_DIR 


# --- 2. Define the Instruction Prompt Template ---
def format_query(example):
    """Applies the chosen instruction template to the input sentence."""
    
    # DETAILED PROMPT (for Experiments 1B, 2A, 2B)
    # Provides more context and explicitly mentions "Leichter-Sprache-Regeln".
    example['query'] = f"Aufgabe: Vereinfache den folgenden Satz nach Leichter-Sprache-Regeln.\nSatz: {example['original']}"
    
    # SIMPLE PROMPT (for Experiment 1A)
    # A concise, no-context instruction. Uncomment the line below to use it.
    # example['query'] = f"Vereinfache diesen Satz: {example['query']}"
    
    return example


# --- 3. Define the Output Model Name ---
# Use a descriptive name for the final PPO model saved after training.
MODEL_FILE_NAME = "PPO_model_2K_4E_28"


# --- 4. Set Training Parameters ---
DATASET_SIZE = 2048
EPOCH = 4


# --- 5. Verify Data Paths ---
# Ensure these paths point to your split dataset files.
MASTER_TRAIN_DATASET = "sft_split_dataset/train.csv"
MASTER_EVAL_DATASET = "sft_split_dataset/eval.csv"
MASTER_TEST_DATASET = "sft_split_dataset/test.csv"

In [None]:
# --- Define Model Paths ---

## Are you implementing direct PPO training (1) or with SFT in-between (2)

# The policy model is the base LLM we want to fine-tune with PPO.
POLICY_MODEL_ID = "frhew/sigdial_ft_a2" # --- (1) directly pluck in model

### IF YOU'RE USING SFT MODEL, UNCOMMENT THE NEXT TWO LINES --- (2)
#SFT_MODEL_OUTPUT_DIR = "my_sft_tuned_model_v1"
#POLICY_MODEL_ID = SFT_MODEL_OUTPUT_DIR

# Define the pre-trained reward model here.
RM_PATH = "rm_out_rules_heavy_final"



# --- Load Policy and Reference Models ---
# The policy model is loaded with a value head for PPO training.
# policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(POLICY_MODEL_ID)
# the above is loaded LATER with lora -- IGNORE

# The reference model is a frozen copy of the original policy.
# --- Load the standard reference model (without LoRA) ---
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
                POLICY_MODEL_ID,
                torch_dtype=torch.bfloat16, # <-- Use lower precision
)

# the above line takes 30min tot downlaod 7 shareds each close to 5GB

In [None]:
# --- Load Tokenizers ---
policy_tokenizer = AutoTokenizer.from_pretrained(POLICY_MODEL_ID, padding_side='left') # Add padding_side
reward_tokenizer = AutoTokenizer.from_pretrained(RM_PATH, padding_side='left')     # Add padding_side

print("All models and tokenizers loaded successfully.")

# Set padding token for both tokenizers if it's not already set.
#TODO assess which padding is required 
if policy_tokenizer.pad_token is None:
    policy_tokenizer.pad_token = policy_tokenizer.eos_token
if reward_tokenizer.pad_token is None:
    reward_tokenizer.pad_token = reward_tokenizer.eos_token

In [None]:
# --- Load Custom Reward Model ---
reward_model = AutoModelForSequenceClassification.from_pretrained(RM_PATH)

In [None]:
# --- Configure Hardware Device ---
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Move the reward model to the correct device and set it to evaluation mode.
reward_model.to(device) # 'device' was set to "mps" in a previous cell
reward_model.eval()

In [None]:
# ==============================================================================
#  2. PREPARE DATASET FOR PPO
# ==============================================================================


# Load the dataset containing the prompts for the policy model.
df = pd.read_csv(MASTER_TRAIN_DATASET, index_col=0)
train_df = df
print(f"Training set size: {len(train_df)}")
ppo_dataset = Dataset.from_pandas(train_df)


In [None]:

# Defined in the beginning (Configuration Step)
# Apply the formatting function to every example in the dataset.
# The .map() function processes each row and updates the 'query' column.
###applying this on EVAL happens later on
ppo_dataset = ppo_dataset.map(format_query)

# You can now check an example to see the new format
print("Example of a formatted query:")
print(ppo_dataset[0]['query'])

In [None]:
# Create the Hugging Face Dataset from a smaller, random sample.
# Using 1024-2048 prompts is often sufficient for a PPO run.
# Using a smaller, diverse set of prompts is often sufficient to train a robust policy, 
# as it provides enough variety to generate the experiences the model needs to learn from, 
# without the massive memory cost.
#4,096, 6,144
ppo_dataset = Dataset.from_pandas(train_df.sample(n=DATASET_SIZE, random_state=42))

def tokenize_query(examples):
    """Tokenizes the 'query' column for the PPO trainer."""
    # Remove `padding="max_length"`. The PPOTrainer's dataloader will handle padding dynamically.
    #return policy_tokenizer(examples["query"], truncation=True, max_length=60)
    # """Tokenizes the 'query' column for the PPO trainer."""
    return policy_tokenizer(examples["query"], 
                            truncation=True, 
                            padding="max_length", 
                            max_length=40)

# Tokenize the dataset and format it for PyTorch.
ppo_dataset = ppo_dataset.map(tokenize_query, batched=True,
                              remove_columns=['final_simplification', 'applied_rules', 'uid', 'query'])
ppo_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [None]:
# ==============================================================================
#  3. CONFIGURE PPO (WITH LORA)
# ==============================================================================

from trl import PPOConfig, PPOTrainer
from peft import LoraConfig

# --- Define the LoRA configuration ---
# This object defines the configuration for LoRA (Low-Rank Adaptation),
# a technique for efficient model fine-tuning.
lora_config = LoraConfig(
    # r: The rank of the LoRA matrices. This is the most important parameter.
    # It controls the number of trainable parameters. A higher rank means more
    # expressive power but also more parameters to train. Common values are 8, 16, 32.
    r=16,
    
    # lora_alpha: A scaling factor for the LoRA updates. It's like a learning
    # rate for the adapter layers. A common practice is to set alpha to twice the rank (2 * r).
    lora_alpha=32,
    
    # lora_dropout: Applies dropout regularization to the LoRA layers. This helps
    # prevent overfitting by randomly setting a fraction of adapter activations to zero.
    lora_dropout=0.05,
    
    # bias: Determines which bias parameters in the model are trained. "none" is a
    # common setting that freezes all bias terms and only trains the new LoRA weights.
    bias="none",
    
    # task_type: Specifies the type of model you are adapting. This is crucial for
    # the PEFT library to correctly identify and modify the right layers.
    # "CAUSAL_LM" is correct for GPT-style models used for text generation.
    task_type="CAUSAL_LM",
)


In [None]:
# # --- Load the Policy Model with PEFT config ---
# # TRL's special model class can directly accept a peft_config.
# # This ensures the model is created in the exact format the PPOTrainer expects.
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    POLICY_MODEL_ID,
    torch_dtype=torch.bfloat16, # <-- Use lower precision
    peft_config=lora_config  # <-- Pass the LoRA config here during loading
)

# Enable gradient checkpointing to trade a little computation for a lot of memory.
#policy_model.gradient_checkpointing_enable()





In [None]:
# --- Define the PPOConfig  ---
# A recommended, stable set of hyperparameters for a single run without a full search.
config = PPOConfig(
    model_name=POLICY_MODEL_ID,
    # A learning rate of 1.4e-5 is a common and effective starting point for fine-tuning.
    learning_rate=1.4e-5,
    # The number of times you iterate over the collected PPO experiences in each optimization step.
    ppo_epochs= EPOCH, #default is 4
    # The number of prompts to collect before performing an optimization.
    batch_size=16, #32,
    # The PPO batch is split into smaller mini-batches for the update.
    mini_batch_size=4, #8,

    # --- MEMORY FIX 2: Add Gradient Accumulation ---
    # This will process 4 mini-batches before performing a model update.
    # Effective Batch Size = 4 (mini_batch_size) * 4 (accumulation_steps) = 16
    gradient_accumulation_steps=4,

    # Disables external logging integrations like WandB.
    log_with=None,
    ### --- FIX
    # Use the full KL penalty to ensure stability with LoRA-adapted modles. This prevents the negative KL divergence by applying the KL calculation more robustly.
    # Use a KL penalty to stop the model from deviating too far from the original.
    kl_penalty="full", #changed frol kl
    # A slightly lower target KL can improve stability and prevent the model from changing too drastically.
    target_kl= 0.05,
    # --- Recommended additions for stability ---
    # Normalizes the reward scores, which is a key practice for stable PPO training.
    use_score_scaling=True,
    # The coefficient for the value function loss in the PPO update.
    vf_coef=0.1,
)


In [None]:
# --- Initialize the PPOTrainer ---
# Initialize the PPOTrainer with all our components.
ppo_trainer = PPOTrainer(
    config=config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=policy_tokenizer,
    dataset=ppo_dataset,
)

In [None]:
# ==============================================================================
#  4. THE PPO TRAINING LOOP
# ==============================================================================
from trl import PPOTrainer
from tqdm.auto import tqdm

# Generation settings for creating the simplified responses.
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": policy_tokenizer.eos_token_id,
    "max_new_tokens": 40, # Control the length of the simplification
}

mean_rewards = []

# The main training loop
for epoch in range(config.ppo_epochs):
    #for batch in tqdm(ppo_trainer.dataloader, f"PPO Epoch {epoch+1}"):
    for i, batch in tqdm(enumerate(ppo_trainer.dataloader), f"PPO Epoch {epoch+1}"):
        
        # A. Get prompts (queries) from the batch.
        query_tensors = batch['input_ids']

        # B. Generate responses from the policy model.
        # THE FIX: Convert the 2D batch tensor into a list of 1D tensors.
        queries_list = [q for q in query_tensors]
        response_tensors = ppo_trainer.generate(queries_list, **generation_kwargs)
        batch['response'] = policy_tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

        # C. Score the responses with your custom reward model.
        # This is where the "Feedback" from your diagram happens.
        texts_to_score = batch['response']
        rewards = []
        with torch.no_grad():
            # Tokenize for the reward model
            inputs = reward_tokenizer(texts_to_score, return_tensors="pt", padding=True, truncation=True).to(device)
            # Get the raw score (logits)
            reward_logits = reward_model(**inputs).logits.squeeze(-1)
            # Store rewards as a list of PyTorch tensors.
            rewards = [r for r in reward_logits]

            # Log the mean reward for this batch 
            # This is the key metric to watch.
            batch_mean_reward = torch.tensor(rewards).mean().item()
            mean_rewards.append(batch_mean_reward)
            # Print the mean reward every 10 steps to monitor progress
            if i % 10 == 0:
                print(f"Step {i}, Mean Reward: {batch_mean_reward:.4f}")
            
        # D. Perform the PPO optimization step.
        # This updates the policy model's weights based on the rewards, while
        # also calculating the KL penalty against the reference model.
        #stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        # Use the 'queries_list' you already created for the generate step
        stats = ppo_trainer.step(queries_list, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)

print("PPO Training Finished!")

# Save the final tuned model.
ppo_trainer.save_pretrained(MODEL_FILE_NAME)
print("Model saved to 'my_ppo_tuned_model'")


# Plot the reward to observe the mean trend development
# import matplotlib.pyplot as plt
# plt.plot(mean_rewards)
# plt.title("Mean Reward per Batch")
# plt.xlabel("PPO Step")
# plt.ylabel("Mean Reward")
# plt.show()