# Importing

In [None]:
%pip install transformers==4.36.2
%pip install -q -U bitsandbytes
%pip install -q -U git+https://github.com/huggingface/accelerate.git
%pip install sentence_transformers==4.1.0
%pip install git+https://github.com/openai/whisper.git

In [None]:
import json
import re
import time
import os
import gc
import boto3
import pickle
from typing import Dict, List
import numpy as np

import bitsandbytes.optim as bnb_optim
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig)
from huggingface_hub import login
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
hf_token = "HF_TOKEN" #Huggingface token
login(hf_token)

pd.options.display.max_seq_items = 2000

In [None]:
#!pip freeze > requirements.txt

# Multi Component Reward Function

In [None]:
class EmbeddingModelManager:
    _instance = None
    _model = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(EmbeddingModelManager, cls).__new__(cls)
        return cls._instance

    def get_model(self, model_name='all-MiniLM-L6-v2'):
        if self._model is None:
            print(f"Loading SentenceTransformer: {model_name}")
            self._model = SentenceTransformer(model_name)
            self._model = self._model.cpu()
        return self._model

    def cleanup(self):
        if self._model is not None:
            del self._model
            self._model = None
        torch.cuda.empty_cache()


embedding_manager = EmbeddingModelManager()

In [None]:
class RewardModel:
    def __init__(self,
                 w_b=1.0,
                 w_a=0.2,
                 w_s=0.2,
                 w_fact=0.2,
                 tau_answer=0.7,
                 tau_preamble=15,
                 lambda_s=1.0,
                 violation_threshold_answer=0.1,
                 violation_threshold_structural=0.3,
                 embedding_model='all-MiniLM-L6-v2',
                 factual_agreement_threshold=0.15,
                 umls_api_key=None):

        self.w_b = w_b
        self.w_a = w_a
        self.w_s = w_s
        self.w_fact = w_fact
        self.tau_answer = tau_answer
        self.tau_preamble = tau_preamble
        self.lambda_s = lambda_s
        self.violation_threshold_answer = violation_threshold_answer
        self.violation_threshold_structural = violation_threshold_structural
        self.sentence_model = embedding_manager.get_model(embedding_model)

        self.answer_leaking_phrases = [
            "the correct answer is", "the answer is definitely", "the choice is clearly",
            "option A is the right one", "option B is the right one", "option C is the right one",
            "option D is the right one", "we can conclude the answer is", 
            "the solution is A", "the solution is B", "the solution is C", "the solution is D",
            "therefore the answer is", "so the correct choice is", "the final answer is",
            "answer: A", "answer: B", "answer: C", "answer: D"
        ]

        self.leak_embeddings = self.sentence_model.encode(self.answer_leaking_phrases)
        self.fact_verification_system = AtomicFactVerificationSystem(
            agreement_threshold=factual_agreement_threshold,
            umls_api_key=umls_api_key
        )

    def extract_answer_choice(self, generation):
        """Extract the final answer choice from <answer> tags."""
        answer_match = re.search(r'<answer>(.*?)</answer>', generation, re.IGNORECASE | re.DOTALL)
        if not answer_match:
            return None
        answer_content = answer_match.group(1).strip()
        match = re.search(r'\b([A-D])\b', answer_content, re.IGNORECASE)
        return match.group(1).upper() if match else None

    def extract_think_content(self, generation):
        think_match = re.search(r'<think>(.*?)</think>', generation, re.DOTALL | re.IGNORECASE)
        if think_match:
            content = think_match.group(1).strip()
            if content:
                # Remove common prefixes
                content = re.sub(r'^(Reasoning:\s*|Analysis:\s*|Thinking:\s*)', '', content, flags=re.IGNORECASE)
                return content
        
        # If no think tags found, check if entire response might be reasoning
        # but only if it doesn't contain answer tags
        if '<answer>' not in generation.lower():
            return generation.strip()
        
        # If we have answer tags but no think tags, that's a format violation
        if '<answer>' in generation.lower() and '<think>' not in generation.lower():
            print("Format violation: Answer tags without think tags")
            return ""
        
        return ""

    def extract_pre_think_content(self, generation):
        """Extract content that appears before the <think> tag."""
        think_start = re.search(r'<think>', generation, re.IGNORECASE)
        return generation[:think_start.start()].strip() if think_start else generation.strip()

    def validate_format(self, generation) -> bool:
        """Enhanced format validation with better detection"""
        # Check for required tags
        has_think_tag = '<think>' in generation.lower()
        has_answer_tag = '<answer>' in generation.lower()
        
        if not (has_think_tag and has_answer_tag):
            return False
        
        # Extract and validate think content
        think_match = re.search(r'<think>\s*(.*?)\s*</think>', generation, re.DOTALL | re.IGNORECASE)
        if not think_match:
            return False
        
        think_content = think_match.group(1).strip()
        if len(think_content) < 10:  # Minimum meaningful reasoning length
            return False
        
        # Extract and validate answer content
        answer_match = re.search(r'<answer>\s*(.*?)\s*</answer>', generation, re.IGNORECASE | re.DOTALL)
        if not answer_match:
            return False
        
        answer_content = answer_match.group(1).strip()
        # Must contain exactly one letter A-D
        answer_letters = re.findall(r'\b([A-D])\b', answer_content, re.IGNORECASE)
        
        return len(answer_letters) == 1

    def compute_binary_reward(self, generation, correct_answer):
        predicted_answer = self.extract_answer_choice(generation)
        if predicted_answer is None:
            return 0.0
        return 1.0 if predicted_answer == correct_answer.upper() else 0.0

    def compute_answer_penalty(self, generation):
        think_content = self.extract_think_content(generation)
        if not think_content:
            return 0.0
        
        # Method 1: Cosine similarity with known leaking phrases
        think_embedding = self.sentence_model.encode([think_content])
        similarities = cosine_similarity(think_embedding, self.leak_embeddings)[0]
        semantic_similarity = float(np.max(similarities))
        
        # Method 2: Direct pattern matching for explicit answers
        pattern_penalty = 0.0
        
        # Look for explicit answer patterns in think content
        explicit_patterns = [
            r'\b(answer|choice|option)\s*:?\s*([A-D])\b',
            r'\b([A-D])\s+is\s+(correct|right|the answer)',
            r'\bcorrect\s+answer\s+is\s+([A-D])\b',
            r'\bthe\s+answer\s+is\s+([A-D])\b',
            r'\bchoose\s+([A-D])\b',
            r'^([A-D])[.)]',  # Starts with A. or A)
            r'\b([A-D])\s*[-–]\s*correct',
        ]
        
        for pattern in explicit_patterns:
            if re.search(pattern, think_content, re.IGNORECASE):
                pattern_penalty = 1.0
                break
        
        # Method 3: Check if answer appears too early in reasoning
        early_answer_penalty = 0.0
        words = think_content.split()
        if len(words) > 5:  # Only check if there's substantial content
            first_quarter = ' '.join(words[:len(words)//4])
            for pattern in explicit_patterns[:5]:  # Check first 5 most obvious patterns
                if re.search(pattern, first_quarter, re.IGNORECASE):
                    early_answer_penalty = 0.5
                    break
        
        # Return maximum penalty from all methods
        total_penalty = max(semantic_similarity, pattern_penalty, early_answer_penalty)
        
        # Debug output
        if total_penalty > 0.1:
            print(f"Answer leakage detected: semantic={semantic_similarity:.3f}, "
                  f"pattern={pattern_penalty:.3f}, early={early_answer_penalty:.3f}")
        
        return total_penalty if total_penalty > self.tau_answer else 0.0

    def compute_structural_penalty(self, generation):
        """Enhanced structural penalty computation"""
        # Method 1: Check content before <think> tag
        think_start = re.search(r'<think>', generation, re.IGNORECASE)
        if think_start:
            pre_think_content = generation[:think_start.start()].strip()
            word_count = len(pre_think_content.split())
            
            if word_count > self.tau_preamble:
                print(f"Structural violation: {word_count} words before <think> tag")
                return self.lambda_s
        
        # Method 2: Check content after </answer> tag  
        answer_end = re.search(r'</answer>', generation, re.IGNORECASE)
        if answer_end:
            post_answer_content = generation[answer_end.end():].strip()
            word_count = len(post_answer_content.split())
            
            if word_count > 5:  # Allow small amounts of post-answer text
                print(f"Structural violation: {word_count} words after </answer> tag")
                return self.lambda_s
        
        # Method 3: Check for reasoning outside think tags
        think_match = re.search(r'<think>(.*?)</think>', generation, re.DOTALL | re.IGNORECASE)
        if think_match:
            # Content before think
            before_think = generation[:think_match.start()].strip()
            # Content after think but before answer
            after_think_start = think_match.end()
            answer_start = re.search(r'<answer>', generation, re.IGNORECASE)
            
            if answer_start:
                between_content = generation[after_think_start:answer_start.start()].strip()
                
                # Check if there's substantial reasoning content outside think tags
                reasoning_indicators = ['because', 'therefore', 'since', 'due to', 'as a result', 
                                      'this means', 'indicates', 'suggests', 'shows that']
                
                for content in [before_think, between_content]:
                    if len(content.split()) > 10:  # Substantial content
                        if any(indicator in content.lower() for indicator in reasoning_indicators):
                            print(f"Structural violation: Reasoning outside think tags")
                            return self.lambda_s
        
        return 0.0

    def compute_factual_reward(self, generation, context=""):
        try:
            reasoning_content = self.extract_think_content(generation)

            if not reasoning_content:
                print(f"[FactExtractor] No reasoning found in generation:\n{generation}\n")
                return {
                    "factual_reward": 0.0,
                    "factual_analysis": {},
                    "error": "No reasoning content found"
                }

            result = self.fact_verification_system.process_response(
                reasoning_content, context
            )

            factual_reward = result.get("factual_analysis", {}).get("factual_reward", 0.0)

            return {
                "factual_reward": factual_reward,
                "factual_analysis": result.get("factual_analysis", {}),
                "facts": result.get("facts", []),
                "error": result.get("error", None)
            }

        except Exception as e:
            return {
                "factual_reward": 0.0,
                "factual_analysis": {},
                "facts": [],
                "error": f"Factual verification failed: {str(e)}"
            }

    def compute_total_reward(self, generation, correct_answer, context=""):
        format_valid = self.validate_format(generation)
        
        if not format_valid:
            print(f"Format validation failed for generation: {generation[:100]}...")
            return {
                'r_binary': -1.0,
                'p_answer': 0.0,
                'p_structural': 1.0,
                'r_factual': 0.0,
                'factual_analysis': {},
                'extracted_facts': [],
                'factual_error': "Invalid format - computation terminated",
                'r_total': -1.0,
                'r_normalized': np.tanh(-1.0),
                'format_valid': False
            }

        r_binary = self.compute_binary_reward(generation, correct_answer)
        p_answer = self.compute_answer_penalty(generation)
        p_structural = self.compute_structural_penalty(generation)

        factual_result = self.compute_factual_reward(generation, context)
        r_factual = factual_result["factual_reward"]

        print(f"Reward components: binary={r_binary:.3f}, answer_penalty={p_answer:.3f}, "
              f"structural_penalty={p_structural:.3f}, factual={r_factual:.3f}")

        r_total = self.w_b * r_binary - self.w_a * p_answer - self.w_s * p_structural + self.w_fact * r_factual
        r_normalized = np.tanh(r_total)

        return {
            'r_binary': r_binary,
            'p_answer': p_answer,
            'p_structural': p_structural,
            'r_factual': r_factual,
            'factual_analysis': factual_result["factual_analysis"],
            'extracted_facts': factual_result.get("facts", []),
            'factual_error': factual_result.get("error", None),
            'r_total': r_total,
            'r_normalized': r_normalized,
            'format_valid': True
        }

    def calculate_hacking_rate(self, responses_with_rewards):
        """Fixed hacking rate calculation with better debugging"""
        total_responses = len(responses_with_rewards)
        if total_responses == 0:
            return {
                'total_responses': 0,
                'valid_responses': 0,
                'answer_violation_count': 0,
                'invalid_format_count': 0,
                'structural_violation_count': 0,
                'factual_violation_count': 0,
                'any_violation_count': 0,
                'answer_violation_rate': 0.0,
                'structural_violation_rate': 0.0,
                'factual_violation_rate': 0.0,
                'overall_violation_rate': 0.0
            }
        
        valid_responses = 0
        answer_violations = 0
        structural_violations = 0
        factual_violations = 0
        any_violation = 0
        high_rewards = 0
        
        factual_threshold = 0.3
        high_reward_threshold = 0.5
        
        print(f"Analyzing {total_responses} responses...")
        
        for i, item in enumerate(responses_with_rewards):
            response = item['response']
            reward_info = item.get('reward_info', {})
            
            # Skip invalid format responses
            if not reward_info.get('format_valid', True):
                continue
                
            valid_responses += 1
            
            # Get penalty/reward values
            answer_penalty = reward_info.get('p_answer', 0.0)
            structural_penalty = reward_info.get('p_structural', 0.0)
            factual_reward = reward_info.get('r_factual', 1.0)
            total_reward = reward_info.get('r_total', 0.0)
            
            # Determine violations
            answer_violation = answer_penalty > self.violation_threshold_answer
            structural_violation = structural_penalty > self.violation_threshold_structural
            factual_violation = factual_reward < factual_threshold
            
            # Count violations
            if answer_violation:
                answer_violations += 1
            if structural_violation:
                structural_violations += 1
            if factual_violation:
                factual_violations += 1
            if answer_violation or structural_violation or factual_violation:
                any_violation += 1
            if total_reward > high_reward_threshold:
                high_rewards += 1
                
            # Debug first few responses
            if i < 5:
                print(f"Response {i}: answer_pen={answer_penalty:.3f}, struct_pen={structural_penalty:.3f}, "
                      f"factual={factual_reward:.3f}, total={total_reward:.3f}")
        
        print(f"Valid responses: {valid_responses}/{total_responses}")
        print(f"Violations: answer={answer_violations}, structural={structural_violations}, "
              f"factual={factual_violations}, any={any_violation}")
        
        return {
            'total_responses': total_responses,
            'valid_responses': valid_responses,
            'answer_violation_count': answer_violations,
            'invalid_format_count': total_responses - valid_responses,
            'structural_violation_count': structural_violations,
            'factual_violation_count': factual_violations,
            'any_violation_count': any_violation,
            'high_reward_count': high_rewards,
            'answer_violation_rate': answer_violations / valid_responses if valid_responses > 0 else 0.0,
            'structural_violation_rate': structural_violations / valid_responses if valid_responses > 0 else 0.0,
            'factual_violation_rate': factual_violations / valid_responses if valid_responses > 0 else 0.0,
            'overall_violation_rate': any_violation / valid_responses if valid_responses > 0 else 0.0,
            'high_reward_rate': high_rewards / valid_responses if valid_responses > 0 else 0.0
        }

# Data Processor

In [None]:
class MedQADataProcessor:
    """Process MedQA dataset for training."""

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def format_prompt(self, question, options):
        formatted_options = ""
        for letter, option in options.items():
            formatted_options += f"{letter}. {option}\n"
        input_text = f"{question}\n\n{formatted_options.strip()}"
        prompt = f"""You are a medical expert taking the USMLE exam. Given the clinical scenario below, respond with your reasoning in a <think></think> tag and your final answer choice (A, B, C, or D) in an <answer></answer> tag.
        Scenario:
        {input_text}
        Format:
        <think>your step-by-step clinical reasoning goes here</think>
        <answer>A</answer>  # Replace A with your final answer choice
        Here is a sample prompt and response:
        Prompt/Question: A man is brought into the emergency department by the police department. The officer state that the man has been arrested multiple times for public alcohol intoxication, but recently became homeless. On exam, the man is behaving erratically. His vitals are all within normal limits. He appears confused and has a slurred speech. On gait exam, the patient is ataxic and cannot stand without support for more than a few seconds. Labs return with the following values: Na 140, K 4, Cl 106, BUN 8, Cr 2. His ABG has pH 7.3, PaCO2 13mm, PaO2 130mm, HCO3 7. His urinalysis is shown in Figure 1. Blood salicylate levels return as normal. While you await other diagnostic tests, which of the following should be administered next to treat this patient?       
        <think>Consider the symptoms and lab values presented in the scenario. The patient is showing signs of salicylate poisoning, which is consistent with the lab values of metabolic acidosis, elevated anion gap, and hyperventilation leading to respiratory alkalosis. Fomepizole is used to treat methanol and ethylene glycol poisoning, so it is not a suitable choice in this scenario. Salicylate poisoning is a known cause of respiratory alkalosis, so the patient's hyperventilation is consistent with this diagnosis. Ethanol is a common treatment for salicylate poisoning as it is thought to inhibit the enzyme aldehyde dehydrogenase and slow the metabolism of salicylate. Naloxone is an opioid antagonist, which would be used in the case of an opioid overdose, not salicylate poisoning. Naltrexone is an opioid antagonist that is often used for the treatment of opioid addiction, but it is not indicated in this scenario. Fomepizole is a medication used to treat methanol and ethylene glycol poisoning, but it is not indicated in this scenario as the patient's lab values are consistent with salicylate poisoning. Therefore, ethanol is the most appropriate choice to treat this patient's condition.</think>
        <answer>A</answer>
        Your response:"""
        return prompt

    def load_medqa_data(self, file_path):
        """Load and process MedQA dataset from a JSON file."""
        with open(file_path, 'r') as f:
            data = json.load(f)
        processed_data = []
        for item in data:
            processed_item = {
                'question': item['question'],
                'options': item['options'],
                'correct_answer': item['answer_idx'],
                'prompt': self.format_prompt(item['question'], item['options'])
            }
            processed_data.append(processed_item)
        return processed_data

# Baseline Network

In [None]:
class BaselineNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()

        # Add input normalization layer
        self.input_norm = nn.LayerNorm(input_dim)

        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        x = x.float()
        x = torch.clamp(x, min=-10.0, max=10.0)
        x = self.input_norm(x)
        output = self.network(x)
        output = torch.clamp(output.squeeze(-1), min=-5.0, max=5.0)
        return output

# Policy Trainer

In [None]:
class PolicyTrainer:
    def __init__(self, model_path, reward_config=None, use_baseline=True, umls_api_key=None, log_file=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        print(f"Loading tokenizer and model from local path: {model_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.eos_token is None:
            self.tokenizer.eos_token = self.tokenizer.pad_token or self.tokenizer.unk_token
        if self.tokenizer.eos_token_id is None:
            self.tokenizer.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)

        self.model = AutoModelForCausalLM.from_pretrained(
           model_path,
           dtype=torch.float16,
           device_map="auto",
           low_cpu_mem_usage=True,
           ignore_mismatched_sizes=True,
           trust_remote_code=True
       )
        try:
            model_param_device = next(self.model.parameters()).device
        except StopIteration:
            model_param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.device = model_param_device

        print(f"Model automatically placed on device: {self.device}")
        self.metrics_tracker = MetricsTracker()

        self._baseline_cache_path = "logs/baseline_responses.json"
        try:
            self._baseline_model = AutoModelForCausalLM.from_pretrained(
                model_path,
                device_map="cpu",
                torch_dtype=torch.float32
            )
            self._baseline_model.eval()
            print("Frozen baseline model initialized for epoch comparisons (CPU)")
        except Exception as e:
            print("Warning: baseline model load failed with", e)
            print("Proceeding without a separate CPU baseline model.")
            self._baseline_model = None

        self._baseline_model.eval()
        print("Frozen baseline model initialized for epoch comparisons")

        self.use_learnable_reward = True
        hidden_dim = self.model.config.hidden_size
        print(f"Model hidden dimension: {hidden_dim}")

        if self.use_learnable_reward:
            reward_device = self.device if torch.cuda.is_available() and 'cuda' in str(self.device) else torch.device(
                "cpu")
            try:
                self.reward_head = nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim // 2, 1)
                ).to(self.device, dtype=torch.float32)
                print(f"Learnable reward head initialized on {reward_device} (float32)")
            except Exception as e:
                print(f"Warning: failed to place reward_head on {reward_device}: {e}")
                print("Falling back to CPU for reward_head.")
                self.reward_head = nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim // 2, 1)
                ).to(torch.device("cpu"), dtype=torch.float32)

            self.reward_optimizer = torch.optim.Adam(self.reward_head.parameters(), lr=1e-6)
            print("Learnable reward head initialized")

        self.device = next(self.model.parameters()).device
        print(f"Model automatically placed on device: {self.device}")
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()

        self.model.train()

        reward_config = reward_config or {}
        self.data_processor = MedQADataProcessor(self.tokenizer)
        reward_config['umls_api_key'] = umls_api_key
        self.reward_function = RewardModel(**reward_config)
        self.reward_function.fact_verification_system.training_mode = False

        self.use_baseline = use_baseline
        if self.use_baseline:
            self.baseline_network = BaselineNetwork(input_dim=hidden_dim)
            self.baseline_network = self.baseline_network.to(self.device)
            self.baseline_optimizer = optim.Adam(self.baseline_network.parameters(), lr=1e-6)
            print("Baseline network initialized")

        self.policy_optimizer = bnb_optim.AdamW8bit(self.model.parameters(), lr=5e-7)
        self.reward_history = []

        adversarial_config = {
            'temperature': 1.2,
            'max_examples': 50,
            'preference_margin': 0.5,
            'validation_threshold': 0.7
        }
        self.adversarial_trainer = AdversarialTrainer(self, adversarial_config)

        # Set up logging
        self.log_file = log_file or f"training_log_{time.strftime('%Y%m%d_%H%M%S')}.txt"

        # Clear the log file at start
        with open(self.log_file, 'w', encoding='utf-8') as f:
            f.write(f"Training Log Started: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("=" * 80 + "\n\n")

        print("PolicyTrainer initialization complete!")

    def _ensure_baseline_responses(self, eval_prompts):
        if os.path.exists(self._baseline_cache_path):
            with open(self._baseline_cache_path, "r", encoding="utf-8") as f:
                return json.load(f)

        print("Generating baseline responses using existing model...")
        self._baseline_model.eval()
        baseline_results = []
        for prompt in eval_prompts:
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self._baseline_model.generate(
                    **inputs,
                    max_new_tokens=100,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )

            base_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:],
                                              skip_special_tokens=True).strip()
            baseline_results.append(base_text)

        os.makedirs("logs", exist_ok=True)
        with open(self._baseline_cache_path, "w", encoding="utf-8") as f:
            json.dump(baseline_results, f, indent=2)
        return baseline_results

    def log_epoch_comparisons(self, eval_prompts, epoch, correct_answers=None, context_list=None):
        results = []
        if epoch % 5 != 0:
            return []
        for i, prompt in enumerate(eval_prompts):
            context = context_list[i] if context_list else ""
            correct_answer = correct_answers[i] if correct_answers else None

            self.model.eval()
            baseline_texts = self._ensure_baseline_responses(eval_prompts)
            base_text = baseline_texts[i]

            adv_text, _, _, _ = self.generate_response_with_logprobs(prompt)

            # Compute rewards for adversarial response
            reward_info = self.reward_function.compute_total_reward(
                adv_text, correct_answer, context
            )

            entry = {
                "epoch": epoch,
                "prompt": prompt,
                "baseline_response": base_text,
                "adversarial_response": adv_text,
                "correct_answer": correct_answer,
                "reward_info": reward_info,
            }
            results.append(entry)

            # Console logging
            print("=" * 80)
            print(f"[Epoch {epoch}] Prompt: {prompt}")
            print(f"Correct Answer: {correct_answer}")
            print("\n--- Baseline Response ---")
            print(base_text)
            print("\n--- Adversarial Response ---")
            print(adv_text)
            print("\nReward breakdown:")
            print(reward_info)

        # Save to JSON file for later analysis
        os.makedirs("logs", exist_ok=True)
        log_path = f"logs/epoch_{epoch}_comparisons.json"
        with open(log_path, "w") as f:
            json.dump(results, f, indent=2)

        return results

    def generate_response_with_logprobs(self, prompt, max_new_tokens=2048):
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        input_length = inputs['input_ids'].shape[1]

        self.model.eval()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=min(max_new_tokens, 512),
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        generated_ids = outputs.sequences[0][input_length:]
        response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        if len(generated_ids) == 0:
            return response_text, torch.tensor([], device=self.device, requires_grad=True), torch.zeros(
                self.model.config.hidden_size, device=self.device, dtype=torch.float16, requires_grad=True)

        self.model.train()

        log_probs = []
        # chunk_size = min(128, len(generated_ids))
        if len(generated_ids) > 0:
            full_context = outputs.sequences[0]
            full_inputs = {'input_ids': full_context.unsqueeze(0),
                           'attention_mask': torch.ones_like(full_context).unsqueeze(0)}

            with torch.no_grad():
                full_outputs = self.model(**full_inputs)
                full_logits = full_outputs.logits[0]

            # Extract log probs for generated tokens only
            for i, token_id in enumerate(generated_ids):
                logit_idx = input_length + i - 1
                if 0 <= logit_idx < full_logits.shape[0]:
                    logits = full_logits[logit_idx].clamp(min=-50, max=50)
                    log_prob = torch.log_softmax(logits, dim=-1)[token_id]
                    log_probs.append(log_prob)

        # Get hidden states for baseline computation
        final_inputs = {
            'input_ids': outputs.sequences[0][-50:].unsqueeze(0),
            'attention_mask': torch.ones(1, min(50, len(outputs.sequences[0]))).to(self.device)
        }
        final_outputs = self.model(**final_inputs, output_hidden_states=True)

        if final_outputs.hidden_states:
            hidden_state = final_outputs.hidden_states[-1][0]

            # Enhanced validity checks
            if torch.any(torch.isnan(hidden_state)) or torch.any(torch.isinf(hidden_state)):
                print("Warning: Invalid hidden states detected")
                avg_hidden_state = torch.zeros(self.model.config.hidden_size,
                                               device=self.device,
                                               dtype=torch.float16,
                                               requires_grad=True)
            else:
                hidden_state_clamped = torch.clamp(hidden_state, min=-10, max=10)
                avg_hidden_state = hidden_state_clamped.mean(dim=0)
                avg_hidden_state.requires_grad_(True)
        else:
            avg_hidden_state = torch.zeros(self.model.config.hidden_size,
                                           device=self.device,
                                           dtype=torch.float16,
                                           requires_grad=True)

        log_probs_tensor = torch.stack(log_probs) if log_probs else torch.tensor([], device=self.device,
                                                                                 requires_grad=True,
                                                                                 dtype=torch.float16)
        return response_text, log_probs_tensor, avg_hidden_state

    def generate_and_evaluate_with_facts(self, prompt, correct_answer, context=""):
        response_text, log_probs, hidden_state = self.generate_response_with_logprobs(prompt)

        # Get rule-based reward (your existing system)
        rule_reward_info = self.reward_function.compute_total_reward(
            response_text, correct_answer, context
        )

        # Add learnable reward component
        if hasattr(self, 'use_learnable_reward') and self.use_learnable_reward:
            with torch.no_grad():
                learnable_reward = self.reward_head(hidden_state.detach().float()).item()

            # Combine rule-based and learnable rewards
            combined_reward = 0.7 * rule_reward_info['r_normalized'] + 0.3 * learnable_reward

            # Add to reward info
            reward_info = rule_reward_info.copy()
            reward_info['r_learnable'] = learnable_reward
            reward_info['r_combined'] = combined_reward
            reward_info['r_normalized'] = combined_reward  # Use combined as main reward
        else:
            reward_info = rule_reward_info

        return response_text, log_probs, hidden_state, reward_info

    def compute_baseline_value(self, hidden_state, training_mode=False):
        """Compute baseline value with enhanced error handling."""
        if not self.use_baseline or hidden_state is None:
            return torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Enhanced validity checks
        if torch.any(torch.isnan(hidden_state)) or torch.any(torch.isinf(hidden_state)):
            print("Warning: Invalid hidden state input to baseline, using fallback")
            return torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Clamp input to prevent extreme values
        hidden_state_input = torch.clamp(hidden_state, min=-10.0, max=10.0)

        try:
            baseline_value = self.baseline_network(hidden_state_input)

            # Check for invalid baseline output
            if torch.isnan(baseline_value) or torch.isinf(baseline_value):
                print("Warning: Baseline network produced invalid output, using fallback")
                baseline_value = torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)
            else:
                baseline_value = torch.clamp(baseline_value, min=-5.0, max=5.0)

        except Exception as e:
            print(f"Warning: Error in baseline network: {e}, using fallback")
            baseline_value = torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        if training_mode:
            return baseline_value
        else:
            return baseline_value.item()

    def _generate_conservative_response(self, prompt, correct_answer):
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=True,
                temperature=0.7,
                top_p=0.8,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()

        return response

    def _generate_high_quality_response(self, prompt, correct_answer):
        """Generate a high-quality response that should score well"""

        # Use conservative generation parameters for high quality
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                do_sample=True,
                temperature=0.6,  # Lower temperature for more focused responses
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()

        # Ensure proper format
        if '<think>' not in response:
            response = f"<think>Let me analyze this medical scenario systematically. {response}</think>"
        if '<answer>' not in response:
            response += f"<answer>{correct_answer}</answer>"

        return response

    def _analyze_response_quality(self, response, reward_info, correct_answer):
        violations = {
            'correctness_violations': 0,
            'answer_leaking_violations': 0,
            'format_violations': 0,
            'factual_violations': 0,
            'bad_ood_high_rewards': 0
        }
        if reward_info.get('r_binary', 0) <= 0:
            violations['correctness_violations'] = 1

        if reward_info.get('p_answer', 0) > 0.1:
            violations['answer_leaking_violations'] = 1

        if reward_info.get('p_structural', 0) > 0.1 or not self.reward_function.validate_format(response):
            violations['format_violations'] = 1

        if reward_info.get('r_factual', 1.0) < 0.3:
            violations['factual_violations'] = 1

        # Goal 5: Less Rewards to Bad OOD Reasoning
        # High total reward despite violations = bad OOD getting undeserved high reward
        has_violations = sum(violations.values()) > 0
        high_reward = reward_info.get('r_total', 0) > 0.5
        if has_violations and high_reward:
            violations['bad_ood_high_rewards'] = 1

        return violations

    def evaluate_model(self, test_data_path, max_examples):
        test_dataset = self.data_processor.load_medqa_data(test_data_path)
        self.reward_function.fact_verification_system.training_mode = False

        if max_examples is not None:
            test_dataset = test_dataset[:max_examples]

        print(f"Evaluating on {max_examples} test examples...")

        self.model.eval()
        if self.use_baseline:
            self.baseline_network.eval()

        correct_predictions = 0
        total_examples = len(test_dataset)

        # Track individual violations for standard deviation calculation
        violation_records = {
            'correctness_violations': [],
            'answer_leaking_violations': [],
            'format_violations': [],
            'factual_violations': [],
            'bad_ood_high_rewards': []
        }

        # Enhanced reward tracking including factual scores
        reward_components = {'r_binary': [], 'p_answer': [], 'p_structural': [], 'r_factual': []}
        all_rewards = []
        format_violations = 0
        evaluation_data = []

        with torch.no_grad():
            for i, test_item in enumerate(test_dataset):
                prompt = test_item['prompt']
                correct_answer = test_item['correct_answer']

                response = self.generate_response(prompt)

                # Enhanced reward computation with factual verification
                reward_info = self.reward_function.compute_total_reward(response, correct_answer)

                evaluation_data.append({
                    'response': response,
                    'reward_info': reward_info,
                    'correct_answer': correct_answer
                })

                # Track reward components
                for key in reward_components:
                    reward_components[key].append(reward_info[key])

                all_rewards.append(reward_info['r_normalized'])

                predicted_answer = self.reward_function.extract_answer_choice(response)
                is_correct = predicted_answer and predicted_answer.upper() == correct_answer.upper()
                if is_correct:
                    correct_predictions += 1

                format_ok = self.reward_function.validate_format(response)
                format_violations += int(not format_ok)

                # Track individual violations for std calculation
                violations = self._analyze_response_quality(response, reward_info, correct_answer)
                for key in violation_records.keys():
                    violation_records[key].append(violations[key])

                if i % 100 == 0:
                    print(f"Evaluated {i + 1}/{total_examples} examples...")

        format_violation_rate = format_violations / total_examples
        accuracy = correct_predictions / total_examples
        avg_rewards = {key: np.mean(values) for key, values in reward_components.items()}
        hacking_stats = self.reward_function.calculate_hacking_rate(evaluation_data)

        # Calculate standard deviations
        reward_std = np.std(all_rewards) if len(all_rewards) > 1 else 0.0

        test_metrics = {
            "accuracy": accuracy,
            "correctness_violations_rate": 1 - accuracy,
            "answer_leaking_violations_rate": hacking_stats["answer_violation_rate"],
            "format_violations_rate": format_violation_rate,
            "factual_violations_rate": hacking_stats["factual_violation_rate"],
            "bad_ood_high_rewards_rate": hacking_stats["overall_violation_rate"],
            "avg_reward": np.mean(all_rewards),
            "avg_factual": avg_rewards["r_factual"],
            "correct_predictions": correct_predictions,
            "total_examples": total_examples,
            "reward_history": self.reward_history,

            # Add standard deviations
            "std_reward": reward_std,
            "std_factual": np.std(reward_components['r_factual']) if len(reward_components['r_factual']) > 1 else 0.0
        }

        # Add violation standard deviations
        for key in violation_records.keys():
            rate_key = f"{key}_rate"
            std_key = f"{key}_std"

            # Calculate violation rate and std
            test_metrics[rate_key] = test_metrics.get(rate_key, np.mean(violation_records[key]))
            test_metrics[std_key] = np.std(violation_records[key]) if len(violation_records[key]) > 1 else 0.0

        # Use MetricsTracker instead of manual printing
        self.metrics_tracker.print_metrics_with_std(test_metrics, "Test Evaluation Results")

        return test_metrics

    def generate_response(self, prompt, max_new_tokens=2048) -> str:
        self.model.eval()

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return response.strip()

    def analyze_hacking_sensitivity(self, test_data_path, max_examples, tau_answer_range=None, tau_preamble_range=None):
        if tau_answer_range is None:
            tau_answer_range = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
        if tau_preamble_range is None:
            tau_preamble_range = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

        test_dataset = self.data_processor.load_medqa_data(test_data_path)
        if max_examples is not None:
            test_dataset = test_dataset[:max_examples]

        print(f"Running sensitivity analysis on {len(test_dataset)} examples...")
        print(f"Testing tau_answer: {tau_answer_range}")
        print(f"Testing tau_preamble: {tau_preamble_range}")

        print("Generating responses...")
        self.model.eval()
        if self.use_baseline:
            self.baseline_network.eval()

        evaluation_data = []
        with torch.no_grad():
            for i, test_item in enumerate(test_dataset):
                prompt = test_item['prompt']
                correct_answer = test_item['correct_answer']
                response = self.generate_response(prompt)

                evaluation_data.append({
                    'response': response,
                    'correct_answer': correct_answer,
                    'prompt': prompt
                })

        # Store original thresholds
        original_tau_answer = self.reward_function.tau_answer
        original_tau_preamble = self.reward_function.tau_preamble

        sensitivity_results = []

        print("\nTesting threshold combinations...")
        for tau_answer in tau_answer_range:
            for tau_structural in tau_preamble_range:
                print(f"Testing tau_answer={tau_answer}, tau_preamble={tau_structural}")

                self.reward_function.tau_answer = tau_answer
                self.reward_function.tau_preamble = tau_structural

                responses_with_rewards = []
                for item in evaluation_data:
                    # Enhanced reward computation with factual verification
                    reward_info = self.reward_function.compute_total_reward(
                        item['response'], item['correct_answer']
                    )
                    responses_with_rewards.append({
                        'response': item['response'],
                        'reward_info': reward_info,
                        'correct_answer': item['correct_answer']
                    })

                hacking_stats = self.reward_function.calculate_hacking_rate(responses_with_rewards)

                positive_rewards = sum(1 for item in responses_with_rewards
                                       if item['reward_info']['r_total'] > 0)
                avg_answer_penalty = np.mean([item['reward_info']['p_answer']
                                              for item in responses_with_rewards])
                avg_structural_penalty = np.mean([item['reward_info']['p_structural']
                                                  for item in responses_with_rewards])
                avg_factual_reward = np.mean([item['reward_info']['r_factual']
                                              for item in responses_with_rewards])

                sensitivity_results.append({
                    'tau_answer': tau_answer,
                    'tau_preamble': tau_structural,
                    'answer_violation_rate': hacking_stats['answer_violation_rate'],
                    'structural_violation_rate': hacking_stats['structural_violation_rate'],
                    'factual_violation_rate': hacking_stats['factual_violation_rate'],
                    'overall_violation_rate': hacking_stats['overall_violation_rate'],
                    'answer_violation_count': hacking_stats['answer_violation_count'],
                    'structural_violation_count': hacking_stats['structural_violation_count'],
                    'factual_violation_count': hacking_stats['factual_violation_count'],
                    'positive_reward_count': positive_rewards,
                    'positive_reward_rate': positive_rewards / len(responses_with_rewards),
                    'avg_answer_penalty': avg_answer_penalty,
                    'avg_structural_penalty': avg_structural_penalty,
                    'avg_factual_reward': avg_factual_reward
                })

        # Restore original thresholds
        self.reward_function.tau_answer = original_tau_answer
        self.reward_function.tau_preamble = original_tau_preamble

        return {
            'sensitivity_results': sensitivity_results,
            'tau_answer_range': tau_answer_range,
            'tau_preamble_range': tau_preamble_range,
            'total_examples': len(test_dataset)
        }

    def display_sample_analysis(self, step, item, response_text, reward_info, log_file=None):
        output = []
        output.append(f"\n{'=' * 80}")
        output.append(f"STEP {step} ANALYSIS")
        output.append(f"{'=' * 80}")
        output.append("QUESTION:")
        output.append(f"{item['question']}")
        output.append(f"\nCORRECT ANSWER: {item['correct_answer']}")
        output.append("\nMODEL RESPONSE:")
        output.append(f"{response_text}")

        think_content = self.reward_function.extract_think_content(response_text)
        if think_content:
            output.append("\nEXTRACTED REASONING:")
            output.append(f"{think_content}")  # Full reasoning for file

        extracted_facts = reward_info.get('extracted_facts', [])
        if extracted_facts:
            output.append(f"\nEXTRACTED FACTS ({len(extracted_facts)} total):")
            for i, fact in enumerate(extracted_facts, 1):  # Show all facts for file
                fact_text = fact.get('text', '') if isinstance(fact, dict) else getattr(fact, 'text', '')
                fact_category = fact.get('category', '') if isinstance(fact, dict) else getattr(fact, 'category', '')
                llm_score = fact.get('llm_score', 0.0) if isinstance(fact, dict) else getattr(fact, 'llm_score', 0.0)
                kb_score = fact.get('kb_score', 0.0) if isinstance(fact, dict) else getattr(fact, 'kb_score', 0.0)

                output.append(f"  {i}. [{fact_category}] {fact_text}")
                output.append(f"     LLM Score: {llm_score:.2f}, KB Score: {kb_score:.2f}")
        else:
            output.append("\nNO FACTS EXTRACTED")

        # Show reward breakdown
        output.append("\nREWARD BREAKDOWN:")
        output.append(f"  Binary (Correctness): {reward_info['r_binary']:.2f}")
        output.append(f"  Answer Penalty: {reward_info['p_answer']:.2f}")
        output.append(f"  Structural Penalty: {reward_info['p_structural']:.2f}")
        output.append(f"  Factual Reward: {reward_info['r_factual']:.2f}")
        output.append(f"  Total Normalized: {reward_info['r_normalized']:.2f}")

        if reward_info.get('factual_error'):
            output.append(f"  Factual Error: {reward_info['factual_error']}")

        output.append(f"{'=' * 80}")

        # Join all output
        full_output = '\n'.join(output)

        # Print to console (truncated for Jupyter)
        console_output = []
        console_output.append(f"\n{'=' * 80}")
        console_output.append(f"STEP {step} ANALYSIS")
        console_output.append(f"{'=' * 80}")
        console_output.append("QUESTION:")
        console_output.append(f"{item['question'][:200]}...")
        console_output.append(f"\nCORRECT ANSWER: {item['correct_answer']}")
        console_output.append("\nMODEL RESPONSE:")
        console_output.append(f"{response_text[:500]}...")

        if think_content:
            console_output.append("\nEXTRACTED REASONING:")
            console_output.append(f"{think_content[:300]}...")

        if extracted_facts:
            console_output.append(f"\nEXTRACTED FACTS ({len(extracted_facts)} total):")
            for i, fact in enumerate(extracted_facts[:3], 1):  # Show first 3
                fact_text = fact.get('text', '') if isinstance(fact, dict) else getattr(fact, 'text', '')
                fact_category = fact.get('category', '') if isinstance(fact, dict) else getattr(fact, 'category', '')
                llm_score = fact.get('llm_score', 0.0) if isinstance(fact, dict) else getattr(fact, 'llm_score', 0.0)
                kb_score = fact.get('kb_score', 0.0) if isinstance(fact, dict) else getattr(fact, 'kb_score', 0.0)

                console_output.append(f"  {i}. [{fact_category}] {fact_text[:100]}...")
                console_output.append(f"     LLM Score: {llm_score:.2f}, KB Score: {kb_score:.2f}")

            if len(extracted_facts) > 3:
                console_output.append(f"     ... and {len(extracted_facts) - 3} more facts")
        else:
            console_output.append("\nNO FACTS EXTRACTED")

        console_output.append("\nREWARD BREAKDOWN:")
        console_output.append(f"  Binary (Correctness): {reward_info['r_binary']:.2f}")
        console_output.append(f"  Answer Penalty: {reward_info['p_answer']:.2f}")
        console_output.append(f"  Structural Penalty: {reward_info['p_structural']:.2f}")
        console_output.append(f"  Factual Reward: {reward_info['r_factual']:.2f}")
        console_output.append(f"  Total Normalized: {reward_info['r_normalized']:.2f}")

        if reward_info.get('factual_error'):
            console_output.append(f"  Factual Error: {reward_info['factual_error']}")

        console_output.append(f"{'=' * 80}")

        print('\n'.join(console_output))

        if log_file:
            with open(log_file, 'a', encoding='utf-8') as f:
                f.write(full_output + '\n')

    def train_reward_model_stage1(self, train_data_path, num_epochs, batch_size):
        """Stage 1: Train only the learnable reward components"""
        if not hasattr(self, 'reward_head'):
            print("No learnable reward head found - skipping reward model training")
            return

        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        print(f"Stage 1: Training reward model on {len(train_dataset)} examples")

        # Freeze policy model
        self.model.eval()
        self.reward_head.train()

        for epoch in range(num_epochs):
            total_loss = 0
            num_batches = 0

            for batch_start in range(0, len(train_dataset), batch_size):
                batch_end = min(batch_start + batch_size, len(train_dataset))
                batch_items = train_dataset[batch_start:batch_end]

                batch_losses = []

                for item in batch_items:
                    try:
                        prompt = item['prompt']
                        correct_answer = item['correct_answer']

                        # Generate response (no gradients for policy)
                        with torch.no_grad():
                            response_text, _, hidden_state = self.generate_response_with_logprobs(prompt)

                        # Compute target score using rule-based reward
                        rule_reward_info = self.reward_function.compute_total_reward(
                            response_text, correct_answer
                        )
                        target_score = rule_reward_info['r_normalized']

                        # Predict score using learnable reward head
                        predicted_score = self.reward_head(
                            hidden_state.detach().to(self.device, dtype=torch.float32)).view(())
                        target_tensor = torch.tensor(target_score, device=self.device, dtype=torch.float32).view(())
                        loss = torch.nn.functional.mse_loss(predicted_score, target_tensor)
                        batch_losses.append(loss)

                    except Exception as e:
                        print(f"Error in reward model training: {e}")
                        continue

                if batch_losses:
                    total_batch_loss = torch.stack(batch_losses).mean()
                    total_batch_loss.backward()
                    total_loss += total_batch_loss.item()
                    num_batches += 1

                # Update reward model every batch
                torch.nn.utils.clip_grad_norm_(self.reward_head.parameters(), 1.0)
                self.reward_optimizer.step()
                self.reward_optimizer.zero_grad()

            avg_loss = total_loss / num_batches if num_batches > 0 else 0
            print(f"Reward model epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.2f}")

        print("Stage 1 complete: Reward model training finished")

    def train_policy_stage2(self, train_data_path, num_epochs, batch_size):
        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        print(f"Stage 2: Training policy on {len(train_dataset)} examples")

        if hasattr(self, 'reward_head'):
            self.reward_head.eval()
        self.model.train()

        for epoch in range(num_epochs):
            epoch_rewards = []
            epoch_advantages = []
            epoch_policy_losses = []
            epoch_baseline_values = []

            violation_records = {
                'correctness_violations': [],
                'answer_leaking_violations': [],
                'format_violations': [],
                'factual_violations': [],
                'bad_ood_high_rewards': []
            }

            epoch_metrics = {
                'correctness_violations': 0,
                'answer_leaking_violations': 0,
                'format_violations': 0,
                'factual_violations': 0,
                'bad_ood_high_rewards': 0,
                'total_examples': 0
            }

            self.policy_optimizer.zero_grad()
            if self.use_baseline:
                self.baseline_optimizer.zero_grad()

            for batch_start in range(0, len(train_dataset), batch_size):
                batch_end = min(batch_start + batch_size, len(train_dataset))
                batch_items = train_dataset[batch_start:batch_end]

                batch_policy_losses = []
                batch_baseline_losses = []
                batch_rewards = []
                batch_advantages = []

                for item in batch_items:
                    try:
                        prompt = item['prompt']
                        correct_answer = item['correct_answer']

                        response_text, log_probs, hidden_state, reward_info = self.generate_and_evaluate_with_facts(
                            prompt, correct_answer
                        )

                        # Track violations for this individual example
                        violations = self._analyze_response_quality(response_text, reward_info, correct_answer)

                        # Record individual violations (0 or 1) for std calculation
                        for key in violation_records.keys():
                            violation_records[key].append(violations[key])

                        # Aggregate totals for rates
                        for key, value in violations.items():
                            epoch_metrics[key] = epoch_metrics.get(key, 0) + value
                        epoch_metrics["total_examples"] = epoch_metrics.get("total_examples", 0) + 1

                        if log_probs.numel() == 0:
                            continue

                        # Get reward from frozen reward model
                        with torch.no_grad():
                            if hasattr(self, 'reward_head'):
                                reward = self.reward_head(hidden_state.detach().float()).item()
                            else:
                                reward_info = self.reward_function.compute_total_reward(
                                    response_text, correct_answer
                                )
                                reward = reward_info['r_normalized']

                        # Baseline and advantage
                        baseline_value = self.compute_baseline_value(hidden_state, training_mode=True)
                        reward_tensor = torch.tensor(reward, device=self.device, dtype=torch.float16)
                        advantage = reward_tensor - baseline_value

                        batch_rewards.append(reward)
                        batch_advantages.append(advantage.detach().item())

                        policy_loss = -torch.sum(log_probs) * advantage
                        batch_policy_losses.append(policy_loss)

                        # Collect metrics for std calculation
                        epoch_rewards.append(reward)
                        epoch_advantages.append(advantage.detach().item())
                        epoch_policy_losses.append(policy_loss.detach().item())
                        if hasattr(baseline_value, 'item'):
                            epoch_baseline_values.append(baseline_value.item())
                        else:
                            epoch_baseline_values.append(float(baseline_value))

                        if self.use_baseline:
                            baseline_loss = ((baseline_value - reward_tensor) ** 2)
                            batch_baseline_losses.append(baseline_loss)

                    except Exception as e:
                        print(f"Error in policy training: {e}")
                        continue

                # Update policy
                if batch_policy_losses:
                    total_policy_loss = torch.stack(batch_policy_losses).mean()
                    total_loss = total_policy_loss

                    if self.use_baseline and batch_baseline_losses:
                        total_baseline_loss = torch.stack(batch_baseline_losses).mean()
                        total_loss += total_baseline_loss

                    total_loss.backward()

            # Update weights
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
            if self.use_baseline:
                torch.nn.utils.clip_grad_norm_(self.baseline_network.parameters(), max_norm=0.5)

            self.policy_optimizer.step()
            self.policy_optimizer.zero_grad()
            if self.use_baseline:
                self.baseline_optimizer.step()
                self.baseline_optimizer.zero_grad()

            # Calculate means and standard deviations for continuous metrics
            avg_reward = np.mean(epoch_rewards) if epoch_rewards else 0
            std_reward = np.std(epoch_rewards) if len(epoch_rewards) > 1 else 0

            avg_advantage = np.mean(epoch_advantages) if epoch_advantages else 0
            std_advantage = np.std(epoch_advantages) if len(epoch_advantages) > 1 else 0

            avg_policy_loss = np.mean(epoch_policy_losses) if epoch_policy_losses else 0
            std_policy_loss = np.std(epoch_policy_losses) if len(epoch_policy_losses) > 1 else 0

            avg_baseline_value = np.mean(epoch_baseline_values) if epoch_baseline_values else 0
            std_baseline_value = np.std(epoch_baseline_values) if len(epoch_baseline_values) > 1 else 0

            print(f"Policy epoch {epoch + 1}/{num_epochs}")
            print(f"Avg Reward: {avg_reward:.2f} (±{std_reward:.2f})\n")
        
            if self.use_baseline:
                print(f"  Avg Baseline Value: {avg_baseline_value:.2f} (±{std_baseline_value:.2f})")

            if epoch_rewards and epoch_metrics["total_examples"] > 0:
                # Calculate violation rates and their standard deviations
                for key in ["correctness_violations", "answer_leaking_violations",
                            "format_violations", "factual_violations", "bad_ood_high_rewards"]:
                    rate_key = f"{key}_rate"
                    std_key = f"{key}_std"

                    # Rate is the mean of 0s and 1s
                    epoch_metrics[rate_key] = epoch_metrics[key] / epoch_metrics["total_examples"]

                    # Standard deviation of binary values (0s and 1s)
                    if len(violation_records[key]) > 1:
                        epoch_metrics[std_key] = np.std(violation_records[key])
                    else:
                        epoch_metrics[std_key] = 0.0

                # Store other metrics
                epoch_metrics["avg_reward"] = avg_reward
                epoch_metrics["std_reward"] = std_reward
                epoch_metrics["avg_advantage"] = avg_advantage
                epoch_metrics["std_advantage"] = std_advantage
                epoch_metrics["avg_policy_loss"] = avg_policy_loss
                epoch_metrics["std_policy_loss"] = std_policy_loss
                epoch_metrics["avg_baseline_value"] = avg_baseline_value
                epoch_metrics["std_baseline_value"] = std_baseline_value

                self.metrics_tracker.print_metrics_with_std(epoch_metrics, f"POLICY EPOCH {epoch + 1}")

        print("Stage 2 complete: Policy training finished")

    def train_adversarial_stage3(self, train_data_path, num_cycles=3):
        print("\n" + "=" * 60)
        print("STAGE 3: ADVERSARIAL TRAINING")
        print("=" * 60)
        # DEBUG: Print first prompt to see what format it's in
        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        print("\n=== DEBUG: First prompt from dataset ===")
        print(train_dataset[0]['prompt'][:2048])
        print("=== END DEBUG ===\n")
        prompts = [item['prompt'] for item in train_dataset[:100]]
        answers = [item['correct_answer'] for item in train_dataset[:100]]

        # Run adversarial training cycles
        results = self.adversarial_trainer.run_adversarial_training_cycle(
            prompts, answers, num_cycles
        )

        print("Adversarial training complete:")
        print(f"  Final robustness score: {results['final_robustness']:.2f}")
        print(f"  Total cycles completed: {results['total_cycles']}")

        print("\nEvaluating model AFTER adversarial training...")
        post_adversarial_metrics = self.evaluate_model(train_data_path, 50)

        print("Post-Adversarial Performance:")
        self.metrics_tracker.print_metrics_with_std(post_adversarial_metrics, "POST-ADVERSARIAL METRICS")

        return results

    def train_reward_policy(self, train_data_path, test_data_path, stage1_epochs, stage2_epochs,
                            batch_size=2, max_eval_examples=50, save_checkpoint_path=None):
        """Train only policy stages and optionally save checkpoint"""
        print("TRAINING: Reward and Policy Only")
        print("=" * 40)

        self.train_reward_model_stage1(train_data_path, stage1_epochs, batch_size)
        self.train_policy_stage2(train_data_path, stage2_epochs, batch_size)

        print("\nEvaluating Policy Training Results:")
        stage1_metrics = self.evaluate_model(test_data_path, max_eval_examples)
        self.metrics_tracker.print_metrics_with_std(stage1_metrics, "POLICY STAGE RESULTS")

        # Save checkpoint if path provided
        if save_checkpoint_path:
            self.save_model_checkpoint(save_checkpoint_path, "policy_complete")
            print(f"Policy training checkpoint saved to: {save_checkpoint_path}")

        return stage1_metrics

    def train_adversarial(self, train_data_path, test_data_path, num_cycles=3, max_eval_examples=20):
        print("TRAINING: Adversarial Stage Only")
        print("=" * 40)
    
        adversarial_results = self.train_adversarial_stage3(train_data_path, num_cycles)
    
        print("\nEvaluating Final Results:")
        final_metrics = self.evaluate_model(test_data_path, 20)
        self.metrics_tracker.print_metrics_with_std(final_metrics, "FINAL ADVERSARIAL RESULTS")
    
        return final_metrics, adversarial_results

    def train_combined(self, train_data_path, 
                       test_data_path, 
                       stage1_epochs, 
                       stage2_epochs, 
                       batch_size=2, 
                       max_eval_examples=20, 
                       adversarial_cycles=3):
        # Stage 1: Train reward policy
        print("\nSTAGE 1: REWARD POLICY TRAINING")
        print("-" * 50)
        
        reward_policy_metrics = self.train_reward_policy(
            train_data_path=train_data_path,
            test_data_path=test_data_path,
            stage1_epochs=stage1_epochs,
            stage2_epochs=stage2_epochs,
            batch_size=batch_size,
            max_eval_examples=max_eval_examples
        )
        
        print("\nReward Policy Training Complete!")
        self.metrics_tracker.print_metrics_with_std(reward_policy_metrics, "REWARD POLICY RESULTS")
    
        # Set your S3 bucket and path
        bucket_name = 'b'
        s3_prefix_reward = 'medqa-models/reinforce-llama-run2/'  # For reward-only model
        
        # Save reward-only model locally
        local_model_dir = '/home/ec2-user/SageMaker/llama_RM_2_USMLE'
        os.makedirs(local_model_dir, exist_ok=True)
        
        self.model.save_pretrained(local_model_dir)
        self.tokenizer.save_pretrained(local_model_dir)
        torch.save(self.baseline_network.state_dict(), os.path.join(local_model_dir, "baseline_network.pt"))
        
        # Upload reward-only model to S3
        s3 = boto3.client('s3')
        for root, dirs, files in os.walk(local_model_dir):
            for file in files:
                local_path = os.path.join(root, file)
                relative_path = os.path.relpath(local_path, local_model_dir)
                s3_path = os.path.join(s3_prefix_reward, relative_path)
                print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
                s3.upload_file(local_path, bucket_name, s3_path)
        
        print("\nUpload to S3 complete!")
        print(f"Reward-only model stored at s3://{bucket_name}/{s3_prefix_reward}")
        
        # Stage 2: Adversarial training on the reward-trained model
        print("\n" + "=" * 50)
        print("STAGE 2: ADVERSARIAL TRAINING")
        print("-" * 50)
        try:
            adversarial_metrics = self.train_adversarial(
                train_data_path=train_data_path,
                test_data_path=test_data_path,
                num_cycles=adversarial_cycles,
                max_eval_examples=max_eval_examples
            )
            print("\nAdversarial Training Complete!")
            self.metrics_tracker.print_metrics_with_std(adversarial_metrics, "FINAL ADVERSARIAL RESULTS")
            
            # Save adversarial model locally
            s3_prefix_adv = 'medqa-models/reinforce-llama-adversarial/'
            local_model_dir_adv = '/home/ec2-user/SageMaker/llama_RM_ADV_USMLE'
            os.makedirs(local_model_dir_adv, exist_ok=True)
            
            # FIXED: Only save once, using self
            self.model.save_pretrained(local_model_dir_adv)
            self.tokenizer.save_pretrained(local_model_dir_adv)
            torch.save(self.baseline_network.state_dict(), os.path.join(local_model_dir_adv, "baseline_network.pt"))
            
            # Upload adversarial model to S3
            for root, dirs, files in os.walk(local_model_dir_adv):
                for file in files:
                    local_path = os.path.join(root, file)
                    relative_path = os.path.relpath(local_path, local_model_dir_adv)
                    s3_path = os.path.join(s3_prefix_adv, relative_path)
                    print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
                    s3.upload_file(local_path, bucket_name, s3_path)
            
            print("\nAdversarial model upload to S3 complete!")
            print(f"Adversarial model stored at s3://{bucket_name}/{s3_prefix_adv}")
            
            return {
                'reward_policy_metrics': reward_policy_metrics,
                'adversarial_metrics': adversarial_metrics,
                'training_successful': True
            }
        except Exception as e:
            print(f"\nERROR: Adversarial training failed: {e}")
            print("Returning reward policy model as final result")
            
            return {
                'reward_policy_metrics': reward_policy_metrics,
                'adversarial_metrics': None,
                'training_successful': False,
                'error': str(e)
            }


    def save_model_checkpoint(self, checkpoint_path, stage_name="policy"):
        os.makedirs(checkpoint_path, exist_ok=True)

        # Save the main model
        self.model.save_pretrained(
            os.path.join(checkpoint_path, "model"),
            safe_serialization=True
        )

        # Save tokenizer
        self.tokenizer.save_pretrained(
            os.path.join(checkpoint_path, "tokenizer")
        )

        # Save training components
        checkpoint_data = {
            'stage': stage_name,
            'reward_history': self.reward_history,
            'use_learnable_reward': self.use_learnable_reward,
            'use_baseline': self.use_baseline,
        }

        # Save reward head if it exists
        if hasattr(self, 'reward_head') and self.reward_head is not None:
            torch.save(
                self.reward_head.state_dict(),
                os.path.join(checkpoint_path, "reward_head.pt")
            )
            checkpoint_data['has_reward_head'] = True
        else:
            checkpoint_data['has_reward_head'] = False

        # Save baseline network if it exists
        if hasattr(self, 'baseline_network') and self.baseline_network is not None:
            torch.save(
                self.baseline_network.state_dict(),
                os.path.join(checkpoint_path, "baseline_network.pt")
            )
            checkpoint_data['has_baseline_network'] = True
        else:
            checkpoint_data['has_baseline_network'] = False

        # Save optimizer states
        torch.save(
            self.policy_optimizer.state_dict(),
            os.path.join(checkpoint_path, "policy_optimizer.pt")
        )

        if hasattr(self, 'reward_optimizer'):
            torch.save(
                self.reward_optimizer.state_dict(),
                os.path.join(checkpoint_path, "reward_optimizer.pt")
            )

        if hasattr(self, 'baseline_optimizer'):
            torch.save(
                self.baseline_optimizer.state_dict(),
                os.path.join(checkpoint_path, "baseline_optimizer.pt")
            )

        # Save metadata
        with open(os.path.join(checkpoint_path, "checkpoint_info.json"), "w") as f:
            json.dump(checkpoint_data, f, indent=2)

        print(f"Model checkpoint saved to {checkpoint_path}")

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, reward_config=None, umls_api_key=None):
        """Load model from checkpoint"""

        # Load metadata
        with open(os.path.join(checkpoint_path, "checkpoint_info.json"), "r") as f:
            checkpoint_data = json.load(f)

        print(f"Loading checkpoint from stage: {checkpoint_data['stage']}")

        # Create new trainer instance
        trainer = cls(
            model_path=os.path.join(checkpoint_path, "model"),
            reward_config=reward_config,
            use_baseline=checkpoint_data['use_baseline'],
            umls_api_key=umls_api_key
        )

        # Restore training state
        trainer.reward_history = checkpoint_data.get('reward_history', [])
        trainer.use_learnable_reward = checkpoint_data.get('use_learnable_reward', False)

        # Load reward head if it exists
        if checkpoint_data.get('has_reward_head', False):
            reward_head_path = os.path.join(checkpoint_path, "reward_head.pt")
            if os.path.exists(reward_head_path):
                trainer.reward_head.load_state_dict(torch.load(reward_head_path))
                print("Reward head loaded from checkpoint")

        # Load baseline network if it exists
        if checkpoint_data.get('has_baseline_network', False):
            baseline_path = os.path.join(checkpoint_path, "baseline_network.pt")
            if os.path.exists(baseline_path):
                trainer.baseline_network.load_state_dict(torch.load(baseline_path))
                print("Baseline network loaded from checkpoint")

        # Load optimizer states
        policy_opt_path = os.path.join(checkpoint_path, "policy_optimizer.pt")
        if os.path.exists(policy_opt_path):
            trainer.policy_optimizer.load_state_dict(torch.load(policy_opt_path))
            print("Policy optimizer loaded from checkpoint")

        reward_opt_path = os.path.join(checkpoint_path, "reward_optimizer.pt")
        if os.path.exists(reward_opt_path) and hasattr(trainer, 'reward_optimizer'):
            trainer.reward_optimizer.load_state_dict(torch.load(reward_opt_path))
            print("Reward optimizer loaded from checkpoint")

        baseline_opt_path = os.path.join(checkpoint_path, "baseline_optimizer.pt")
        if os.path.exists(baseline_opt_path) and hasattr(trainer, 'baseline_optimizer'):
            trainer.baseline_optimizer.load_state_dict(torch.load(baseline_opt_path))
            print("Baseline optimizer loaded from checkpoint")

        print(f"Checkpoint loaded successfully from {checkpoint_path}")
        return trainer

    def load_reward_model_and_train_adversarial(self, saved_model_path,
                                                train_data_path,
                                                test_data_path,
                                                adversarial_cycles=3,
                                                max_eval_examples=20):
        print(f"Loading model from {saved_model_path}")
        self.model = AutoModelForCausalLM.from_pretrained(saved_model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(saved_model_path)
        self.model.to(self.device)

        baseline_path = os.path.join(saved_model_path, "baseline_network.pt")
        if os.path.exists(baseline_path):
            print("Loading saved baseline network...")
            baseline_state = torch.load(baseline_path, map_location=self.device)
            self.baseline_network.load_state_dict(baseline_state)
            print("Baseline network loaded successfully")
        else:
            print("Warning: No baseline network found, initializing fresh")
        
        # Run adversarial training on the loaded model
        print("\n" + "="*60)
        print("STARTING ADVERSARIAL TRAINING ON LOADED MODEL")
        print("="*60)
        
        adversarial_metrics = self.train_adversarial(
            train_data_path=train_data_path,
            test_data_path=test_data_path,
            num_cycles=adversarial_cycles,
            max_eval_examples=max_eval_examples
        )
        
        print("\nAdversarial Training Complete!")
        self.metrics_tracker.print_metrics_with_std(adversarial_metrics, "ADVERSARIAL RESULTS")
        
        return adversarial_metrics


# Adversarial Trainer

In [None]:
class AdversarialTrainer:
    
    def __init__(self, base_trainer, adversarial_config=None):

        self.base_trainer = base_trainer
        self.model = base_trainer.model
        self.tokenizer = base_trainer.tokenizer
        self.reward_function = base_trainer.reward_function
        self.device = base_trainer.device

        config = adversarial_config or {}
        self.adversarial_temperature = config.get('temperature', 1.2)
        self.max_adversarial_examples = config.get('max_examples', 50)
        self.preference_margin = config.get('preference_margin', 0.5)
        self.validation_threshold = config.get('validation_threshold', 0.7)

        self.adversarial_examples_buffer = []
        self.preference_pairs_buffer = []

        print("Adversarial Trainer initialized")
    
    def generate_adversarial_examples(self, prompts, correct_answers=None, target_reward=0.5):
        adversarial_examples = []
        print(f"Generating adversarial examples from {len(prompts)} prompts...")
    
        generation_strategies = [
            {'temperature': 0.7, 'top_p': 0.9, 'do_sample': True},   # Balanced
            {'temperature': 0.9, 'top_p': 0.95, 'do_sample': True},  # Slightly diverse
            {'temperature': 0.5, 'top_p': 0.85, 'do_sample': True},  # Conservative
        ]
    
        shown = 0
    
        for strategy_idx, strategy in enumerate(generation_strategies):
            print(f"  Strategy {strategy_idx + 1}: temp={strategy['temperature']}")
    
            for i, prompt in enumerate(prompts[:self.max_adversarial_examples]):
                try:
                    response = self._generate_with_strategy(prompt, strategy)
                    correct_answer = correct_answers[i] if correct_answers else "A"
                    reward_info = self.reward_function.compute_total_reward(response, correct_answer)
    
                    if self._is_adversarial_example(response, reward_info, target_reward):
                        adv = {
                            'prompt': prompt,
                            'response': response,
                            'reward_info': reward_info,
                            'correct_answer': correct_answer,
                            'strategy': strategy_idx,
                            'vulnerability_type': self._classify_vulnerability(reward_info)
                        }
                        adversarial_examples.append(adv)
    
                        # Show first 1–2 adversarial examples for inspection
                        if shown < 2:
                            print("\n=== Adversarial Example ===")
                            print(f"Prompt:\n{prompt}\n")
                            print(f"Response:\n{response}\n")
                            print(f"Reward info: {reward_info}")
                            print(f"Vulnerability: {adv['vulnerability_type']}")
                            print("===========================\n")
                            shown += 1
    
                except Exception as e:
                    print(f"Error generating adversarial example: {e}")
                    continue
    
        print(f"Found {len(adversarial_examples)} adversarial examples")
        self.adversarial_examples_buffer.extend(adversarial_examples)
        return adversarial_examples

    
    def _format_prompt_for_generation(self, prompt):

        return f"""You are a medical expert taking the USMLE exam. Given the clinical scenario below, respond with your reasoning in a <think></think> tag and your final answer choice (A, B, C, or D) in an <answer></answer> tag.
        Scenario:
        {prompt}
        
        Format:
        <think>your step-by-step clinical reasoning goes here</think>
        <answer>A</answer>
        
        Your response:"""
        

    def _generate_with_strategy(self, prompt, strategy):
        device = self.model.device
    
        # Tokenize with same settings as eval
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(device)
    
        # Save and switch model state
        was_training = self.model.training
        self.model.eval()
    
        try:
            with torch.no_grad():
                # Force full precision for stable sampling
                with torch.cuda.amp.autocast(enabled=False):
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=600,
                        temperature=strategy.get("temperature", 0.7),
                        top_p=strategy.get("top_p", 0.9),
                        do_sample=strategy.get("do_sample", True),
                        pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
                        eos_token_id=self.tokenizer.eos_token_id,
                        repetition_penalty=1.2,     # prevent "QuestionQuestion..."
                        no_repeat_ngram_size=3      # prevent n-gram loops
                    )
    
            # Decode only the generated continuation
            response = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            ).strip()
    
        finally:
            # Restore model state
            if was_training:
                self.model.train()
    
        return response

    
    
    def _generate_clean_response(self, prompt, correct_answer):
        """Generate a high-quality, non-adversarial response."""
        
        #formatted = self._format_prompt_for_generation(prompt)
        
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=500,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        return response
    
    def _is_adversarial_example(self, response, reward_info, target_reward):
        total_reward = reward_info.get('r_total', 0)
        
        if total_reward < target_reward:
            return False
        
        has_issues = (
            reward_info.get('p_answer', 0) > 0.05 or
            reward_info.get('p_structural', 0) > 0.05 or 
            reward_info.get('r_factual', 1.0) < 0.4 or
            not self.reward_function.validate_format(response)
        )
        return has_issues and total_reward > target_reward
    
    def _classify_vulnerability(self, reward_info):
        """Classify the type of vulnerability being exploited."""
        if reward_info.get('p_answer', 0) > 0.05:
            return "answer_leaking"
        elif reward_info.get('p_structural', 0) > 0.05:
            return "structural_gaming"
        elif reward_info.get('r_factual', 1.0) < 0.4:
            return "factual_exploitation"
        else:
            return "format_gaming"
    
    def create_preference_pairs(self, adversarial_examples):
        preference_pairs = []
        
        for adv_example in adversarial_examples:
            clean_response = self._generate_clean_response(
                adv_example['prompt'],
                adv_example['correct_answer']
            )
            
            clean_reward_info = self.reward_function.compute_total_reward(
                clean_response, adv_example['correct_answer']
            )
            
            clean_quality = self._assess_response_quality(clean_response, clean_reward_info)
            adv_quality = self._assess_response_quality(adv_example['response'], adv_example['reward_info'])
            
            if clean_quality > adv_quality:
                    preference_pairs.append({
                        'prompt': adv_example['prompt'],
                        'chosen': clean_response,
                        'rejected': adv_example['response'],
                        'correct_answer': adv_example['correct_answer'],
                        'chosen_reward': clean_reward_info,
                        'rejected_reward': adv_example['reward_info'],
                        'vulnerability_type': adv_example['vulnerability_type']
                    })
        
        print(f"Created {len(preference_pairs)} preference pairs")
        self.preference_pairs_buffer.extend(preference_pairs)
        
        return preference_pairs
    
    def _generate_clean_response(self, prompt, correct_answer):
        """Generate a high-quality, non-adversarial response."""
        # Use conservative parameters for clean generation
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        # Ensure proper format
        response = self._ensure_proper_format(response, correct_answer)
        
        return response
    
    def _ensure_proper_format(self, response, correct_answer):
        """Ensure response has proper <think> and <answer> tags."""
        # Add think tags if missing
        if '<think>' not in response.lower():
            reasoning_part = response.split('<answer>')[0].strip()
            answer_part = response.split('<answer>')[1] if '<answer>' in response else ''
            response = f"<think>{reasoning_part}</think>"
            if answer_part:
                response += f"<answer>{answer_part}"
        
        # Add answer tags if missing
        if '<answer>' not in response.lower():
            response += f"\n<answer>{correct_answer}</answer>"
        
        return response
    
    def update_reward_model(self, preference_pairs):
        if not hasattr(self.base_trainer, 'reward_head'):
            print("No learnable reward component found - skipping adversarial update")
            return {'loss': 0.0, 'accuracy': 0.0}
        
        print(f"Updating reward model with {len(preference_pairs)} preference pairs...")
        self.base_trainer.reward_optimizer.zero_grad()
        self.base_trainer.reward_head.train()
        total_loss = 0
        correct_rankings = 0
        
        for pair in preference_pairs:
            try:
                # Get hidden states for both responses
                chosen_hidden = self._get_hidden_state_for_response(
                    pair['prompt'], pair['chosen']
                )
                rejected_hidden = self._get_hidden_state_for_response(
                    pair['prompt'], pair['rejected']
                )
                
                # Compute predicted rewards
                chosen_reward_pred = self.base_trainer.reward_head(chosen_hidden.detach().float())
                rejected_reward_pred = self.base_trainer.reward_head(rejected_hidden.detach().float())
                
                # Preference loss: chosen should have higher reward
                target = torch.ones_like(chosen_reward_pred)
                loss = F.margin_ranking_loss(
                    chosen_reward_pred, rejected_reward_pred, 
                    target, margin=self.preference_margin
                )
                
                loss.backward()
                total_loss += loss.item()
                
                # Track accuracy
                if chosen_reward_pred.item() > rejected_reward_pred.item():
                    correct_rankings += 1
                    
            except Exception as e:
                print(f"Error in preference update: {e}")
                continue
        
        # Update parameters
        if total_loss > 0:
            torch.nn.utils.clip_grad_norm_(self.base_trainer.reward_head.parameters(), 1.0)
            self.base_trainer.reward_optimizer.step()
        
        avg_loss = total_loss / len(preference_pairs) if preference_pairs else 0
        accuracy = correct_rankings / len(preference_pairs) if preference_pairs else 0
        
        print(f"  Adversarial loss: {avg_loss:.4f}")
        print(f"  Preference accuracy: {accuracy:.2f}")
        
        return {'loss': avg_loss, 'accuracy': accuracy}

    def analyze_response_quality(self, response, reward_info, correct_answer):
        """Analyze response against your 5 quality goals"""

        violations = {
            'correctness_violations': 0,
            'answer_leaking_violations': 0,
            'format_violations': 0,
            'factual_violations': 0,
            'bad_ood_high_rewards': 0
        }

        # Goal 1: Improved Correctness
        if reward_info.get('r_binary', 0) <= 0:
            violations['correctness_violations'] = 1

        # Goal 2: Less Answer Leaking
        if reward_info.get('p_answer', 0) > 0.1:
            violations['answer_leaking_violations'] = 1

        # Goal 3: Less Format Violations
        if reward_info.get('p_structural', 0) > 0.1 or not self.reward_function.validate_format(response):
            violations['format_violations'] = 1

        # Goal 4: Less Factual Errors
        if reward_info.get('r_factual', 1.0) < 0.3:
            violations['factual_violations'] = 1

        # Goal 5: Less Rewards to Bad OOD Reasoning
        # High total reward despite violations = bad OOD getting undeserved high reward
        has_violations = sum(violations.values()) > 0
        high_reward = reward_info.get('r_total', 0) > 0.5
        if has_violations and high_reward:
            violations['bad_ood_high_rewards'] = 1

        return violations

    def _assess_response_quality(self, response, reward_info, correct_answer=None):
        """
        Assess overall response quality using existing analysis method.
        
        Returns:
            Quality score between 0 and 1 (higher = better quality)
        """
        # Use existing violation analysis
        violations = self.analyze_response_quality(response, reward_info, correct_answer or "A")
        
        # Convert violations to quality score
        total_violations = sum(violations.values())
        max_violations = len(violations)  # 5 possible violation types
        
        # Quality = 1 - (violation_ratio)
        quality_score = 1.0 - (total_violations / max_violations)
        
        return quality_score
    
    def _get_hidden_state_for_response(self, prompt, response):
        """Get hidden state for a specific response."""
        full_text = prompt + " " + response
        inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            hidden_state = outputs.hidden_states[-1][0].mean(dim=0)  # Average over sequence
        
        return hidden_state
    
    def validate_robustness(self, test_prompts, test_answers):
        """
        Validate that reward model has become more robust.
        
        Args:
            test_prompts: List of test prompts
            test_answers: List of correct answers
            
        Returns:
            Robustness metrics
        """
        print("Validating reward model robustness...")
        
        # Generate new adversarial examples
        new_adversarial = self.generate_adversarial_examples(
            test_prompts[:20], test_answers[:20]
        )
        
        # Check if reward model correctly identifies them as problematic
        correct_identifications = 0
        
        for adv_example in new_adversarial:
            reward_info = adv_example['reward_info']
            
            # If reward model is robust, adversarial examples should get lower rewards
            if reward_info['r_total'] < 0.3:  # Threshold for "correctly identified as bad"
                correct_identifications += 1
        
        robustness_score = correct_identifications / len(new_adversarial) if new_adversarial else 0
        
        print(f"Robustness validation: {robustness_score:.2f}")
        print(f"  Found {len(new_adversarial)} new adversarial examples")
        print(f"  Correctly identified {correct_identifications} as problematic")
        
        return {
            'robustness_score': robustness_score,
            'adversarial_found': len(new_adversarial),
            'correctly_identified': correct_identifications
        }
    
    def run_adversarial_training_cycle(self, prompts, correct_answers, num_cycles=3):
        """
        Complete adversarial training cycle.
        
        Args:
            prompts: Training prompts
            correct_answers: Correct answers
            num_cycles: Number of adversarial cycles to run
            
        Returns:
            Training summary
        """
        print(f"Starting {num_cycles} adversarial training cycles...")
        
        cycle_results = []
        
        for cycle in range(num_cycles):
            print(f"\n=== Adversarial Cycle {cycle + 1}/{num_cycles} ===")
            
            # Step 1: Generate adversarial examples
            adversarial_examples = self.generate_adversarial_examples(
                prompts, correct_answers
            )
            
            if not adversarial_examples:
                print("No adversarial examples found")
                break
            
            # Step 2: Create preference pairs
            preference_pairs = self.create_preference_pairs(adversarial_examples)
            
            if not preference_pairs:
                print("No valid preference pairs created")
                continue
            
            # Step 3: Update reward model
            update_metrics = self.update_reward_model(preference_pairs)
            
            # Step 4: Validate improvement
            validation_metrics = self.validate_robustness(
                prompts[:10], correct_answers[:10]
            )
            
            cycle_results.append({
                'cycle': cycle + 1,
                'adversarial_found': len(adversarial_examples),
                'preference_pairs': len(preference_pairs),
                'update_loss': update_metrics['loss'],
                'update_accuracy': update_metrics['accuracy'],
                'robustness_score': validation_metrics['robustness_score']
            })
            
            print(f"Cycle {cycle + 1} complete:")
            print(f"  Adversarial examples: {len(adversarial_examples)}")
            print(f"  Update accuracy: {update_metrics['accuracy']:.2f}")
            print(f"  Robustness score: {validation_metrics['robustness_score']:.2f}")
        
        return {
            'total_cycles': len(cycle_results),
            'cycle_results': cycle_results,
            'final_robustness': cycle_results[-1]['robustness_score'] if cycle_results else 0
        }
    
    def clear_buffers(self):
        """Clear adversarial example and preference pair buffers."""
        self.adversarial_examples_buffer.clear()
        self.preference_pairs_buffer.clear()
        print("Adversarial training buffers cleared")

