In [1]:
import gc
import re
import torch
import os
import json
import pandas as pd
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from typing import Optional, List, Dict, Callable

In [2]:
!pip install trl
!pip install git+https://github.com/huggingface/accelerate.git
!pip install git+https://github.com/huggingface/transformers.git
!pip install git+https://github.com/UKPLab/sentence-transformers.git
!pip install datasets
!pip install peft
!pip install sentencepiece
!pip install -U bitsandbytes
!pip install nbresuse
!jupyter serverextension enable --py nbresuse
!jupyter labextension install @jupyterlab/statusbar

Collecting trl
  Downloading trl-0.19.0-py3-none-any.whl.metadata (10 kB)
Collecting accelerate>=1.4.0 (from trl)
  Downloading accelerate-1.8.1-py3-none-any.whl.metadata (19 kB)
Collecting datasets>=3.0.0 (from trl)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers>=4.51.0 (from trl)
  Downloading transformers-4.53.0-py3-none-any.whl.metadata (39 kB)
Collecting huggingface_hub>=0.21.0 (from accelerate>=1.4.0->trl)
  Downloading huggingface_hub-0.33.2-py3-none-any.whl.metadata (14 kB)
Collecting safetensors>=0.4.3 (from accelerate>=1.4.0->trl)
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=3.0.0->trl)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=3.0.0->trl)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from da

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from trl import PPOConfig, PPOTrainer
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import bitsandbytes.optim as bnb_optim

In [4]:
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

# Multi Component Reward Function