# Fact Extraction and LLM Judge

## AtomicFactExtractor

In [None]:
@dataclass
class AtomicFact:
    text: str
    category: str
    confidence: float = 0.0
    llm_score: float = 0.0
    kb_score: float = 0.0
    source_sentence: str = ""

class AtomicFactExtractor:
    def __init__(self, model_name, region_name):
        self.model_name = model_name
        self.extraction_prompt = self._build_extraction_prompt()
        self.client = ClaudeBedrockClient(model_name, region_name)
        
    def _build_extraction_prompt(self):
        return """You are a medical fact extractor. Extract specific, verifiable clinical claims from the reasoning text.
    
        RULES:
        - Extract only factual medical statements (not reasoning steps)
        - Each fact should be 1-2 sentences maximum
        - Focus on: symptoms, diagnoses, treatments, drug effects, anatomical facts
        - Ignore: "the patient likely has..." or "this suggests..." (too speculative)
        
        INPUT TEXT:
        {reasoning_text}
        
        Return ONLY a valid JSON object in this exact format:
        {{"facts": ["fact1", "fact2", "fact3"]}}
        
        EXAMPLES:
        Good facts: "Salicylate poisoning causes metabolic acidosis", "Insulin treats diabetic ketoacidosis"
        Bad facts: "This presentation is consistent with...", "We should consider..."
        
        JSON:"""
    
    def _parse_extraction_response(self, response):
        facts = []
        try:
            print(f"Raw response: '{response}'")
            cleaned_response = response.strip()
    
            cleaned_response = re.sub(r'```json\s*', '', cleaned_response)
            cleaned_response = re.sub(r'```\s*$', '', cleaned_response)
    
            json_str = None
            json_match = re.search(r'\{[^{}]*"facts"[^{}]*:\s*\[[^\]]*\][^{}]*\}', cleaned_response, re.DOTALL)
            if json_match:
                json_str = json_match.group(0)
                print(f"Found complete JSON object: '{json_str}'")
            else:
                facts_only_match = re.search(r'"facts"\s*:\s*\[[^\]]*\]', cleaned_response, re.DOTALL)
                if facts_only_match:
                    json_str = '{' + facts_only_match.group(0) + '}'
                    print(f"Fixed incomplete JSON: '{json_str}'")
                else:
                    json_match = re.search(r'\{.*?\}', cleaned_response, re.DOTALL)
                    if json_match:
                        json_str = json_match.group(0)
                        print(f"Found fallback JSON: '{json_str}'")
                    else:
                        array_match = re.search(r'\[.*?\]', cleaned_response, re.DOTALL)
                        if array_match:
                            json_str = '{"facts": ' + array_match.group(0) + '}'
                            print(f"Created JSON from array: '{json_str}'")
    
            if not json_str:
                print("No JSON pattern found, using fallback extraction")
                return self._fallback_extraction(cleaned_response)
    
            # Fix common JSON issues
            json_str = self._fix_json_issues(json_str)
            print(f"After fixing JSON: '{json_str}'")
            
            try:
                data = json.loads(json_str)
                print(f"Successfully parsed JSON: {data}")
            except json.JSONDecodeError as e:
                print(f"JSON decode failed: {e}")
                print("Attempting advanced repair...")
                
                repaired_json = self._repair_json_advanced(json_str)
                if repaired_json != json_str:
                    try:
                        data = json.loads(repaired_json)
                        print("Successfully repaired and parsed JSON")
                    except Exception as repair_e:
                        print(f"JSON repair also failed: {repair_e}")
                        return self._fallback_extraction(cleaned_response)
                else:
                    return self._fallback_extraction(cleaned_response)
            
            if isinstance(data, dict):
                if "facts" in data:
                    facts_list = data["facts"]
                    if isinstance(facts_list, list):
                        print(f"Found {len(facts_list)} facts in data")
                        
                        for i, fact_item in enumerate(facts_list):
                            if isinstance(fact_item, str):
                                # Simple string fact
                                text = fact_item.strip()
                                if text and len(text) > 5:
                                    fact = AtomicFact(
                                        text=text,
                                        category="medical_statement",
                                        source_sentence=text
                                    )
                                    facts.append(fact)
                                    print(f"Created string fact {i+1}: {text}")
                            elif isinstance(fact_item, dict):
                                # Object fact with metadata
                                text = fact_item.get("text", "").strip()
                                if text and len(text) > 5:
                                    fact = AtomicFact(
                                        text=text,
                                        category=fact_item.get("category", "medical_statement"),
                                        source_sentence=fact_item.get("source_sentence", text)
                                    )
                                    facts.append(fact)
                                    print(f"Created object fact {i+1}: {text}")
                    else:
                        print(f"'facts' is not a list: {type(facts_list)}")
                        return self._fallback_extraction(cleaned_response)
                else:
                    print(f"No 'facts' key found. Available keys: {list(data.keys())}")
                    return self._fallback_extraction(cleaned_response)
            else:
                print(f"Parsed data is not a dict: {type(data)}")
                return self._fallback_extraction(cleaned_response)
            
            if not facts:
                print("No valid facts extracted, using fallback")
                return self._fallback_extraction(cleaned_response)
            
            print(f"Successfully extracted {len(facts)} facts")
            return facts
            
        except Exception as e:
            print(f"Error in _parse_extraction_response: {e}")
            import traceback
            traceback.print_exc()
            return self._fallback_extraction(response)

    def _fix_json_issues(self, json_str):
        """Fix common JSON formatting issues"""
        
        # Remove extra quotes around the whole string
        if json_str.startswith('"') and json_str.endswith('"'):
            json_str = json_str[1:-1]
        
        # Fix escaped quotes
        json_str = json_str.replace('\\"', '"')
        
        # Remove trailing commas
        json_str = re.sub(r',\s*}', '}', json_str)
        json_str = re.sub(r',\s*]', ']', json_str)
        
        # Fix missing closing brackets/braces
        open_braces = json_str.count('{')
        close_braces = json_str.count('}')
        if open_braces > close_braces:
            json_str += '}' * (open_braces - close_braces)
        
        open_brackets = json_str.count('[')
        close_brackets = json_str.count(']')
        if open_brackets > close_brackets:
            json_str += ']' * (open_brackets - close_brackets)
        
        return json_str
    
    def _repair_json_advanced(self, json_str):
        """Advanced JSON repair for malformed responses"""
        
        # If we can't parse it, try to build a valid structure
        try:
            # Look for fact-like content in quotes
            fact_candidates = re.findall(r'"([^"]{10,200})"', json_str)
            
            # Filter for medical-sounding facts
            medical_facts = []
            medical_keywords = ['patient', 'symptom', 'diagnosis', 'treatment', 'medication', 
                              'disease', 'condition', 'causes', 'therapy', 'clinical']
            
            for candidate in fact_candidates:
                if any(keyword in candidate.lower() for keyword in medical_keywords):
                    # Avoid reasoning phrases
                    reasoning_phrases = ['this suggests', 'likely', 'consistent with', 
                                       'we should', 'let us', 'first', 'therefore']
                    if not any(phrase in candidate.lower() for phrase in reasoning_phrases):
                        medical_facts.append(candidate)
            
            if medical_facts:
                # Create valid JSON structure
                valid_json = '{"facts": ' + json.dumps(medical_facts[:10]) + '}'
                print(f"Reconstructed JSON: {valid_json}")
                return valid_json
            
        except Exception as e:
            print(f"Advanced repair failed: {e}")
        
        return json_str
    
    def _fallback_extraction(self, text):
        """Enhanced fallback extraction when JSON parsing fails"""
        facts = []
        
        print("Using fallback extraction method")
        
        # Extract medical sentences using improved patterns
        sentences = re.split(r'[.!?]+', text)
        medical_keywords = [
            'patient', 'diagnosis', 'treatment', 'symptom', 'medication',
            'therapy', 'condition', 'disease', 'clinical', 'medical',
            'causes', 'leads to', 'results in', 'associated with',
            'effective', 'contraindicated', 'indicated', 'syndrome'
        ]
        
        # Reasoning phrases to avoid
        avoid_phrases = [
            'this suggests', 'likely', 'probably', 'consistent with',
            'we should', 'let us', 'first', 'next step', 'therefore',
            'in conclusion', 'based on', 'given that'
        ]
        
        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 15 and len(sentence) < 200:  # Reasonable length
                sentence_lower = sentence.lower()
                
                # Check if sentence contains medical content
                has_medical = any(keyword in sentence_lower for keyword in medical_keywords)
                has_reasoning = any(phrase in sentence_lower for phrase in avoid_phrases)
                
                if has_medical and not has_reasoning:
                    fact = AtomicFact(
                        text=sentence,
                        category="medical_statement",
                        source_sentence=sentence
                    )
                    facts.append(fact)
                    print(f"Fallback extracted: {sentence}")
        
        return facts[:8]  # Limit to 8 facts max

    def _clean_reasoning_text(self, text):
        """Clean and extract reasoning text from response"""
        #print(f"Original text: '{text}'")
        
        # Extract content from <think> tags if present
        think_match = re.search(r'<think>(.*?)</think>', text, flags=re.DOTALL)
        if think_match:
            text = think_match.group(1)
            #print(f"Extracted from <think> tags: '{text}'")
        else:
            print("No <think> tags found, using original text")
        
        # Remove <answer> tags entirely
        text = re.sub(r'<answer>.*?</answer>', '', text, flags=re.DOTALL)
        
        # Clean whitespace
        text = ' '.join(text.split())
        
        #print(f"Final cleaned text: '{text}'")
        return text.strip()
    
    def _call_extraction_llm(self, prompt):
        """Call the extraction LLM with the prompt"""
        full_prompt = f"""You are a medical fact extraction specialist.\n\n{prompt}"""
    
        try:
            response = self.client.chat_completion(
                messages=[{"role": "user", "content": full_prompt}],
                temperature=0.1,
                max_tokens=1000
            )
    
            content = response['choices'][0]['message']['content']
            return content
    
        except Exception as e:
            print(f"API call failed: {e}")
            return ""

    def extract_facts(self, reasoning_text):
        try:
            cleaned_text = self._clean_reasoning_text(reasoning_text)
            prompt = self.extraction_prompt.format(reasoning_text=cleaned_text)
            response = self._call_extraction_llm(prompt)
            facts = self._parse_extraction_response(response)
            return facts

        except Exception as e:
            print(f"Error in fact extraction: {e}")
            return []
        
    def diagnose_evaluation(self, responses):
        """Diagnostic method to debug fact extraction issues"""
        print("=== DIAGNOSTIC REPORT ===")
        
        for i, response in enumerate(responses[:5]):  # Check first 5
            print(f"\n--- Response {i+1} ---")
            print(f"Length: {len(response)}")
            print(f"Has <think>: {'<think>' in response}")
            print(f"Has <answer>: {'<answer>' in response}")
            
            try:
                facts = self.extract_facts(response)
                print(f"Extracted facts: {len(facts)} facts")
                if facts:
                    print(f"First fact: {facts[0].text if hasattr(facts[0], 'text') else facts[0]}")
            except Exception as e:
                print(f"Fact extraction error: {e}")
                import traceback
                traceback.print_exc()
            
            print(f"Preview: {response[:200]}...")
            
            if i < 2:
                print(f"Full response:\n{response}\n")

In [None]:
# def quick_diagnostic_test():
#     """Quick test with sample responses"""
    
#     sample_responses = [
#         """<think>
#         The patient presents with chest pain and elevated troponins. This suggests myocardial infarction.
#         Aspirin and clopidogrel are indicated for antiplatelet therapy.
#         </think>
#         <answer>A</answer>""",
        
#         """<think>
#         Diabetic ketoacidosis is characterized by hyperglycemia, ketosis, and acidosis.
#         Insulin therapy is the primary treatment.
#         </think>
#         <answer>B</answer>""",
        
#         """No think tags here, just some text.
#         <answer>C</answer>""",
        
#         """<think>
#         This is malformed JSON response test
#         </think>
#         <answer>D</answer>""",
        
#         """Completely malformed response without proper tags"""
#     ]
    
#     # Initialize your extractor
#     fact_extractor = AtomicFactExtractor(model_name="anthropic.claude-3-haiku-20240307-v1:0", region_name="us-east-1")
    
#     # Run diagnostic
#     fact_extractor.diagnose_evaluation(sample_responses)


# quick_diagnostic_test()/

## LLM Judge

In [None]:
from botocore.exceptions import ClientError

class ClaudeBedrockClient:
    def __init__(self, model_name, region_name):
        self.model_name = model_name
        self.region_name = region_name

        try:
            self.bedrock_client = boto3.client(
                service_name='bedrock-runtime',
                region_name=region_name
            )
            print(f"Initialized Claude Bedrock client with model: {model_name}")
        except Exception as e:
            print(f"Failed to initialize Bedrock client: {e}")
            raise

    def chat_completion(self, messages, temperature=0.2, max_tokens=1000):
        try:
            claude_messages = self._convert_messages_to_claude_format(messages)

            body = {
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": max_tokens,
                "temperature": temperature,
                "messages": claude_messages
            }

            # Make request to Bedrock
            response = self.bedrock_client.invoke_model(
                modelId=self.model_name,
                body=json.dumps(body),
                contentType='application/json'
            )

            # Parse response
            response_body = json.loads(response['body'].read())

            # Convert to OpenAI-like format
            return self._convert_claude_response_to_openai_format(response_body)

        except ClientError as e:
            print(f"Bedrock API error: {e}")
            raise
        except Exception as e:
            print(f"Unexpected error in chat completion: {e}")
            raise

    def _convert_messages_to_claude_format(self, messages):
        """Convert OpenAI message format to Claude format"""
        claude_messages = []

        for msg in messages:
            role = msg['role']
            content = msg['content']

            if role == 'system':
                if claude_messages and claude_messages[-1]['role'] == 'user':
                    claude_messages[-1]['content'] = f"{content}\n\n{claude_messages[-1]['content']}"
                else:
                    claude_messages.append({
                        'role': 'user',
                        'content': content
                    })
            elif role in ['user', 'assistant']:
                claude_messages.append({
                    'role': role,
                    'content': content
                })

        if not claude_messages or claude_messages[0]['role'] != 'user':
            claude_messages.insert(0, {
                'role': 'user',
                'content': 'Please help me with the following task.'
            })

        return claude_messages

    def _convert_claude_response_to_openai_format(self, claude_response: Dict):
        """Convert Claude response to OpenAI-like format"""
        content = ""

        if 'content' in claude_response and claude_response['content']:
            content = claude_response['content'][0]['text']

        return {
            'choices': [{
                'message': {
                    'role': 'assistant',
                    'content': content
                },
                'finish_reason': claude_response.get('stop_reason', 'stop')
            }],
            'usage': {
                'prompt_tokens': claude_response.get('usage', {}).get('input_tokens', 0),
                'completion_tokens': claude_response.get('usage', {}).get('output_tokens', 0),
                'total_tokens': claude_response.get('usage', {}).get('input_tokens', 0) +
                                claude_response.get('usage', {}).get('output_tokens', 0)
            },
            'model': self.model_name
        }