In [5]:
"""
This code uses REINFORCE algorithm to train the model. We use
llama 3.2-3B finetuned model.
"""
class MultiComponentRewardFunction:
    """
    Implements the composite reward function R_total = w_b * R_binary - w_h * P_hacker - w_e * P_exploit
    """

    def __init__(self,
                 w_b = 1.0,  # Weight for binary reward
                 w_h = 0.5,  # Weight for hacker penalty
                 w_e = 0.3,  # Weight for exploit penalty.
                 w_f = 0.3,
                 tau_hacker = 0.7,  # Threshold for hacker penalty
                 tau_exploit = 15,  # Word count threshold for exploit penalty
                 lambda_e = 1.0,  # Fixed penalty value for exploit
                 embedding_model = 'all-MiniLM-L6-v2'):
        """
        Initialize the multi-component reward function.
        """
        self.w_b = w_b
        self.w_h = w_h
        self.w_e = w_e
        self.w_f = w_f
        self.tau_hacker = tau_hacker
        self.tau_exploit = tau_exploit
        self.lambda_e = lambda_e

        # Load sentence transformer for semantic similarity
        self.sentence_model = SentenceTransformer(embedding_model)

        # Define prototypical answer-leaking phrases
        self.answer_leaking_phrases = [
            "the correct answer is", "the answer is definitely", "the choice is clearly",
            "option A is the right one", "option B is the right one", "option C is the right one",
            "option D is the right one",
            "we can conclude the answer is", "the solution is A", "the solution is B", "the solution is C",
            "the solution is D",
            "therefore the answer is", "so the correct choice is", "the final answer is",
            "answer: A", "answer: B", "answer: C", "answer: D"
        ]

        # Pre-compute embeddings for answer leaking phrases
        self.leak_embeddings = self.sentence_model.encode(self.answer_leaking_phrases)

    def extract_answer_choice(self, generation: str) -> Optional[str]:
        """Extract the final answer choice from <answer> tags."""
        answer_match = re.search(r'<answer>(.*?)</answer>', generation, re.IGNORECASE | re.DOTALL)
        if not answer_match:
            return None
        answer_content = answer_match.group(1).strip()
        match = re.search(r'\b([A-D])\b', answer_content, re.IGNORECASE)
        return match.group(1).upper() if match else None

    def extract_think_content(self, generation: str) -> str:
        think_match = re.search(r'<think>(.*?)</think>', generation, re.DOTALL | re.IGNORECASE)
        if think_match:
            content = think_match.group(1).strip()
            return re.sub(r'^Reasoning:\s*', '', content)
        return ""

    def extract_pre_think_content(self, generation: str) -> str:
        """Extract content that appears before the <think> tag."""
        think_start = re.search(r'<think>', generation, re.IGNORECASE)
        return generation[:think_start.start()].strip() if think_start else generation.strip()

    def validate_format(self, generation: str) -> bool:
        """Check if generation follows required <think>...</think><answer>...</answer> format with content."""
        think_match = re.search(r'<think>\s*(.*?)\s*</think>', generation, re.DOTALL | re.IGNORECASE)
        has_think_content = think_match and think_match.group(1).strip()
        has_answer = bool(re.search(r'<answer>\s*[A-D]\s*</answer>', generation, re.IGNORECASE))
        return bool(has_think_content) and has_answer

    def compute_binary_reward(self, generation, correct_answer):
        """Returns +1 for correct, 0 for incorrect (with valid format), -1 for wrong format."""
        if not self.validate_format(generation):
            return -1.0
        predicted_answer = self.extract_answer_choice(generation)
        if predicted_answer is None:
            return -1.0
        return 1.0 if predicted_answer == correct_answer.upper() else 0.0

    def compute_hacker_penalty(self, generation):
        """Compute P_hacker penalty using semantic similarity."""
        think_content = self.extract_think_content(generation)
        if not think_content:
            return 0.0
        think_embedding = self.sentence_model.encode([think_content])
        similarities = cosine_similarity(think_embedding, self.leak_embeddings)[0]
        max_similarity = np.max(similarities)
        return float(max_similarity) if max_similarity > self.tau_hacker else 0.0

    def compute_exploit_penalty(self, generation):
        """Compute P_exploit penalty based on pre-think word count."""
        pre_think_content = self.extract_pre_think_content(generation)
        word_count = len(pre_think_content.split())
        return self.lambda_e if word_count > self.tau_exploit else 0.0

    def soft_format_bonus(self, generation: str) -> float:
        bonus = 0.0
        if '<think>' in generation.lower() and '</think>' in generation.lower():
            bonus += 0.25
        if re.search(r'<answer>\s*[A-D]\s*</answer>', generation, re.IGNORECASE):
            bonus += 0.25
        return bonus

    def compute_total_reward(self, generation, correct_answer):
        """Compute the complete multi-component reward."""
        r_binary = self.compute_binary_reward(generation, correct_answer)
        p_hacker = self.compute_hacker_penalty(generation)
        p_exploit = self.compute_exploit_penalty(generation)
        r_total = self.w_b * r_binary - self.w_h * p_hacker - self.w_e * p_exploit
        r_total = r_total + self.w_f * self.soft_format_bonus(generation) # added
        r_normalized = np.tanh(r_total)
        return {
            'r_binary': r_binary, 'p_hacker': p_hacker, 'p_exploit': p_exploit,
            'r_total': r_total, 'r_normalized': r_normalized
        }

# Data Processor

In [6]:
class MedQADataProcessor:
    """Process MedQA dataset for training."""

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def format_prompt(self, question: str, options: Dict[str, str]) -> str:
        """Format question and options into the required prompt template."""
        formatted_options = ""
        for letter, option in options.items():
            formatted_options += f"{letter}. {option}\n"
        input_text = f"{question}\n\n{formatted_options.strip()}"
        prompt = f"""You are a medical expert taking the USMLE exam. Given the clinical scenario below, respond with your reasoning in a <think></think> tag and your final answer choice (A, B, C, or D) in an <answer></answer> tag.
        Scenario:
        {input_text}
        Format:
        <think>your step-by-step clinical reasoning goes here</think>
        <answer>A</answer>  # Replace A with your final answer choice
        Your response:"""
        return prompt

    def load_medqa_data(self, file_path: str) -> List[Dict]:
        """Load and process MedQA dataset from a JSON file."""
        with open(file_path, 'r') as f:
            data = json.load(f)
        processed_data = []
        for item in data:
            processed_item = {
                'question': item['question'],
                'options': item['options'],
                'correct_answer': item['answer_idx'],
                'prompt': self.format_prompt(item['question'], item['options'])
            }
            processed_data.append(processed_item)
        return processed_data