In [None]:
class LLMJudge:
    def __init__(self, model_name, region_name):

        self.claude_client = ClaudeBedrockClient(model_name, region_name)
        self.model_name = model_name

    def judge_fact_accuracy(self, fact, context="", temperature=0.1):
        full_prompt = f"""You are a medical expert evaluating the accuracy of medical facts. 
        Provide a confidence score between 0.0 and 1.0 where:
        - 1.0 = Definitely accurate and well-established medical fact
        - 0.8-0.9 = Very likely accurate with strong evidence
        - 0.6-0.7 = Probably accurate but may have exceptions
        - 0.4-0.5 = Uncertain or conflicting evidence
        - 0.2-0.3 = Probably inaccurate
        - 0.0-0.1 = Definitely inaccurate or harmful
        Please evaluate this medical fact:

        Fact: {fact}

        Context: {context if context else "No additional context provided"}

        Respond with just a JSON object containing:
        {{
            "confidence_score": <float between 0.0 and 1.0>,
            "reasoning": "<brief explanation>",
            "medical_category": "<relevant medical specialty>"
        }}

        Provide your assessment as JSON only."""

        messages = [{"role": "user", "content": full_prompt}]

        try:
            response = self.claude_client.chat_completion(
                messages=messages,
                temperature=temperature,
                max_tokens=500
            )

            content = response['choices'][0]['message']['content']

            try:
                json_match = re.search(r'\{.*\}', content, re.DOTALL)
                if json_match:
                    result = json.loads(json_match.group())

                    if 'confidence_score' not in result:
                        result['confidence_score'] = 0.5

                    result['model_used'] = self.model_name
                    result['raw_response'] = content

                    return result
                else:
                    raise ValueError("No JSON found in response")

            except (json.JSONDecodeError, ValueError) as e:
                print(f"Failed to parse Claude response as JSON: {e}")
                print(f"Raw response: {content}")

                score_match = re.search(r'(\d+\.?\d*)', content)
                score = float(score_match.group(1)) if score_match else 0.5
                if score > 1.0:
                    score = score / 10.0

                return {
                    'confidence_score': min(1.0, max(0.0, score)),
                    'reasoning': content[:200] + "..." if len(content) > 200 else content,
                    'medical_category': 'unknown',
                    'model_used': self.model_name,
                    'raw_response': content,
                    'parsing_error': str(e)
                }

        except Exception as e:
            print(f"Error in Claude fact judgment: {e}")
            return {
                'confidence_score': 0.0,
                'reasoning': f'Error occurred: {str(e)}',
                'medical_category': 'error',
                'model_used': self.model_name,
                'error': str(e)
            }

    def batch_judge_facts(self, facts, context=""):
        """Judge multiple facts in batch"""
        results = []

        for i, fact in enumerate(facts):
            print(f"Judging fact {i + 1}/{len(facts)}: {fact[:100]}...")
            result = self.judge_fact_accuracy(fact, context)
            results.append(result)

        return results