In [7]:
class BaselineNetwork(nn.Module):
    """Improved baseline network with better numerical stability."""

    def __init__(self, input_dim: int, hidden_dim: int = 256):
        super().__init__()

        # Add input normalization layer
        self.input_norm = nn.LayerNorm(input_dim)

        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),  # Stabilize intermediate activations
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),  # Stabilize intermediate activations
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )

        # Initialize weights properly to prevent extreme values
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights to prevent extreme initial values."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                # Use smaller initialization to prevent explosion
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        x = x.float()
        x = torch.clamp(x, min=-10.0, max=10.0)
        x = self.input_norm(x)
        output = self.network(x)
        output = torch.clamp(output.squeeze(-1), min=-5.0, max=5.0)
        return output

In [8]:
class MedREINFORCETrainer:
    def __init__(self, model_path: str, reward_config: Dict = None, use_baseline: bool = True):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Load tokenizer and model
        print(f"Loading tokenizer and model from local path: {model_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.eos_token is None:
            self.tokenizer.eos_token = self.tokenizer.pad_token or self.tokenizer.unk_token
        if self.tokenizer.eos_token_id is None:
            self.tokenizer.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )

        # Load model with memory optimizations
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map={"": self.device},
            quantization_config=quantization_config
        )

        # Enable gradient checkpointing to save memory
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()

        self.model.train()

        # Get the actual hidden dimension from the model
        hidden_dim = self.model.config.hidden_size
        print(f"Model hidden dimension: {hidden_dim}")

        # Initialize reward function and data processor
        reward_config = reward_config or {}
        self.reward_function = MultiComponentRewardFunction(**reward_config)
        self.data_processor = MedQADataProcessor(self.tokenizer)

        self.use_baseline = use_baseline
        if self.use_baseline:
            self.baseline_network = BaselineNetwork(input_dim=hidden_dim).to(self.device)
            # Keep baseline network in same precision as model
            self.baseline_network = self.baseline_network
            self.baseline_optimizer = optim.Adam(self.baseline_network.parameters(), lr=1e-4)

        # Initialize policy optimizer
        self.policy_optimizer = bnb_optim.AdamW8bit(self.model.parameters(), lr=1.41e-5)
        self.reward_history = []

    def generate_response_with_logprobs(self, prompt: str, max_new_tokens: int = 512):
        """
        Memory-efficient generation with log probabilities.
        """
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        input_length = inputs['input_ids'].shape[1]

        # Generate response WITHOUT gradients to save memory
        self.model.eval()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        generated_ids = outputs.sequences[0][input_length:]
        response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        if len(generated_ids) == 0:
            return response_text, torch.tensor([], device=self.device, requires_grad=True), torch.zeros(self.model.config.hidden_size, device=self.device, dtype=torch.float16, requires_grad=True)

        # Now compute log probabilities WITH gradients for ONLY generated tokens
        self.model.train()

        # Process tokens in smaller chunks to save memory
        log_probs = []
        chunk_size = 32  # Process 32 tokens at a time

        for chunk_start in range(0, len(generated_ids), chunk_size):
            chunk_end = min(chunk_start + chunk_size, len(generated_ids))
            chunk_tokens = generated_ids[chunk_start:chunk_end]

            # Create context for this chunk
            context_start = max(0, input_length + chunk_start - 50)  # Keep 50 tokens of context
            context_tokens = outputs.sequences[0][context_start:input_length + chunk_end]

            # Forward pass for this chunk only
            chunk_inputs = {
                'input_ids': context_tokens.unsqueeze(0),
                'attention_mask': torch.ones_like(context_tokens).unsqueeze(0)
            }

            chunk_outputs = self.model(**chunk_inputs)
            chunk_logits = chunk_outputs.logits[0]

            # Calculate log probs for this chunk with numerical stability
            context_len = len(context_tokens) - len(chunk_tokens)
            for i, token_id in enumerate(chunk_tokens):
                logit_idx = context_len + i - 1
                if logit_idx >= 0 and logit_idx < chunk_logits.shape[0]:
                    # Clamp logits to prevent overflow/underflow
                    logits = chunk_logits[logit_idx].clamp(min=-50, max=50)
                    log_prob = torch.log_softmax(logits, dim=-1)[token_id]

                    # Check for invalid values and handle them
                    if torch.isnan(log_prob) or torch.isinf(log_prob):
                        print("Warning: Invalid log_prob detected, using fallback value")
                        log_prob = torch.tensor(-10.0, device=self.device, requires_grad=True, dtype=torch.float16)

                    log_probs.append(log_prob)

            # Clear intermediate computations
            del chunk_outputs, chunk_logits
            torch.cuda.empty_cache()

        # Get hidden state for baseline (WITH gradients for training)
        final_inputs = {
            'input_ids': outputs.sequences[0][-50:].unsqueeze(0),  # Reduced to 50 tokens
            'attention_mask': torch.ones(1, min(50, len(outputs.sequences[0]))).to(self.device)
        }
        final_outputs = self.model(**final_inputs, output_hidden_states=True)

        if final_outputs.hidden_states:
            hidden_state = final_outputs.hidden_states[-1][0]  # [seq_len, hidden_dim]

            # Check for invalid values before averaging
            if torch.any(torch.isnan(hidden_state)) or torch.any(torch.isinf(hidden_state)):
                print("Warning: Invalid hidden states detected")
                avg_hidden_state = torch.zeros(self.model.config.hidden_size, 
                                             device=self.device, dtype=torch.float16, 
                                             requires_grad=True)
            else:
                # Clamp and average
                hidden_state_clamped = torch.clamp(hidden_state, min=-10, max=10)
                avg_hidden_state = hidden_state_clamped.mean(dim=0)
                avg_hidden_state.requires_grad_(True)
        else:
            avg_hidden_state = torch.zeros(self.model.config.hidden_size, device=self.device, 
                                           dtype=torch.float16, requires_grad=True)

        log_probs_tensor = torch.stack(log_probs) if log_probs else torch.tensor([], device=self.device, requires_grad=True, dtype=torch.float16)
        return response_text, log_probs_tensor, avg_hidden_state

    def compute_baseline_value(self, hidden_state, training_mode=False):
        """Compute baseline value from hidden state with proper error handling."""
        if not self.use_baseline or hidden_state is None:
            return torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Ensure input is in correct format and check for invalid values
        if torch.any(torch.isnan(hidden_state)) or torch.any(torch.isinf(hidden_state)):
            print("Warning: Invalid hidden state input to baseline, using fallback")
            return torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Clamp input to prevent extreme values
        hidden_state_input = torch.clamp(hidden_state, min=-10.0, max=10.0)

        try:
            baseline_value = self.baseline_network(hidden_state_input)

            # Check for invalid baseline output
            if torch.isnan(baseline_value) or torch.isinf(baseline_value):
                print("Warning: Baseline network produced invalid output, using fallback")
                baseline_value = torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)
            else:
                # Clamp baseline value to reasonable range
                baseline_value = torch.clamp(baseline_value, min=-5.0, max=5.0)

        except Exception as e:
            print(f"Warning: Error in baseline network: {e}, using fallback")
            baseline_value = torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Return tensor (with gradients) for training, scalar for evaluation
        if training_mode:
            return baseline_value  # Keep gradients and tensor
        else:
            return baseline_value.item()  # Convert to scalar

    def train_with_reinforce(self, train_data_path: str, num_epochs=1, batch_size=2, gradient_accumulation_steps=8):
        """
        Memory-efficient REINFORCE training with proper error handling.
        """
        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        print(f"Loaded {len(train_dataset)} training examples.")

        self.model.train()
        if self.use_baseline:
            self.baseline_network.train()

        for epoch in range(num_epochs):
            print(f"\n=== Epoch {epoch + 1}/{num_epochs} ===")
            epoch_rewards = []
            epoch_losses = []

            # Clear optimizers
            self.policy_optimizer.zero_grad()
            if self.use_baseline:
                self.baseline_optimizer.zero_grad()

            for step, batch_start in enumerate(range(0, len(train_dataset), batch_size)):
                batch_end = min(batch_start + batch_size, len(train_dataset))
                batch_items = train_dataset[batch_start:batch_end]

                # Initialize batch tracking variables
                batch_policy_losses = []
                batch_baseline_losses = []
                batch_rewards = []
                batch_advantages = []
                item_data_list = []  # Initialize here to avoid UnboundLocalError

                for item in batch_items:
                    try:
                        prompt = item['prompt']
                        correct_answer = item['correct_answer']

                        # Generate response
                        response_text, log_probs, hidden_state = self.generate_response_with_logprobs(prompt)
                        if log_probs.numel() == 0:
                            continue
                        # Check for invalid log probabilities
                        if torch.any(torch.isnan(log_probs)) or torch.any(torch.isinf(log_probs)):
                            print("Warning: Invalid log_probs detected, skipping sample")
                            continue
                        # Compute reward
                        reward_info = self.reward_function.compute_total_reward(response_text, correct_answer)
                        reward = reward_info['r_normalized']

                        # Clamp reward to reasonable range
                        reward = max(-5.0, min(5.0, reward))

                        if step % 10 == 0:  # Print every 10 steps to avoid spam
                            print(f"\n--- Step {step} Sample ---")
                            print(f"Prompt: {prompt[:100]}...")
                            print(f"Generated: {response_text}")
                            print(f"Correct Answer: {correct_answer}")
                            print(f"Step {step}, R_total: {reward:.3f}, R_binary: {reward_info['r_binary']:.2f}, "
                                  f"P_hacker: {reward_info['p_hacker']:.2f}, P_exploit: {reward_info['p_exploit']:.2f}")
                            print("--- End Sample ---\n")

                        # Compute baseline and advantage
                        baseline_value = self.compute_baseline_value(hidden_state, training_mode=True)

                        # Convert reward to tensor for computation
                        reward_tensor = torch.tensor(reward, device=self.device, dtype=torch.float16, requires_grad=False)
                        advantage = reward_tensor - baseline_value
                        batch_advantages.append(advantage.detach().item())

                        # Store data for batch processing
                        batch_rewards.append(reward)

                        # Store the raw data for later processing
                        item_data = {
                            'log_probs': log_probs,
                            'advantage_tensor': advantage,
                            'baseline_value': baseline_value,
                            'reward': reward_tensor
                        }
                        item_data_list.append(item_data)

                        # Baseline loss
                        if self.use_baseline:
                            baseline_loss = ((baseline_value - reward_tensor) ** 2) / gradient_accumulation_steps

                            # Check for invalid baseline loss
                            if torch.isnan(baseline_loss) or torch.isinf(baseline_loss):
                                print(f"Warning: Invalid baseline_loss detected, skipping sample")
                                continue

                            batch_baseline_losses.append(baseline_loss)

                    except Exception as e:
                        print(f"Error processing sample: {e}")
                        continue

                # Skip if no valid samples
                if not item_data_list:
                    continue

                # Normalize advantages across the batch to reduce variance
                if batch_advantages and len(batch_advantages) > 1:
                    adv_mean = np.mean(batch_advantages)
                    adv_std = np.std(batch_advantages) + 1e-8

                    # Apply normalization and compute policy losses
                    for i, item_data in enumerate(item_data_list):
                        normalized_advantage = (batch_advantages[i] - adv_mean) / adv_std
                        normalized_advantage = max(-2.0, min(2.0, normalized_advantage))

                        # Scale log_probs by sequence length to prevent explosion
                        log_probs = item_data['log_probs']
                        normalized_log_probs = log_probs / max(1.0, len(log_probs))

                        # Policy loss with normalized advantage
                        policy_loss = -torch.sum(normalized_log_probs) * normalized_advantage / gradient_accumulation_steps

                        # Check for invalid policy loss
                        if torch.isnan(policy_loss) or torch.isinf(policy_loss):
                            print(f"Warning: Invalid policy_loss detected, skipping sample")
                            continue

                        batch_policy_losses.append(policy_loss)
                else:
                    # Fallback for single sample or no valid advantages
                    for item_data in item_data_list:
                        advantage = torch.clamp(item_data['advantage_tensor'], min=-2.0, max=2.0)
                        log_probs = item_data['log_probs']
                        normalized_log_probs = log_probs / max(1.0, len(log_probs))
                        policy_loss = -torch.sum(normalized_log_probs) * advantage / gradient_accumulation_steps
                        batch_policy_losses.append(policy_loss)

                if not batch_policy_losses:
                    continue

                # Accumulate gradients
                total_policy_loss = torch.stack(batch_policy_losses).sum()

                if self.use_baseline and batch_baseline_losses:
                    total_baseline_loss = torch.stack(batch_baseline_losses).sum()
                    total_loss = total_policy_loss + total_baseline_loss
                else:
                    total_loss = total_policy_loss

                # Check if loss is reasonable before backward pass
                if torch.isnan(total_loss) or torch.isinf(total_loss) or abs(total_loss.item()) > 100.0:
                    print(f"Warning: Extreme loss detected: {total_loss.item():.3f}, skipping batch")
                    self.policy_optimizer.zero_grad()
                    if self.use_baseline:
                        self.baseline_optimizer.zero_grad()
                    continue

                total_loss.backward()

                epoch_rewards.extend(batch_rewards)
                epoch_losses.append(total_policy_loss.item() * gradient_accumulation_steps)

                # Update weights every gradient_accumulation_steps
                if (step + 1) % gradient_accumulation_steps == 0:
                    # Clip gradients
                    grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    if self.use_baseline:
                        baseline_grad_norm = torch.nn.utils.clip_grad_norm_(self.baseline_network.parameters(), max_norm=1.0)

                    # Update policy
                    self.policy_optimizer.step()
                    self.policy_optimizer.zero_grad()

                    # Update baseline
                    if self.use_baseline:
                        self.baseline_optimizer.step()
                        self.baseline_optimizer.zero_grad()

                # Memory cleanup
                torch.cuda.empty_cache()

                if step % 10 == 0:
                    avg_reward = np.mean(batch_rewards) if batch_rewards else 0
                    print(f"Step {step}, Avg Reward: {avg_reward:.3f}")

            # Final weight update if needed
            if len(train_dataset) % gradient_accumulation_steps != 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
                if self.use_baseline:
                    torch.nn.utils.clip_grad_norm_(self.baseline_network.parameters(), max_norm=0.5)
                self.policy_optimizer.step()
                self.policy_optimizer.zero_grad()
                if self.use_baseline:
                    self.baseline_optimizer.step()
                    self.baseline_optimizer.zero_grad()

            # Epoch summary
            if epoch_rewards:
                avg_epoch_reward = np.mean(epoch_rewards)
                avg_epoch_loss = np.mean(epoch_losses) if epoch_losses else 0
                print(f"Epoch {epoch + 1} - Avg Reward: {avg_epoch_reward:.3f}, Avg Loss: {avg_epoch_loss:.3f}")
                self.reward_history.extend(epoch_rewards)

    def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
        """
        Generate response for evaluation.
        """
        self.model.eval()

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return response.strip()

    def evaluate_model(self, test_data_path, max_examples) -> Dict:
        """
        Evaluate trained model on the test set.
        """
        test_dataset = self.data_processor.load_medqa_data(test_data_path)
        
        # Limit to first N examples if specified
        if max_examples is not None:
            test_dataset = test_dataset[:max_examples]
    
        print(f"Evaluating on {max_examples} test examples...")

        # Ensure model is in eval mode
        self.model.eval()
        if self.use_baseline:
            self.baseline_network.eval()

        correct_predictions = 0
        total_examples = len(test_dataset)

        reward_components = {'r_binary': [], 'p_hacker': [], 'p_exploit': []}
        format_violations = 0
        with torch.no_grad():
            for i, test_item in enumerate(test_dataset):
                prompt = test_item['prompt']
                correct_answer = test_item['correct_answer']

                # Generate response
                response = self.generate_response(prompt)

                # Compute rewards
                reward_info = self.reward_function.compute_total_reward(response, correct_answer)

                # Track reward components
                for key in reward_components:
                    reward_components[key].append(reward_info[key])

                # Check if prediction is correct
                predicted_answer = self.reward_function.extract_answer_choice(response)
                if predicted_answer and predicted_answer.upper() == correct_answer.upper():
                    correct_predictions += 1

                # Check if format is correct
                format_ok = self.reward_function.validate_format(response)
                format_violations += int(not format_ok)

                if i % 100 == 0:
                    print(f"Evaluated {i + 1}/{total_examples} examples...")
            format_violation_rate = format_violations / total_examples
            print(f"Format Violation Rate: {format_violation_rate:.2%}")

        accuracy = correct_predictions / total_examples

        # Compute average reward components
        avg_rewards = {key: np.mean(values) for key, values in reward_components.items()}

        print(f"Evaluation complete. Accuracy: {accuracy:.3f}")
        print(f"Average Binary Reward: {avg_rewards['r_binary']:.3f}")
        print(f"Average Hacker Penalty: {avg_rewards['p_hacker']:.3f}")
        print(f"Average Exploit Penalty: {avg_rewards['p_exploit']:.3f}")

        return {
            "accuracy": accuracy,
            "correct_predictions": correct_predictions,
            "total_examples": total_examples,
            "avg_rewards": avg_rewards,
            "reward_history": self.reward_history
        }