In [None]:
class LLMJudgeVerifier:
    def __init__(self, model_name, region_name):
        self.model_name = model_name
        self.verification_prompt = self._build_verification_prompt()
        self.client = ClaudeBedrockClient(model_name, region_name)

    def _build_verification_prompt(self):
        return """You are a medical fact verification specialist. Evaluate the clinical accuracy of the given atomic fact.

        VERIFICATION CRITERIA:
        1. Medical Accuracy: Is the fact clinically correct?
        2. Specificity: Is the claim specific enough to be verifiable?
        3. Context Appropriateness: Does the fact make sense in the clinical context?
        4. Evidence Base: Is this supported by established medical knowledge?

        FACT TO VERIFY: {fact_text}
        MEDICAL CONTEXT: {context}

        You MUST respond with valid JSON in exactly this format (no extra text):
        {{
            "confidence_score": 0.85,
            "is_accurate": true,
            "reasoning": "detailed explanation of your assessment",
            "concerns": ["any specific concerns or caveats"]
        }}

        CRITICAL: The confidence_score must be a decimal number between 0.0 and 1.0.

        ASSESSMENT:"""

    def verify_fact(self, fact, medical_context=""):
        """
        Verify a single atomic fact using LLM judge
        Returns confidence score [0, 1]
        """
        try:
            # Format verification prompt
            prompt = self.verification_prompt.format(
                fact_text=fact.text,
                context=medical_context
            )
            response = self._call_verification_llm(prompt)
            confidence = self._parse_verification_response(response)
            fact.llm_score = confidence
            return confidence

        except Exception as e:
            print(f"Error in LLM verification: {e}")
            return 0.0

    def _call_verification_llm(self, prompt):
        full_prompt = f"""You are a medical fact verification specialist.

    {prompt}"""
        response = self.client.chat_completion(
            messages=[{"role": "user", "content": full_prompt}],
            temperature=0.1,
            max_tokens=500
        )
        return response['choices'][0]['message']['content']

    def _parse_verification_response(self, response):
        try:
            if not response or response.strip() == "":
                return 0.0

            cleaned = response.strip()
            cleaned = re.sub(r'```json\s*', '', cleaned)
            cleaned = re.sub(r'```\s*', '', cleaned)

            json_patterns = [
                r'\{[^{}]*"confidence_score"[^{}]*:[^{}]*[0-9.]+[^{}]*\}',
                r'\{.*?"confidence_score".*?\}',
                r'\{.*\}',
            ]
            for pattern in json_patterns:
                matches = re.findall(pattern, cleaned, re.DOTALL | re.IGNORECASE)
                for match in matches:
                    try:
                        json_str = match.strip()
                        json_str = re.sub(r',\s*}', '}', json_str)
                        json_str = re.sub(r',\s*]', ']', json_str)

                        data = json.loads(json_str)
                        score = data.get("confidence_score")
                        if score is not None:
                            return max(0.0, min(1.0, float(score)))

                    except (json.JSONDecodeError, ValueError, TypeError):
                        continue

            score_patterns = [
                r'"confidence_score"[:\s]*([0-9.]+)',
                r'confidence[_\s]*score[:\s]*([0-9.]+)',
                r'score[:\s]*([0-9.]+)',
                r'confidence[:\s]*([0-9.]+)'
            ]

            for pattern in score_patterns:
                match = re.search(pattern, cleaned, re.IGNORECASE)
                if match:
                    try:
                        score = float(match.group(1))
                        return max(0.0, min(1.0, score))
                    except (ValueError, IndexError):
                        continue

            number_match = re.search(r'([0-9.]+)', cleaned)
            if number_match:
                try:
                    score = float(number_match.group(1))
                    if 0.0 <= score <= 1.0:
                        return score
                    elif 0.0 <= score <= 5.0:
                        return score / 5.0
                    elif 0.0 <= score <= 10.0:
                        return score / 10.0
                except ValueError:
                    pass

            print(f"Could not parse verification response: {response[:200]}")
            return 0.5

        except Exception as e:
            print(f"Error parsing verification response: {e}")
            return 0.0

## KB Verify

### Local KB

In [None]:
class LocalMedicalKnowledgeBase:
    def __init__(self,
                 embedding_model= 'all-MiniLM-L6-v2',
                 cache_dir= "local_kb_cache"):

        self.embedding_model = SentenceTransformer(embedding_model, device='cpu')
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)

        self.knowledge_base = []
        self.knowledge_embeddings = None

        self._initialize_medical_knowledge()

    def _initialize_medical_knowledge(self):
        """Initialize with comprehensive medical facts"""
        cache_file = os.path.join(self.cache_dir, "local_medical_knowledge.pkl")

        if os.path.exists(cache_file):
            try:
                with open(cache_file, 'rb') as f:
                    cached_data = pickle.load(f)
                    self.knowledge_base = cached_data['facts']
                    self.knowledge_embeddings = cached_data['embeddings']
                    print(f"Loaded {len(self.knowledge_base)} medical facts from cache")
                    return
            except Exception as e:
                print(f"Failed to load cache: {e}")
        
        print("Building local medical knowledge base...")
        self.knowledge_base = self._create_medical_facts()
        
        fact_texts = [fact['text'] for fact in self.knowledge_base]
        print("Generating embeddings for medical facts...")
        self.knowledge_embeddings = self.embedding_model.encode(fact_texts)
        
        try:
            with open(cache_file, 'wb') as f:
                pickle.dump({
                    'facts': self.knowledge_base,
                    'embeddings': self.knowledge_embeddings
                }, f)
            print(f"Cached {len(self.knowledge_base)} medical facts")
        except Exception as e:
            print(f"Failed to cache knowledge base: {e}")
    
    def _create_medical_facts(self):
        """Create comprehensive medical knowledge base"""
        return [
            # CARDIOLOGY
            {"text": "Myocardial infarction is caused by coronary artery occlusion leading to cardiac muscle death", "category": "cardiology", "source": "medical_textbook", "confidence": 0.95},
            {"text": "Chest pain, shortness of breath, and diaphoresis are classic symptoms of myocardial infarction", "category": "cardiology", "source": "clinical_guidelines", "confidence": 0.92},
            {"text": "Hypertension is defined as systolic blood pressure ≥140 mmHg or diastolic ≥90 mmHg", "category": "cardiology", "source": "AHA_guidelines", "confidence": 0.96},
            {"text": "ACE inhibitors are first-line treatment for hypertension and heart failure", "category": "cardiology", "source": "treatment_guidelines", "confidence": 0.90},
            {"text": "Atrial fibrillation significantly increases the risk of stroke", "category": "cardiology", "source": "clinical_studies", "confidence": 0.94},
            {"text": "Beta-blockers reduce heart rate and myocardial oxygen demand", "category": "cardiology", "source": "pharmacology", "confidence": 0.93},
            {"text": "Electrocardiography shows ST-elevation in acute myocardial infarction", "category": "cardiology", "source": "diagnostic_criteria", "confidence": 0.91},
            {"text": "Cardiac catheterization is the gold standard for diagnosing coronary artery disease", "category": "cardiology", "source": "diagnostic_procedures", "confidence": 0.89},
            
            # ENDOCRINOLOGY
            {"text": "Type 1 diabetes mellitus is caused by autoimmune destruction of pancreatic beta cells", "category": "endocrinology", "source": "pathophysiology", "confidence": 0.96},
            {"text": "Type 2 diabetes mellitus is characterized by insulin resistance and relative insulin deficiency", "category": "endocrinology", "source": "pathophysiology", "confidence": 0.95},
            {"text": "Metformin is the first-line treatment for type 2 diabetes mellitus", "category": "endocrinology", "source": "ADA_guidelines", "confidence": 0.94},
            {"text": "HbA1c ≥6.5% indicates diabetes mellitus diagnosis", "category": "endocrinology", "source": "diagnostic_criteria", "confidence": 0.96},
            {"text": "Insulin is absolutely required for type 1 diabetes management", "category": "endocrinology", "source": "treatment_standards", "confidence": 0.98},
            {"text": "Diabetic ketoacidosis is a life-threatening complication of diabetes", "category": "endocrinology", "source": "emergency_medicine", "confidence": 0.93},
            {"text": "Hypoglycemia symptoms include diaphoresis, tremor, and altered mental status", "category": "endocrinology", "source": "clinical_presentation", "confidence": 0.91},
            {"text": "Thyroid stimulating hormone (TSH) is elevated in hypothyroidism", "category": "endocrinology", "source": "laboratory_medicine", "confidence": 0.94},
            
            # PULMONOLOGY
            {"text": "Pneumonia causes consolidation visible on chest radiograph", "category": "pulmonology", "source": "radiology", "confidence": 0.90},
            {"text": "Asthma is characterized by reversible airway obstruction and inflammation", "category": "pulmonology", "source": "pathophysiology", "confidence": 0.93},
            {"text": "Albuterol is a short-acting beta-2 agonist bronchodilator", "category": "pulmonology", "source": "pharmacology", "confidence": 0.96},
            {"text": "Chronic obstructive pulmonary disease (COPD) is primarily caused by tobacco smoking", "category": "pulmonology", "source": "epidemiology", "confidence": 0.91},
            {"text": "Pulmonary embolism can cause sudden onset dyspnea and chest pain", "category": "pulmonology", "source": "clinical_presentation", "confidence": 0.89},
            {"text": "Spirometry shows reduced FEV1/FVC ratio in obstructive lung disease", "category": "pulmonology", "source": "pulmonary_function", "confidence": 0.92},
            {"text": "Oxygen saturation below 90% indicates significant hypoxemia", "category": "pulmonology", "source": "critical_care", "confidence": 0.88},
            
            # INFECTIOUS DISEASE
            {"text": "Sepsis is a life-threatening organ dysfunction caused by dysregulated host response to infection", "category": "infectious_disease", "source": "sepsis_guidelines", "confidence": 0.94},
            {"text": "Penicillin is effective against gram-positive bacterial infections", "category": "infectious_disease", "source": "microbiology", "confidence": 0.92},
            {"text": "Viral infections do not respond to antibiotic treatment", "category": "infectious_disease", "source": "antimicrobial_stewardship", "confidence": 0.97},
            {"text": "Urinary tract infections commonly present with dysuria and urinary frequency", "category": "infectious_disease", "source": "clinical_presentation", "confidence": 0.88},
            {"text": "Blood cultures should be obtained before starting empiric antibiotic therapy", "category": "infectious_disease", "source": "diagnostic_guidelines", "confidence": 0.85},
            {"text": "Methicillin-resistant Staphylococcus aureus (MRSA) requires vancomycin treatment", "category": "infectious_disease", "source": "antimicrobial_guidelines", "confidence": 0.91},
            
            # NEUROLOGY
            {"text": "Stroke symptoms include sudden onset focal neurological deficits", "category": "neurology", "source": "clinical_criteria", "confidence": 0.93},
            {"text": "Tissue plasminogen activator (tPA) is used for acute ischemic stroke within 4.5 hours", "category": "neurology", "source": "stroke_guidelines", "confidence": 0.90},
            {"text": "Seizures can be focal or generalized based on their origin and spread", "category": "neurology", "source": "epilepsy_classification", "confidence": 0.94},
            {"text": "Computed tomography (CT) can rapidly detect hemorrhagic stroke", "category": "neurology", "source": "neuroimaging", "confidence": 0.87},
            {"text": "Lumbar puncture is contraindicated with increased intracranial pressure", "category": "neurology", "source": "procedural_guidelines", "confidence": 0.89},
            {"text": "Multiple sclerosis causes demyelinating lesions in the central nervous system", "category": "neurology", "source": "pathophysiology", "confidence": 0.92},
            
            # PSYCHIATRY
            {"text": "Major depressive disorder requires at least 2 weeks of depressive symptoms", "category": "psychiatry", "source": "DSM5", "confidence": 0.96},
            {"text": "Selective serotonin reuptake inhibitors (SSRIs) are first-line treatment for depression", "category": "psychiatry", "source": "treatment_guidelines", "confidence": 0.88},
            {"text": "Bipolar disorder includes both manic and depressive episodes", "category": "psychiatry", "source": "DSM5", "confidence": 0.95},
            {"text": "Suicidal ideation requires immediate safety assessment and intervention", "category": "psychiatry", "source": "crisis_intervention", "confidence": 0.97},
            {"text": "Antipsychotic medications are used to treat schizophrenia and psychotic disorders", "category": "psychiatry", "source": "psychopharmacology", "confidence": 0.91},
            
            # PHARMACOLOGY
            {"text": "Warfarin requires regular INR monitoring due to narrow therapeutic window", "category": "pharmacology", "source": "anticoagulation_guidelines", "confidence": 0.93},
            {"text": "Nonsteroidal anti-inflammatory drugs (NSAIDs) can cause gastric ulceration", "category": "pharmacology", "source": "adverse_effects", "confidence": 0.90},
            {"text": "Statins reduce cholesterol synthesis by inhibiting HMG-CoA reductase", "category": "pharmacology", "source": "mechanism_of_action", "confidence": 0.94},
            {"text": "Opioids can cause respiratory depression at high doses", "category": "pharmacology", "source": "toxicology", "confidence": 0.91},
            {"text": "Drug interactions can alter medication effectiveness and safety", "category": "pharmacology", "source": "clinical_pharmacology", "confidence": 0.87},
            
            # PEDIATRICS
            {"text": "Sudden infant death syndrome risk is reduced by supine sleeping position", "category": "pediatrics", "source": "AAP_guidelines", "confidence": 0.92},
            {"text": "Febrile seizures are common in children aged 6 months to 5 years", "category": "pediatrics", "source": "pediatric_neurology", "confidence": 0.89},
            {"text": "Vaccination schedules protect children from preventable diseases", "category": "pediatrics", "source": "immunization_guidelines", "confidence": 0.95},
            {"text": "Growth charts assess normal pediatric development", "category": "pediatrics", "source": "developmental_medicine", "confidence": 0.86},
            
            # SURGERY
            {"text": "Appendicitis typically presents with right lower quadrant abdominal pain", "category": "surgery", "source": "surgical_diagnosis", "confidence": 0.90},
            {"text": "Cholecystitis causes right upper quadrant pain and Murphy's sign", "category": "surgery", "source": "surgical_examination", "confidence": 0.88},
            {"text": "Surgical site infections are prevented by proper sterile technique", "category": "surgery", "source": "infection_control", "confidence": 0.91},
            {"text": "Bowel obstruction can cause abdominal distension and vomiting", "category": "surgery", "source": "surgical_emergencies", "confidence": 0.87},
            
            # OBSTETRICS/GYNECOLOGY
            {"text": "Preeclampsia is defined by hypertension and proteinuria in pregnancy", "category": "obstetrics", "source": "ACOG_guidelines", "confidence": 0.94},
            {"text": "Folic acid supplementation prevents neural tube defects", "category": "obstetrics", "source": "preventive_medicine", "confidence": 0.92},
            {"text": "Regular prenatal care improves maternal and fetal outcomes", "category": "obstetrics", "source": "prenatal_guidelines", "confidence": 0.89},
            {"text": "Gestational diabetes increases risk of macrosomia", "category": "obstetrics", "source": "maternal_fetal_medicine", "confidence": 0.86},
            
            # EMERGENCY MEDICINE
            {"text": "Advanced Cardiac Life Support (ACLS) protocols guide cardiac arrest management", "category": "emergency_medicine", "source": "AHA_guidelines", "confidence": 0.95},
            {"text": "Trauma patients require primary and secondary survey assessment", "category": "emergency_medicine", "source": "ATLS_guidelines", "confidence": 0.93},
            {"text": "Anaphylaxis is treated with intramuscular epinephrine", "category": "emergency_medicine", "source": "allergy_guidelines", "confidence": 0.96},
            {"text": "Glasgow Coma Scale assesses level of consciousness", "category": "emergency_medicine", "source": "neurological_assessment", "confidence": 0.91},
            
            # NEPHROLOGY
            {"text": "Chronic kidney disease is staged based on glomerular filtration rate", "category": "nephrology", "source": "KDIGO_guidelines", "confidence": 0.93},
            {"text": "Acute kidney injury can be prerenal, intrinsic, or postrenal", "category": "nephrology", "source": "nephrology_classification", "confidence": 0.90},
            {"text": "Dialysis is indicated for severe uremia or fluid overload", "category": "nephrology", "source": "renal_replacement_therapy", "confidence": 0.88},
            {"text": "Proteinuria indicates glomerular kidney disease", "category": "nephrology", "source": "laboratory_findings", "confidence": 0.85},
            
            # GASTROENTEROLOGY
            {"text": "Gastroesophageal reflux disease (GERD) causes heartburn and regurgitation", "category": "gastroenterology", "source": "clinical_presentation", "confidence": 0.87},
            {"text": "Peptic ulcer disease is commonly caused by Helicobacter pylori", "category": "gastroenterology", "source": "etiology", "confidence": 0.91},
            {"text": "Inflammatory bowel disease includes Crohn's disease and ulcerative colitis", "category": "gastroenterology", "source": "disease_classification", "confidence": 0.93},
            {"text": "Cirrhosis can lead to portal hypertension and ascites", "category": "gastroenterology", "source": "hepatology", "confidence": 0.89},
            
            # ONCOLOGY
            {"text": "Cancer staging determines prognosis and treatment approach", "category": "oncology", "source": "cancer_guidelines", "confidence": 0.92},
            {"text": "Chemotherapy targets rapidly dividing cancer cells", "category": "oncology", "source": "cancer_treatment", "confidence": 0.90},
            {"text": "Tumor markers can aid in cancer diagnosis and monitoring", "category": "oncology", "source": "laboratory_oncology", "confidence": 0.84},
            {"text": "Radiation therapy delivers targeted energy to destroy cancer cells", "category": "oncology", "source": "radiation_oncology", "confidence": 0.88},
            
            # DERMATOLOGY
            {"text": "Melanoma is the most dangerous form of skin cancer", "category": "dermatology", "source": "dermatopathology", "confidence": 0.91},
            {"text": "Topical corticosteroids treat inflammatory skin conditions", "category": "dermatology", "source": "dermatologic_therapeutics", "confidence": 0.86},
            {"text": "Skin biopsy provides definitive diagnosis of skin lesions", "category": "dermatology", "source": "diagnostic_procedures", "confidence": 0.89},
            
            # RHEUMATOLOGY
            {"text": "Rheumatoid arthritis is an autoimmune inflammatory joint disease", "category": "rheumatology", "source": "autoimmune_diseases", "confidence": 0.93},
            {"text": "Disease-modifying antirheumatic drugs (DMARDs) slow joint destruction", "category": "rheumatology", "source": "treatment_guidelines", "confidence": 0.90},
            {"text": "Systemic lupus erythematosus affects multiple organ systems", "category": "rheumatology", "source": "connective_tissue_disorders", "confidence": 0.88},
        ]
    
    def verify_fact(self, fact, threshold: float = 0.7):
        if not self.knowledge_base or self.knowledge_embeddings.size == 0:
            return 0.5
        
        try:
            # Generate embedding for the fact
            fact_text = fact.text if hasattr(fact, 'text') else str(fact)
            fact_embedding = self.embedding_model.encode([fact_text])
            
            # Calculate similarities
            similarities = np.dot(fact_embedding, self.knowledge_embeddings.T)[0]
            
            # Get the best matching facts
            best_matches_idx = np.argsort(similarities)[-3:][::-1]  # Top 3 matches
            best_similarities = similarities[best_matches_idx]
            
            # Calculate confidence score based on similarity and source confidence
            confidence_scores = []
            for idx, similarity in zip(best_matches_idx, best_similarities):
                kb_fact = self.knowledge_base[idx]
                source_confidence = kb_fact.get('confidence', 0.8)
                
                # Combine similarity and source confidence
                combined_score = similarity * source_confidence
                confidence_scores.append(combined_score)
            
            # Use the best match
            final_confidence = max(confidence_scores) if confidence_scores else 0.0
            
            # Apply threshold-based adjustment
            if max(best_similarities) < threshold:
                final_confidence *= 0.5  # Penalize low similarity
            
            # Update fact with KB score if it's an AtomicFact object
            if hasattr(fact, 'kb_score'):
                fact.kb_score = final_confidence
            
            return min(1.0, max(0.0, final_confidence))
            
        except Exception as e:
            print(f"Error in local knowledge base verification: {e}")
            return 0.0
    
    def search_facts(self, query, limit = 10):
        """Search for facts similar to the query"""
        if not self.knowledge_base or self.knowledge_embeddings.size == 0:
            return []
        
        try:
            # Generate embedding for the query
            query_embedding = self.embedding_model.encode([query])
            
            # Calculate similarities
            similarities = np.dot(query_embedding, self.knowledge_embeddings.T)[0]
            
            # Get top matches
            top_indices = np.argsort(similarities)[-limit:][::-1]
            
            results = []
            for idx in top_indices:
                fact = self.knowledge_base[idx].copy()
                fact['similarity_score'] = similarities[idx]
                results.append(fact)
            
            return results
            
        except Exception as e:
            print(f"Error searching local knowledge base: {e}")
            return []
    
    def get_knowledge_stats(self):
        """Get statistics about the local knowledge base"""
        if not self.knowledge_base:
            return {"total_facts": 0, "categories": [], "sources": []}
        
        categories = {}
        sources = {}
        
        for fact in self.knowledge_base:
            cat = fact.get('category', 'unknown')
            src = fact.get('source', 'unknown')
            
            categories[cat] = categories.get(cat, 0) + 1
            sources[src] = sources.get(src, 0) + 1
        
        return {
            "total_facts": len(self.knowledge_base),
            "categories": categories,
            "sources": sources,
            "has_embeddings": self.knowledge_embeddings is not None and self.knowledge_embeddings.size > 0
        }
    
    def add_facts(self, new_facts: List[Dict]):
        """Add new facts to the knowledge base"""
        self.knowledge_base.extend(new_facts)

        # Regenerate embeddings
        fact_texts = [fact['text'] for fact in self.knowledge_base]
        print(f"Regenerating embeddings for {len(fact_texts)} facts...")
        self.knowledge_embeddings = self.embedding_model.encode(fact_texts)

        print(f"Added {len(new_facts)} new facts to knowledge base")