In [9]:
def monitor_memory():
    """Monitor GPU memory usage."""
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024 ** 3  # GB
        memory_reserved = torch.cuda.memory_reserved() / 1024 ** 3  # GB
        print(f"GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB")

In [10]:
def main():
    print("Initial memory state:")
    monitor_memory()

    # Set environment variables for memory optimization
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    # os.environ["TORCHINDUCTOR_DISABLE"] = "1"
    # os.environ["TOKENIZERS_PARALLELISM"] = "false"

    # if torch.cuda.is_available():
    #     torch.backends.cuda.matmul.allow_tf32 = False
    #     torch.backends.cudnn.allow_tf32 = False
    #     torch.cuda.set_per_process_memory_fraction(0.8)
    #     torch.backends.cuda.enable_flash_sdp(False)
    #     print("Applied V100-specific optimizations")
    #     torch.cuda.set_device(0)

    # Load data from Hugging Face
    # print("Loading data from Hugging Face...")
    # dataset = load_dataset("GBaker/MedQA-USMLE-4-options")
    # train_dataset = dataset["train"]
    # test_dataset = dataset["test"]
    # train_df = train_dataset.to_pandas()
    # test_df = test_dataset.to_pandas()
    # train_data_path = "medqa_train.json"
    # test_data_path = "medqa_test.json"
    # print(f"Saving data to local files: {train_data_path}, {test_data_path}")
    # train_df.to_json(train_data_path, orient='records', indent=4)
    # test_df.to_json(test_data_path, orient='records', indent=4)

    train_df = pd.read_json("medqa_train.json")
    test_df = pd.read_json("medqa_test.json")

    # Configuration for reward function
    reward_config = {
        'w_b': 1.0,
    'w_h': 0.5, 
    'w_e': 0.3, 
    'w_f': 0.3,
    'tau_hacker': 0.7, 'tau_exploit': 15, 'lambda_e': 1.0
    }

    # Initialize trainer
    local_model_directory = "Llama3.2-3B-Instruct-SFT"

    trainer = MedREINFORCETrainer(
        model_path=local_model_directory,
        reward_config=reward_config,
        use_baseline=True
    )

    # Setup sample tracking
    print("\n" + "=" * 80)
    print("SETTING UP REASONING TRACKING")
    print("=" * 80)

    # Select the first question from the test set to track
    sample_question_data = test_df.iloc[0]
    sample_prompt = trainer.data_processor.format_prompt(
        sample_question_data['question'],
        sample_question_data['options']
    )
    sample_correct_answer = sample_question_data['answer_idx']

    def check_reasoning_on_sample(stage: str):
        """Helper function to check model's reasoning on the sample prompt."""
        print("\n" + "-" * 30 + f" {stage} " + "-" * 30)
        print(f"Prompt: {sample_question_data['question']}")
        print(f"Correct Answer: {sample_correct_answer}")

        response = trainer.generate_response(sample_prompt)
        reward_info = trainer.reward_function.compute_total_reward(response, sample_correct_answer)

        print("\n--- Model Generation ---")
        print(response)
        print("------------------------")

        print("\n--- Analysis ---")
        print(f"Predicted Answer: {trainer.reward_function.extract_answer_choice(response)}")
        print(f"Binary Reward (Correctness): {reward_info['r_binary']:.2f}")
        print(f"Hacker Penalty: {reward_info['p_hacker']:.2f}")
        print(f"Exploit Penalty: {reward_info['p_exploit']:.2f}")
        print(f"Normalized Total Reward: {reward_info['r_normalized']:.2f}")
        print("-" * (62 + len(stage)))

    # Check reasoning BEFORE training
    check_reasoning_on_sample(stage="BEFORE TRAINING")

    # Run Training with REINFORCE
    print("\n" + "=" * 80)
    print("STARTING REINFORCE TRAINING")
    print("=" * 80)

    # Create a small sample for demonstration
    # train_subset = train_dataset.select(range(500))
    # train_subset_df = train_subset.to_pandas()
    # train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

    # Create a small sample for demonstration and save it
    train_subset_df = train_df.iloc[:1000]
    train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

    trainer.train_with_reinforce(
        train_data_path="medqa_train_sample.json",
        num_epochs=2,
        batch_size=4
    )

    print("\n" + "=" * 80)
    print("TRAINING COMPLETE")
    print("=" * 80)

    # Run the final evaluation on the full test set
    print("\nRunning final evaluation on the test set...")
    final_results = trainer.evaluate_model("medqa_test.json", 200)
    print("\nFinal Evaluation Metrics:")
    print(json.dumps({k: v for k, v in final_results.items() if k != 'reward_history'}, indent=4))