class LocalKnowledgeBaseVerifier:
    """Drop-in replacement using local knowledge base"""
    
    def __init__(self, knowledge_sources: List[str] = None):
        self.knowledge_sources = knowledge_sources or ["Local Medical DB"]
        self.local_kb = LocalMedicalKnowledgeBase()
        
        print("Local Knowledge Base Verifier initialized")
        stats = self.local_kb.get_knowledge_stats()
        print(f"Loaded {stats['total_facts']} facts across {len(stats['categories'])} categories")
    
    def verify_fact(self, fact, threshold: float = 0.7):
        """Verify fact against local knowledge base"""
        return self.local_kb.verify_fact(fact, threshold)
    
    def search_facts(self, query, limit = 10):
        """Search for relevant facts"""
        return self.local_kb.search_facts(query, limit)

## Fact Verifier

In [None]:
class FactualRewardCalculator:
    def __init__(self, agreement_threshold):
      self.agreement_threshold = agreement_threshold

    def extract_think_content(self, generation):
        """Extract content from <think> tags"""
        think_match = re.search(r'<think>(.*?)</think>', generation, re.DOTALL | re.IGNORECASE)
        if think_match:
            content = think_match.group(1).strip()
            return re.sub(r'^Reasoning:\s*', '', content)
        return ""

    def compute_factual_reward(self, facts_list):
        if not facts_list:
            return {
                "factual_reward": 0.0,
                "factual_analysis": {
                    "factual_reward": 0.0,
                    "individual_rewards": [],
                    "agreement_rate": 0.0,
                    "avg_llm_score": 0.0,
                    "avg_kb_score": 0.0,
                    "num_facts": 0
                },
                "error": "No facts provided"
            }
        atomic_facts = []
        for fact_item in facts_list:
            if isinstance(fact_item, AtomicFact):
                atomic_facts.append(fact_item)
            else:
                fact = AtomicFact(
                    text=fact_item.get("text", ""),
                    category=fact_item.get("category", "unknown"),
                    source_sentence=fact_item.get("source_sentence", "")
                )
                fact.llm_score = fact_item.get("llm_score", 0.0)
                fact.kb_score = fact_item.get("kb_score", 0.0)
                atomic_facts.append(fact)

        individual_rewards = []
        agreement_count = 0
        llm_scores = []
        kb_scores = []

        for fact in atomic_facts:
            fact_reward = self._compute_individual_fact_reward(fact)
            individual_rewards.append(fact_reward)

            if abs(fact.llm_score - fact.kb_score) <= self.agreement_threshold:
                agreement_count += 1

            llm_scores.append(fact.llm_score)
            kb_scores.append(fact.kb_score)

        factual_reward = sum(individual_rewards) / len(individual_rewards) if individual_rewards else 0.0

        agreement_rate = agreement_count / len(atomic_facts) if atomic_facts else 0.0
        avg_llm_score = sum(llm_scores) / len(llm_scores) if llm_scores else 0.0
        avg_kb_score = sum(kb_scores) / len(kb_scores) if kb_scores else 0.0

        factual_analysis = {
            "factual_reward": factual_reward,
            "individual_rewards": individual_rewards,
            "agreement_rate": agreement_rate,
            "avg_llm_score": avg_llm_score,
            "avg_kb_score": avg_kb_score,
            "num_facts": len(atomic_facts)
        }

        return {
            "factual_reward": factual_reward,
            "factual_analysis": factual_analysis,
            "error": None
        }

    def _compute_individual_fact_reward(self, fact):
        llm_score = fact.llm_score
        kb_score = fact.kb_score

        # Weight LLM more heavily (e.g., 70% LLM, 30% KB)
        llm_weight = 0.7
        kb_weight = 0.3

        base_score = llm_weight * llm_score + kb_weight * kb_score

        agreement = abs(llm_score - kb_score) <= self.agreement_threshold
        if agreement:
            return base_score
        else:
            return base_score * 0.9

## Atomic Fact Verification System

In [None]:
class AtomicFactVerificationSystem:
    def __init__(self,agreement_threshold,umls_api_key=None):
        self.extractor = AtomicFactExtractor(model_name="anthropic.claude-3-haiku-20240307-v1:0",
                                             region_name="us-east-1")
        self.llm_verifier = LLMJudgeVerifier(model_name="anthropic.claude-3-haiku-20240307-v1:0",
                                             region_name="us-east-1")
        self.kb_verifier = LocalKnowledgeBaseVerifier()
        self.reward_calculator = FactualRewardCalculator(agreement_threshold)
        self.training_mode = False

    def process_response(self, reasoning_text, context=""):
        try:
            facts = self.extractor.extract_facts(reasoning_text)

            if not facts:
                return {
                    "facts": [],
                    "factual_analysis": self._empty_analysis(),
                    "error": "No facts extracted"
                }

            # Check if we're in training mode to skip expensive verification
            if self.training_mode:  # Simplified check - no need for hasattr
                for fact in facts:
                    fact.llm_score = self._get_training_heuristic_score(fact)
                    fact.kb_score = self._get_training_heuristic_score(fact)
            else:
                # During evaluation: Full verification with rate limiting protection
                for fact in facts:
                    try:
                        # LLM verification
                        self.llm_verifier.verify_fact(fact, context)
                        # KB verification
                        self.kb_verifier.verify_fact(fact)
                        # Small delay to avoid rate limits during evaluation
                        time.sleep(4)
                    except Exception as verification_error:
                        print(f"Verification error for fact '{fact.text}': {verification_error}")
                        # Use fallback scores if verification fails
                        if not hasattr(fact, 'llm_score') or fact.llm_score == 0.0:
                            fact.llm_score = 0.5
                        if not hasattr(fact, 'kb_score') or fact.kb_score == 0.0:
                            fact.kb_score = 0.5

            factual_result = self.reward_calculator.compute_factual_reward(facts)
            facts_as_dicts = [self._fact_to_dict(fact) for fact in facts]

            return {
                "facts": facts_as_dicts,
                "factual_analysis": factual_result,
                "error": None
            }

        except Exception as e:
            print(f"Error in process_response: {e}")
            import traceback
            traceback.print_exc()
            return {
                "facts": [],
                "factual_analysis": self._empty_analysis(),
                "error": str(e)
            }

    def _get_training_heuristic_score(self, fact):
        fact_text = fact.text.lower()
        # Simple heuristics based on medical fact patterns
        if any(keyword in fact_text for keyword in ['patient', 'symptom', 'diagnosis', 'treatment']):
            # Medical context facts get moderate-high scores
            return 0.7
        elif any(keyword in fact_text for keyword in ['anatomy', 'physiology', 'mechanism']):
            # Basic medical knowledge gets high scores
            return 0.8
        elif len(fact_text) < 20:
            # Very short facts might be incomplete
            return 0.5
        elif any(keyword in fact_text for keyword in ['may', 'might', 'could', 'possibly']):
            # Uncertain statements get lower scores
            return 0.6
        else:
            # Default reasonable score
            return 0.7

    def _fact_to_dict(self, fact: AtomicFact):
        return {
            "text": fact.text,
            "category": fact.category,
            "llm_score": fact.llm_score,
            "kb_score": fact.kb_score,
            "source_sentence": fact.source_sentence
        }

    def _empty_analysis(self):
        return {
            "factual_reward": 0.0,
            "individual_rewards": [],
            "agreement_rate": 0.0,
            "avg_llm_score": 0.0,
            "avg_kb_score": 0.0,
            "num_facts": 0
        }