if __name__ == "__main__":
    main()

Initial memory state:
GPU Memory - Allocated: 0.08GB, Reserved: 0.10GB
Using device: cuda
Loading tokenizer and model from local path: Llama3.2-3B-Instruct-SFT




Model hidden dimension: 3072

SETTING UP REASONING TRACKING

------------------------------ BEFORE TRAINING ------------------------------
Prompt: A junior orthopaedic surgery resident is completing a carpal tunnel repair with the department chairman as the attending physician. During the case, the resident inadvertently cuts a flexor tendon. The tendon is repaired without complication. The attending tells the resident that the patient will do fine, and there is no need to report this minor complication that will not harm the patient, as he does not want to make the patient worry unnecessarily. He tells the resident to leave this complication out of the operative report. Which of the following is the correct next action for the resident to take?
Correct Answer: B

--- Model Generation ---
<think>
        The resident should not leave out a complication of the case in the operative report. The attending's directive to do so is inappropriate and goes against the principles of transparenc

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.



--- Step 0 Sample ---
Prompt: You are a medical expert taking the USMLE exam. Given the clinical scenario below, respond with your...
Generated: <think>
        The patient presents with symptoms of a urinary tract infection (UTI), which is common in pregnant women. The patient's symptoms, such as burning during urination, have worsened despite increased fluid intake and cranberry extract, suggesting a bacterial infection. The absence of costovertebral angle tenderness and a normal physical exam support this diagnosis. The first-line treatment for uncomplicated UTIs in pregnant women is typically an antibiotic that is safe for the mother and fetus. Nitrofurantoin is often avoided in pregnancy due to concerns about fetal exposure to the drug. Ceftriaxone is typically used for more complicated infections or in patients who are allergic to other antibiotics. Doxycycline is contraindicated in pregnancy due to the risk of inhibiting bone growth and causing tooth discoloration in the fetus.