# Training and Evaluation

## MetricsTracker

In [None]:
class MetricsTracker:
    def __init__(self):
        self.epoch_metrics = []

    def add_epoch_metrics(self, epoch, metrics):
        self.epoch_metrics.append({
            'epoch': epoch,
            'correctness_violations_rate': metrics.get('correctness_violations_rate', 0),
            'answer_leaking_violations_rate': metrics.get('answer_leaking_violations_rate', 0), 
            'format_violations_rate': metrics.get('format_violations_rate', 0),
            'factual_violations_rate': metrics.get('factual_violations_rate', 0),
            'bad_ood_high_rewards_rate': metrics.get('bad_ood_high_rewards_rate', 0),
            'avg_reward': metrics.get('avg_reward', 0),
            'avg_factual': metrics.get('avg_factual', 0)
        })

    def print_stage_summary(self, stage_name, start_epoch=0, end_epoch=None):
        """Print mean and std for a training stage"""
        if end_epoch is None:
            end_epoch = len(self.epoch_metrics)

        stage_data = self.epoch_metrics[start_epoch:end_epoch]
        if not stage_data:
            print(f"No data for {stage_name}")
            return

        print(f"\n{'='*60}")
        print(f"{stage_name.upper()} SUMMARY (Epochs {start_epoch+1}-{end_epoch})")
        print(f"{'='*60}")
        
        # Extract values for each metric
        metrics = {
            'Correctness Violations': [d['correctness_violations_rate'] for d in stage_data],
            'Answer Leaking': [d['answer_leaking_violations_rate'] for d in stage_data],
            'Format Violations': [d['format_violations_rate'] for d in stage_data], 
            'Factual Violations': [d['factual_violations_rate'] for d in stage_data],
            'Bad OOD Rewards': [d['bad_ood_high_rewards_rate'] for d in stage_data],
            'Average Reward': [d['avg_reward'] for d in stage_data],
            'Factual Reward': [d['avg_factual'] for d in stage_data]
        }
        
        # Print statistics
        for metric_name, values in metrics.items():
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f"{metric_name:20s}: {mean_val:.2f} ± {std_val:.2f}")
            
        # Calculate improvement (first vs last epoch)
        if len(stage_data) > 1:
            print(f"\n{stage_name} Improvements:")
            first_epoch = stage_data[0]
            last_epoch = stage_data[-1]
            
            # For violation rates, improvement is decrease (negative change)
            for metric in ['correctness_violations_rate', 'answer_leaking_violations_rate', 
                          'format_violations_rate', 'factual_violations_rate', 'bad_ood_high_rewards_rate']:
                change = last_epoch[metric] - first_epoch[metric]
                improvement = -change  # Negative change is improvement for violations
                print(f"  {metric.replace('_', ' ').title():25s}: {improvement:+.2f}")
            
            # For rewards, improvement is increase (positive change) 
            for metric in ['avg_reward', 'avg_factual']:
                change = last_epoch[metric] - first_epoch[metric]
                print(f"  {metric.replace('_', ' ').title():25s}: {change:+.2f}")


    def print_metrics_with_std(self, metrics, title="METRICS", include_std=True):
        """Print metrics in consistent format with optional standard deviations"""
        print(f"\n{title}")
        print("-" * len(title))
        
        # Define metric display configurations
        metric_configs = [
            ('correctness_violations_rate', 'Correctness Violations', 'correctness_violations_std'),
            ('answer_leaking_violations_rate', 'Answer Leakage Rate', 'answer_leaking_violations_std'),
            ('format_violations_rate', 'Bad Format Rate', 'format_violations_std'),
            ('factual_violations_rate', 'Factual Violations', 'factual_violations_std'),
            ('bad_ood_high_rewards_rate', 'High Rewards Rate', 'bad_ood_high_rewards_std'),
            ('avg_reward', 'Average Reward', 'std_reward'),
            ('avg_factual', 'Factual Reward', None),
            ('accuracy', 'Accuracy', None)
        ]
        
        for metric_key, display_name, std_key in metric_configs:
            if metric_key in metrics:
                value = metrics[metric_key]
                if include_std and std_key and std_key in metrics:
                    std_value = metrics[std_key]
                    print(f"{display_name}: {value:.2f} (±{std_value:.2f})")
                else:
                    print(f"{display_name}: {value:.2f}")
    
    def add_adversarial_stage_metrics(self, pre_metrics, post_metrics):
        """Add special adversarial training stage metrics"""
        self.add_epoch_metrics(-1, pre_metrics)
        self.add_epoch_metrics(999, post_metrics)
        
        # Calculate improvement metrics
        improvement_metrics = {}
        for key in pre_metrics:
            if key in post_metrics and isinstance(pre_metrics[key], (int, float)):
                improvement_metrics[f"{key}_improvement"] = post_metrics[key] - pre_metrics[key]
        
        # Store improvement metrics
        self.add_epoch_metrics(1000, improvement_metrics)
    
    def print_adversarial_impact_analysis(self):
        """Print comprehensive adversarial training impact analysis"""
        if -1 in self.epoch_metrics and 999 in self.epoch_metrics:
            pre_metrics = self.epoch_metrics[-1]
            post_metrics = self.epoch_metrics[999]
            
            print("\n" + "=" * 60)
            print("ADVERSARIAL TRAINING IMPACT ANALYSIS")
            print("=" * 60)
            
            print("\nPRE-ADVERSARIAL PERFORMANCE:")
            self.print_metrics_with_std(pre_metrics, "", include_std=True)
            
            print("\nPOST-ADVERSARIAL PERFORMANCE:")
            self.print_metrics_with_std(post_metrics, "", include_std=True)

        else:
            print("Adversarial training metrics not found")
    
    
    def save_metrics(self, filepath="training_metrics.json"):
        """Save metrics to JSON file"""
        with open(filepath, 'w') as f:
            json.dump(self.epoch_metrics, f, indent=2)
        print(f"Metrics saved to {filepath}")

def monitor_memory():
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024 ** 3
        memory_reserved = torch.cuda.memory_reserved() / 1024 ** 3
        print(f"GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB")

In [None]:
train_df = pd.read_json("Datasets/medqa_train.json")
test_df = pd.read_json("Datasets/medqa_test.json")
train_subset_df = train_df.iloc[:100]
train_subset_df.to_json("Datasets/medqa_train_sample.json", orient='records', indent=4)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()

model_url = "llama_RM_ADV"

trainer_llama_rm = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="UMLS_KEY"
)

## Llama + RM on MMLU

In [None]:
train_df = pd.read_json("Datasets/medqa_train.json")
test_df = pd.read_json("Datasets/medqa_test.json")
train_subset_df = train_df.iloc[:100]
train_subset_df.to_json("Datasets/medqa_train_sample.json", orient='records', indent=4)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()

model_url = "meta-llama/Llama-3.2-3B-Instruct"

trainer_llama_rm = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="UMLS_KEY"
)

In [None]:
reward_results = trainer_llama_rm.train_reward_policy(
    train_data_path="Datasets/medqa_train_sample.json",
    test_data_path="Datasets/mmlu_pro_health_test.json",
    stage1_epochs=1,
    stage2_epochs=1,
    max_eval_examples=20
)

print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

# Set your S3 bucket and path
bucket_name = 'b'
s3_prefix = 'medqa-models/reinforce-llama-run3/'

# Save model locally
local_model_dir = '/home/ec2-user/SageMaker/llama_RM_only_2'
os.makedirs(local_model_dir, exist_ok=True)

trainer_llama_rm.model.save_pretrained(local_model_dir)
trainer_llama_rm.tokenizer.save_pretrained(local_model_dir)
torch.save(trainer_llama_rm.baseline_network.state_dict(), os.path.join(local_model_dir, "baseline_network.pt"))

s3 = boto3.client('s3')
for root, dirs, files in os.walk(local_model_dir):
    for file in files:
        local_path = os.path.join(root, file)
        relative_path = os.path.relpath(local_path, local_model_dir)
        s3_path = os.path.join(s3_prefix, relative_path)
        print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
        s3.upload_file(local_path, bucket_name, s3_path)
print("\nUpload to S3 complete!")
print(f"All model files are stored at s3://{bucket_name}/{s3_prefix}")

In [None]:
adversarial_results = trainer_llama_rm.load_reward_model_and_train_adversarial(
    saved_model_path='/home/ec2-user/SageMaker/llama_RM_only_2',
    train_data_path="Datasets/medqa_train_sample.json",
    test_data_path="Datasets/mmlu_pro_health_test.json",
    adversarial_cycles=2,
    max_eval_examples=20
)

In [None]:
# Set your S3 bucket and path
bucket_name = 'b'
s3_prefix = 'medqa-models/reinforce-llama-run3/'

# Save model locally
local_model_dir = '/home/ec2-user/SageMaker/llama_RM_ADV_2'
os.makedirs(local_model_dir, exist_ok=True)

trainer_llama_rm.model.save_pretrained(local_model_dir)
trainer_llama_rm.tokenizer.save_pretrained(local_model_dir)
torch.save(trainer_llama_rm.baseline_network.state_dict(), os.path.join(local_model_dir, "baseline_network.pt"))

s3 = boto3.client('s3')
for root, dirs, files in os.walk(local_model_dir):
    for file in files:
        local_path = os.path.join(root, file)
        relative_path = os.path.relpath(local_path, local_model_dir)
        s3_path = os.path.join(s3_prefix, relative_path)
        print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
        s3.upload_file(local_path, bucket_name, s3_path)
print("\nUpload to S3 complete!")
print(f"All model files are stored at s3://{bucket_name}/{s3_prefix}")

## Llama + RM + ADV on USMLE

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

train_df = pd.read_json("Datasets/medqa_train.json")
test_df = pd.read_json("Datasets/medqa_test.json")

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()
model_url = "meta-llama/Llama-3.2-3B-Instruct"

trainer_3 = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="UMLS_KEY"
)

print("=" * 80)
train_subset_df = train_df.iloc[:100]
train_subset_df.to_json("Datasets/medqa_train_sample.json", orient='records', indent=4)

comparison_results = trainer_3.train_reward_policy(
    train_data_path="Datasets/medqa_train_sample.json",
    test_data_path="Datasets/medqa_test.json",
    stage1_epochs=1,
    stage2_epochs=1,
    max_eval_examples=20
)
print("TRAINING COMPLETE")

## OOD with MMLU-PRO (OG llama)

In [None]:
local_model_dir = '/home/ec2-user/SageMaker/Adversarial_RM_llamaOriginal_RM'

# Configuration for reward function
reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_f': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}

trainer_2 = PolicyTrainer(
    model_path=local_model_dir,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="UMLS_KEY"
)

# 1. Reload model + tokenizer
trainer_2.model = AutoModelForCausalLM.from_pretrained(local_model_dir)
trainer_2.tokenizer = AutoTokenizer.from_pretrained(local_model_dir)

# 2. Reload baseline network
baseline_path = os.path.join(local_model_dir, "baseline_network.pt")
trainer_2.baseline_network.load_state_dict(torch.load(baseline_path, map_location=trainer_2.device))
trainer_2.baseline_network.to(trainer.device)
trainer_2.baseline_network.eval()

print("\nRunning final evaluation on the test set...")
final_results = trainer_2.evaluate_model("mmlu_pro_health_test.json", 200)

## Llama only (MMLU)

## With Qwen2.5-05B-Instruct

In [None]:
import os

os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

train_df = pd.read_json("medqa_train.json")
test_df = pd.read_json("medqa_test.json")

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()
model_url = "Qwen/Qwen2.5-3B-Instruct"

# Train policy stages and save
trainer = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="key"
)

print("=" * 80)
train_subset_df = train_df.iloc[:100]
train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

policy_metrics = trainer.train_reward_policy(
    train_data_path="medqa_train_sample.json",
    test_data_path="medqa_test.json",
    stage1_epochs=1,
    stage2_epochs=1,
    max_eval_examples=20,
    save_checkpoint_path="./qwen_25_policy_checkpoint"
)

print("Policy training complete! Checkpoint saved.")

In [None]:
trainer = PolicyTrainer.load_from_checkpoint(
    checkpoint_path="qwen_25_policy_checkpoint",
    reward_config=reward_config,
    umls_api_key="your-key"
)

adversarial_metrics, adversarial_results = trainer.train_adversarial(
    train_data_path="medqa_train_sample.json",
    test_data_path="medqa_test.json",
    num_cycles=3,
    max_eval_examples=20
)

print("Adversarial training complete!")
print("\n" + "=" * 80)

In [None]:
train_subset_df = train_df.iloc[:100]
train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

comparison_results = trainer_3.compare_training_approaches(
    train_data_path="medqa_train_sample.json",
    test_data_path="medqa_test.json",
    stage1_epochs=1,
    stage2_epochs=1,
    max_eval_examples=20
)
print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

In [None]:
bucket_name = 'b'
s3_prefix = 'medqa-models/reinforce-qwen-run2/'

# Save model locally
local_model_dir = '/home/ec2-user/SageMaker/Adversarial_RM_qwenSFT_RM'
os.makedirs(local_model_dir, exist_ok=True)

trainer_3.model.save_pretrained(local_model_dir)
trainer_3.tokenizer.save_pretrained(local_model_dir)
torch.save(trainer_3.baseline_network.state_dict(), os.path.join(local_model_dir, "baseline_network.pt"))

s3 = boto3.client('s3')
for root, dirs, files in os.walk(local_model_dir):
    for file in files:
        local_path = os.path.join(root, file)
        relative_path = os.path.relpath(local_path, local_model_dir)
        s3_path = os.path.join(s3_prefix, relative_path)
        print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
        s3.upload_file(local_path, bucket_name, s3_path)
print("\nUpload to S3 complete!")
print(f"All model files are stored at s3://{bucket_name}/{s3_prefix}")

## Qwen2.5 RM Only on MMLU

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
model_url = "Qwen/Qwen2.5-3B-Instruct"

trainer_qwen_rm = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="UMLS_KEY"
)

results = trainer_qwen_rm.train_reward_policy(
    train_data_path="Datasets/medqa_train_sample.json",
    test_data_path="Datasets/mmlu_pro_health_test.json",
    stage1_epochs=1,
    stage2_epochs=1,
    max_eval_examples=20
)
print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

# Set your S3 bucket and path
bucket_name = 'b'
s3_prefix = 'medqa-models/reinforce-qwen-run2/'

# Save model locally
local_model_dir = '/home/ec2-user/SageMaker/qwen_RM_only'
os.makedirs(local_model_dir, exist_ok=True)

trainer_qwen_rm.model.save_pretrained(local_model_dir)
trainer_qwen_rm.tokenizer.save_pretrained(local_model_dir)
torch.save(trainer_qwen_rm.baseline_network.state_dict(), os.path.join(local_model_dir, "baseline_network.pt"))

s3 = boto3.client('s3')
for root, dirs, files in os.walk(local_model_dir):
    for file in files:
        local_path = os.path.join(root, file)
        relative_path = os.path.relpath(local_path, local_model_dir)
        s3_path = os.path.join(s3_prefix, relative_path)
        print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
        s3.upload_file(local_path, bucket_name, s3_path)
print("\nUpload to S3 complete!")
print(f"All model files are stored at s3://{bucket_name}/{s3_prefix}")

## Qwen2.5 RM +ADV on MMLU

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
model_url = "Qwen/Qwen2.5-0.5B-Instruct"

trainer_qwen_rm = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key="UMLS_KEY"
)

results = trainer_qwen_rm.train_combined(
    train_data_path="Datasets/medqa_train_sample.json",
    test_data_path="Datasets/mmlu_pro_health_test.json",
    stage1_epochs=2,
    stage2_epochs=1,
    max_eval_examples=20
)

print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

# Set your S3 bucket and path
bucket_name = 'b'
s3_prefix = 'medqa-models/reinforce-qwen-run2/'

# Save model locally
local_model_dir = '/home/ec2-user/SageMaker/qwen_RM_Adv'
os.makedirs(local_model_dir, exist_ok=True)

trainer_qwen_rm.model.save_pretrained(local_model_dir)
trainer_qwen_rm.tokenizer.save_pretrained(local_model_dir)
torch.save(trainer_qwen_rm.baseline_network.state_dict(), os.path.join(local_model_dir, "baseline_network.pt"))

s3 = boto3.client('s3')
for root, dirs, files in os.walk(local_model_dir):
    for file in files:
        local_path = os.path.join(root, file)
        relative_path = os.path.relpath(local_path, local_model_dir)
        s3_path = os.path.join(s3_prefix, relative_path)
        print(f"Uploading {relative_path} to s3://{bucket_name}/{s3_path}")
        s3.upload_file(local_path, bucket_name, s3_path)
print("\nUpload to S3 complete!")
print(f"All model files are stored at s3://{bucket_name}/{s3_prefix}")