# Importing

In [None]:
#The code was run inside AWS Sagemaker ml.g5.4xlarge instance
#%conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1 -c pytorch -c nvidia -y
%pip install -q -U bitsandbytes
%pip install -q -U git+https://github.com/huggingface/transformers.git
%pip install -q -U git+https://github.com/huggingface/accelerate.git
%pip install sentence_transformers==4.1.0
%pip install git+https://github.com/openai/whisper.git

In [None]:
import json
import re
import time
import os
import gc
import boto3
import pickle
from typing import Dict, List
import numpy as np

import bitsandbytes.optim as bnb_optim
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig)
from huggingface_hub import login
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
hf_token = "YOR HUGGINGFACE TOKEN"
login(hf_token)

pd.options.display.max_seq_items = 2000

# Multi Component Reward Function

In [None]:
class EmbeddingModelManager:
    _instance = None
    _model = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(EmbeddingModelManager, cls).__new__(cls)
        return cls._instance

    def get_model(self, model_name='all-MiniLM-L6-v2'):
        if self._model is None:
            print(f"Loading SentenceTransformer: {model_name}")
            self._model = SentenceTransformer(model_name)
            self._model = self._model.cpu()
        return self._model

    def cleanup(self):
        if self._model is not None:
            del self._model
            self._model = None
        torch.cuda.empty_cache()

embedding_manager = EmbeddingModelManager()

In [None]:
class RewardModel:
    def __init__(self,
                 w_b=1.0,
                 w_a=0.2,
                 w_s=0.2,
                 w_fact=0.2,
                 tau_answer=0.7,
                 tau_preamble=15,
                 lambda_s=1.0,
                 violation_threshold_answer=0.1,
                 violation_threshold_structural=0.3,
                 embedding_model='all-MiniLM-L6-v2',
                 factual_agreement_threshold=0.15,
                 umls_api_key=None):

        self.w_b = w_b
        self.w_a = w_a
        self.w_s = w_s
        self.w_fact = w_fact
        self.tau_answer = tau_answer
        self.tau_preamble = tau_preamble
        self.lambda_s = lambda_s
        self.violation_threshold_answer = violation_threshold_answer
        self.violation_threshold_structural = violation_threshold_structural
        self.sentence_model = embedding_manager.get_model(embedding_model)
        #self.sentence_model = self.sentence_model.cpu()

        self.answer_leaking_phrases = [
            "the correct answer is", "the answer is definitely", "the choice is clearly",
            "option A is the right one", "option B is the right one", "option C is the right one",
            "option D is the right one", "we can conclude the answer is", 
            "the solution is A", "the solution is B", "the solution is C", "the solution is D",
            "therefore the answer is", "so the correct choice is", "the final answer is",
            "answer: A", "answer: B", "answer: C", "answer: D"
        ]

        self.leak_embeddings = self.sentence_model.encode(self.answer_leaking_phrases)
        self.fact_verification_system = AtomicFactVerificationSystem(
            agreement_threshold=factual_agreement_threshold,
            umls_api_key=umls_api_key
        )

    def extract_answer_choice(self, generation):
        """Extract the final answer choice from <answer> tags."""
        answer_match = re.search(r'<answer>(.*?)</answer>', generation, re.IGNORECASE | re.DOTALL)
        if not answer_match:
            return None
        answer_content = answer_match.group(1).strip()
        match = re.search(r'\b([A-D])\b', answer_content, re.IGNORECASE)
        return match.group(1).upper() if match else None

    def extract_think_content(self, generation):
        think_match = re.search(r'<think>(.*?)</think>', generation, re.DOTALL | re.IGNORECASE)
        if think_match and think_match.group(1).strip():
            return re.sub(r'^Reasoning:\s*', '', think_match.group(1).strip())
        # Fallback: if no tags, try using the whole response
        return generation.strip()

    def extract_pre_think_content(self, generation):
        """Extract content that appears before the <think> tag."""
        think_start = re.search(r'<think>', generation, re.IGNORECASE)
        return generation[:think_start.start()].strip() if think_start else generation.strip()

    def validate_format(self, generation) -> bool:
        """Check if generation follows required format with content."""
        think_match = re.search(r'<think>\s*(.*?)\s*</think>', generation, re.DOTALL | re.IGNORECASE)
        has_think_content = think_match and think_match.group(1).strip()
        has_answer = bool(re.search(r'<answer>\s*[A-D]\s*</answer>', generation, re.IGNORECASE))
        return bool(has_think_content) and has_answer

    def compute_binary_reward(self, generation, correct_answer):
        predicted_answer = self.extract_answer_choice(generation)
        if predicted_answer is None:
            return 0.0
        return 1.0 if predicted_answer == correct_answer.upper() else 0.0

    def compute_answer_penalty(self, generation):
        think_content = self.extract_think_content(generation)
        if not think_content:
            return 0.0
        think_embedding = self.sentence_model.encode([think_content])
        similarities = cosine_similarity(think_embedding, self.leak_embeddings)[0]
        max_similarity = np.max(similarities)
        return float(max_similarity) if max_similarity > self.tau_answer else 0.0

    def compute_structural_penalty(self, generation):
        """Compute p_structural penalty based on pre-think word count."""
        pre_think_content = self.extract_pre_think_content(generation)
        word_count = len(pre_think_content.split())
        return self.lambda_s if word_count > self.tau_preamble else 0.0

    def compute_factual_reward(self, generation, context=""):
        try:
            reasoning_content = self.extract_think_content(generation)

            if not reasoning_content:
                print(f"[FactExtractor] No reasoning found in generation:\n{generation}\n")
                return {
                    "factual_reward": 0.0,
                    "factual_analysis": {},
                    "error": "No reasoning content found"
                }

            result = self.fact_verification_system.process_response(
                reasoning_content, context
            )

            factual_reward = result.get("factual_analysis", {}).get("factual_reward", 0.0)

            return {
                "factual_reward": factual_reward,
                "factual_analysis": result.get("factual_analysis", {}),
                "facts": result.get("facts", []),
                "error": result.get("error", None)
            }

        except Exception as e:
            return {
                "factual_reward": 0.0,
                "factual_analysis": {},
                "facts": [],
                "error": f"Factual verification failed: {str(e)}"
            }

    def compute_total_reward(self, generation, correct_answer, context=""):
        if not self.validate_format(generation):
            return {
                'r_binary': -1.0,
                'p_answer': 0.0,
                'p_structural': 0.0,
                'r_factual': 0.0,
                'factual_analysis': {},
                'extracted_facts': [],
                'factual_error': "Invalid format - computation terminated",
                'r_total': -1.0,
                'r_normalized': np.tanh(-1.0),
                'format_valid': False
            }

        r_binary = self.compute_binary_reward(generation, correct_answer)
        p_answer = self.compute_answer_penalty(generation)
        p_structural = self.compute_structural_penalty(generation)

        factual_result = self.compute_factual_reward(generation, context)
        r_factual = factual_result["factual_reward"]

        # Compute total reward
        r_total = self.w_b * r_binary - self.w_a * p_answer - self.w_s * p_structural + self.w_fact * r_factual
        r_normalized = np.tanh(r_total)

        return {
            'r_binary': r_binary,
            'p_answer': p_answer,
            'p_structural': p_structural,
            'r_factual': r_factual,
            'factual_analysis': factual_result["factual_analysis"],
            'extracted_facts': factual_result.get("facts", []),
            'factual_error': factual_result.get("error", None),
            'r_total': r_total,
            'r_normalized': r_normalized,
            'format_valid': True
        }

    def calculate_hacking_rate(self, responses_with_rewards):
        total_responses = len(responses_with_rewards)
        valid_responses = 0
        answer_violations = 0
        structural_violations = 0
        factual_violations = 0
        any_violation = 0

        factual_threshold = 0.3

        for item in responses_with_rewards:
            response = item['response']
            reward_info = item.get('reward_info', {})
            if not reward_info.get('format_valid', True):
                continue
            valid_responses += 1

            answer_penalty = self.compute_answer_penalty(response)
            structural_penalty = self.compute_structural_penalty(response)
            factual_reward = reward_info.get('r_factual', 1.0)

            answer_violation = answer_penalty > self.violation_threshold_answer
            structural_violation = structural_penalty > self.violation_threshold_structural
            factual_violation = factual_reward < factual_threshold

            if answer_violation:
                answer_violations += 1
            if structural_violation:
                structural_violations += 1
            if factual_violation:
                factual_violations += 1
            if answer_violation or structural_violation or factual_violation:
                any_violation += 1

        return {
            'total_responses': total_responses,
            'valid_responses': valid_responses,
            'answer_violation_count': answer_violations,
            'invalid_format_count': total_responses - valid_responses,
            'structural_violation_count': structural_violations,
            'factual_violation_count': factual_violations,
            'any_violation_count': any_violation,
            'answer_violation_rate': answer_violations / valid_responses if valid_responses > 0 else 0.0,
            'structural_violation_rate': structural_violations / valid_responses if valid_responses > 0 else 0.0,
            'factual_violation_rate': factual_violations / valid_responses if valid_responses > 0 else 0.0,
            'overall_violation_rate': any_violation / valid_responses if valid_responses > 0 else 0.0
        }

# Data Processor

In [None]:
class MedQADataProcessor:
    """Process MedQA dataset for training."""

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def format_prompt(self, question, options):
        formatted_options = ""
        for letter, option in options.items():
            formatted_options += f"{letter}. {option}\n"
        input_text = f"{question}\n\n{formatted_options.strip()}"
        prompt = f"""You are a medical expert taking the USMLE exam. Given the clinical scenario below, respond with your reasoning in a <think></think> tag and your final answer choice (A, B, C, or D) in an <answer></answer> tag.
        Scenario:
        {input_text}
        Format:
        <think>your step-by-step clinical reasoning goes here</think>
        <answer>A</answer>  # Replace A with your final answer choice
        Here is a sample prompt and response:
        Prompt/Question: A man is brought into the emergency department by the police department. The officer state that the man has been arrested multiple times for public alcohol intoxication, but recently became homeless. On exam, the man is behaving erratically. His vitals are all within normal limits. He appears confused and has a slurred speech. On gait exam, the patient is ataxic and cannot stand without support for more than a few seconds. Labs return with the following values: Na 140, K 4, Cl 106, BUN 8, Cr 2. His ABG has pH 7.3, PaCO2 13mm, PaO2 130mm, HCO3 7. His urinalysis is shown in Figure 1. Blood salicylate levels return as normal. While you await other diagnostic tests, which of the following should be administered next to treat this patient?       
        <think>Consider the symptoms and lab values presented in the scenario. The patient is showing signs of salicylate poisoning, which is consistent with the lab values of metabolic acidosis, elevated anion gap, and hyperventilation leading to respiratory alkalosis. Fomepizole is used to treat methanol and ethylene glycol poisoning, so it is not a suitable choice in this scenario. Salicylate poisoning is a known cause of respiratory alkalosis, so the patient's hyperventilation is consistent with this diagnosis. Ethanol is a common treatment for salicylate poisoning as it is thought to inhibit the enzyme aldehyde dehydrogenase and slow the metabolism of salicylate. Naloxone is an opioid antagonist, which would be used in the case of an opioid overdose, not salicylate poisoning. Naltrexone is an opioid antagonist that is often used for the treatment of opioid addiction, but it is not indicated in this scenario. Fomepizole is a medication used to treat methanol and ethylene glycol poisoning, but it is not indicated in this scenario as the patient's lab values are consistent with salicylate poisoning. Therefore, ethanol is the most appropriate choice to treat this patient's condition.</think>
        <answer>A</answer>
        Your response:"""
        return prompt

    def load_medqa_data(self, file_path: str):
        """Load and process MedQA dataset from a JSON file."""
        with open(file_path, 'r') as f:
            data = json.load(f)
        processed_data = []
        for item in data:
            processed_item = {
                'question': item['question'],
                'options': item['options'],
                'correct_answer': item['answer_idx'],
                'prompt': self.format_prompt(item['question'], item['options'])
            }
            processed_data.append(processed_item)
        return processed_data

# Baseline Network

In [None]:
class BaselineNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()

        # Add input normalization layer
        self.input_norm = nn.LayerNorm(input_dim)

        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        x = x.float()
        x = torch.clamp(x, min=-10.0, max=10.0)
        x = self.input_norm(x)
        output = self.network(x)
        output = torch.clamp(output.squeeze(-1), min=-5.0, max=5.0)
        return output

# Policy Trainer

In [None]:
class PolicyTrainer:
    def __init__(self, model_path, reward_config=None, use_baseline=True, umls_api_key=None, log_file=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        print(f"Loading tokenizer and model from local path: {model_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.eos_token is None:
            self.tokenizer.eos_token = self.tokenizer.pad_token or self.tokenizer.unk_token
        if self.tokenizer.eos_token_id is None:
            self.tokenizer.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)

        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
        dtype = torch.float16 if quantization_config is None else None

        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=dtype,
            quantization_config=quantization_config
        )
        #attn_implementation=attn_impl,
        self.metrics_tracker = MetricsTracker()
        
        self._baseline_cache_path = "logs/baseline_responses.json"
        self._baseline_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="cpu",
            torch_dtype=None,  # cpu stable
            quantization_config=BitsAndBytesConfig(load_in_8bit=True)
        )
        self._baseline_model.eval()
        print("Frozen baseline model initialized for epoch comparisons")

        self.use_learnable_reward = True

        hidden_dim = self.model.config.hidden_size
        print(f"Model hidden dimension: {hidden_dim}")

        if self.use_learnable_reward:
            self.reward_head = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim // 2, 1)
            ).to(self.device).float()

            self.reward_optimizer = torch.optim.Adam(self.reward_head.parameters(), lr=1e-4)
            print("Learnable reward head initialized")

        self.device = next(self.model.parameters()).device
        print(f"Model automatically placed on device: {self.device}")
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()

        self.model.train()

        reward_config = reward_config or {}
        self.data_processor = MedQADataProcessor(self.tokenizer)
        reward_config['umls_api_key'] = umls_api_key
        self.reward_function = RewardModel(**reward_config)
        self.reward_function.fact_verification_system.training_mode = False

        self.use_baseline = use_baseline
        if self.use_baseline:
            self.baseline_network = BaselineNetwork(input_dim=hidden_dim)
            self.baseline_network = self.baseline_network.to(self.device)
            self.baseline_optimizer = optim.Adam(self.baseline_network.parameters(), lr=1e-4)
            print("Baseline network initialized")

        self.policy_optimizer = bnb_optim.AdamW8bit(self.model.parameters(), lr=1.41e-5)
        self.reward_history = []

        adversarial_config = {
            'temperature': 1.2,
            'max_examples': 50,
            'preference_margin': 0.5,
            'validation_threshold': 0.7
        }
        self.adversarial_trainer = AdversarialTrainer(self, adversarial_config)

        # Set up logging
        self.log_file = log_file or f"training_log_{time.strftime('%Y%m%d_%H%M%S')}.txt"

        # Clear the log file at start
        with open(self.log_file, 'w', encoding='utf-8') as f:
            f.write(f"Training Log Started: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("=" * 80 + "\n\n")

        print("PolicyTrainer initialization complete!")


    def _ensure_baseline_responses(self, eval_prompts):
        if os.path.exists(self._baseline_cache_path):
            with open(self._baseline_cache_path, "r", encoding="utf-8") as f:
                return json.load(f)

        print("Generating baseline responses using existing model...")
        self._baseline_model.eval()
        baseline_results = []
        for prompt in eval_prompts:
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self._baseline_model.generate(
                    **inputs,
                    max_new_tokens=100,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            base_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:],
                                              skip_special_tokens=True).strip()
            baseline_results.append(base_text)
        
        
        os.makedirs("logs", exist_ok=True)
        with open(self._baseline_cache_path, "w", encoding="utf-8") as f:
            json.dump(baseline_results, f, indent=2)
        return baseline_results


    def log_epoch_comparisons(self, eval_prompts, epoch, correct_answers=None, context_list=None):
        results = []
        if epoch % 5 != 0:
            return []
        for i, prompt in enumerate(eval_prompts):
            context = context_list[i] if context_list else ""
            correct_answer = correct_answers[i] if correct_answers else None

            self.model.eval()
            baseline_texts = self._ensure_baseline_responses(eval_prompts)
            base_text = baseline_texts[i]

            adv_text, _, _, _ = self.generate_response_with_logprobs(prompt)

            # Compute rewards for adversarial response
            reward_info = self.reward_function.compute_total_reward(
                adv_text, correct_answer, context
            )

            entry = {
                "epoch": epoch,
                "prompt": prompt,
                "baseline_response": base_text,
                "adversarial_response": adv_text,
                "correct_answer": correct_answer,
                "reward_info": reward_info,
            }
            results.append(entry)

            # Console logging
            print("=" * 80)
            print(f"[Epoch {epoch}] Prompt: {prompt}")
            print(f"Correct Answer: {correct_answer}")
            print("\n--- Baseline Response ---")
            print(base_text)
            print("\n--- Adversarial Response ---")
            print(adv_text)
            print("\nReward breakdown:")
            print(reward_info)

        # Save to JSON file for later analysis
        os.makedirs("logs", exist_ok=True)
        log_path = f"logs/epoch_{epoch}_comparisons.json"
        with open(log_path, "w") as f:
            json.dump(results, f, indent=2)

        return results

    def generate_response_with_logprobs(self, prompt, max_new_tokens = 2048):
        """Memory-efficient generation with log probabilities and enhanced error handling."""
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        input_length = inputs['input_ids'].shape[1]

        self.model.eval()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=min(max_new_tokens, 512),
                do_sample=True,
                temperature=0.2,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        generated_ids = outputs.sequences[0][input_length:]
        response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        if len(generated_ids) == 0:
            return response_text, torch.tensor([], device=self.device, requires_grad=True), torch.zeros(
                self.model.config.hidden_size, device=self.device, dtype=torch.float16, requires_grad=True)

        self.model.train()

        log_probs = []
        # chunk_size = min(128, len(generated_ids))
        if len(generated_ids) > 0:
            full_context = outputs.sequences[0]
            full_inputs = {'input_ids': full_context.unsqueeze(0), 
                           'attention_mask': torch.ones_like(full_context).unsqueeze(0)}
            
            with torch.no_grad():
                full_outputs = self.model(**full_inputs)
                full_logits = full_outputs.logits[0]
            
            # Extract log probs for generated tokens only
            for i, token_id in enumerate(generated_ids):
                logit_idx = input_length + i - 1
                if 0 <= logit_idx < full_logits.shape[0]:
                    logits = full_logits[logit_idx].clamp(min=-50, max=50)
                    log_prob = torch.log_softmax(logits, dim=-1)[token_id]
                    log_probs.append(log_prob)

        # Get hidden states for baseline computation
        final_inputs = {
            'input_ids': outputs.sequences[0][-50:].unsqueeze(0),
            'attention_mask': torch.ones(1, min(50, len(outputs.sequences[0]))).to(self.device)
        }
        final_outputs = self.model(**final_inputs, output_hidden_states=True)

        if final_outputs.hidden_states:
            hidden_state = final_outputs.hidden_states[-1][0]

            # Enhanced validity checks
            if torch.any(torch.isnan(hidden_state)) or torch.any(torch.isinf(hidden_state)):
                print("Warning: Invalid hidden states detected")
                avg_hidden_state = torch.zeros(self.model.config.hidden_size,
                                               device=self.device,
                                               dtype=torch.float16,
                                               requires_grad=True)
            else:
                hidden_state_clamped = torch.clamp(hidden_state, min=-10, max=10)
                avg_hidden_state = hidden_state_clamped.mean(dim=0)
                avg_hidden_state.requires_grad_(True)
        else:
            avg_hidden_state = torch.zeros(self.model.config.hidden_size,
                                           device=self.device,
                                           dtype=torch.float16,
                                           requires_grad=True)

        log_probs_tensor = torch.stack(log_probs) if log_probs else torch.tensor([], device=self.device,
                                                                                 requires_grad=True,
                                                                                 dtype=torch.float16)
        return response_text, log_probs_tensor, avg_hidden_state

    def generate_and_evaluate_with_facts(self, prompt, correct_answer, context=""):
        response_text, log_probs, hidden_state = self.generate_response_with_logprobs(prompt)

        # Get rule-based reward (your existing system)
        rule_reward_info = self.reward_function.compute_total_reward(
            response_text, correct_answer, context
        )

        # Add learnable reward component
        if hasattr(self, 'use_learnable_reward') and self.use_learnable_reward:
            with torch.no_grad():
                learnable_reward = self.reward_head(hidden_state.detach().float()).item()

            # Combine rule-based and learnable rewards
            combined_reward = 0.7 * rule_reward_info['r_normalized'] + 0.3 * learnable_reward

            # Add to reward info
            reward_info = rule_reward_info.copy()
            reward_info['r_learnable'] = learnable_reward
            reward_info['r_combined'] = combined_reward
            reward_info['r_normalized'] = combined_reward  # Use combined as main reward
        else:
            reward_info = rule_reward_info

        return response_text, log_probs, hidden_state, reward_info

    def compute_baseline_value(self, hidden_state, training_mode=False):
        """Compute baseline value with enhanced error handling."""
        if not self.use_baseline or hidden_state is None:
            return torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Enhanced validity checks
        if torch.any(torch.isnan(hidden_state)) or torch.any(torch.isinf(hidden_state)):
            print("Warning: Invalid hidden state input to baseline, using fallback")
            return torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        # Clamp input to prevent extreme values
        hidden_state_input = torch.clamp(hidden_state, min=-10.0, max=10.0)

        try:
            baseline_value = self.baseline_network(hidden_state_input)

            # Check for invalid baseline output
            if torch.isnan(baseline_value) or torch.isinf(baseline_value):
                print("Warning: Baseline network produced invalid output, using fallback")
                baseline_value = torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)
            else:
                baseline_value = torch.clamp(baseline_value, min=-5.0, max=5.0)

        except Exception as e:
            print(f"Warning: Error in baseline network: {e}, using fallback")
            baseline_value = torch.tensor(0.0, device=self.device, requires_grad=training_mode, dtype=torch.float16)

        if training_mode:
            return baseline_value
        else:
            return baseline_value.item()

    def _generate_conservative_response(self, prompt, correct_answer):
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=True,
                temperature=0.5,
                top_p=0.8,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()

        return response
        
    def _generate_high_quality_response(self, prompt, correct_answer):
        """Generate a high-quality response that should score well"""

        # Use conservative generation parameters for high quality
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                do_sample=True,
                temperature=0.6,  # Lower temperature for more focused responses
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()

        # Ensure proper format
        if '<think>' not in response:
            response = f"<think>Let me analyze this medical scenario systematically. {response}</think>"
        if '<answer>' not in response:
            response += f"<answer>{correct_answer}</answer>"

        return response

    def _analyze_response_quality(self, response: str, reward_info: Dict, correct_answer: str) -> Dict:
        """Analyze response against your 5 quality goals"""

        violations = {
            'correctness_violations': 0,
            'answer_leaking_violations': 0,
            'format_violations': 0,
            'factual_violations': 0,
            'bad_ood_high_rewards': 0
        }

        # Goal 1: Improved Correctness
        if reward_info.get('r_binary', 0) <= 0:
            violations['correctness_violations'] = 1

        # Goal 2: Less Answer Leaking
        if reward_info.get('p_answer', 0) > 0.1:
            violations['answer_leaking_violations'] = 1

        # Goal 3: Less Format Violations
        if reward_info.get('p_structural', 0) > 0.1 or not self.reward_function.validate_format(response):
            violations['format_violations'] = 1

        # Goal 4: Less Factual Errors
        if reward_info.get('r_factual', 1.0) < 0.3:
            violations['factual_violations'] = 1

        # Goal 5: Less Rewards to Bad OOD Reasoning
        # High total reward despite violations = bad OOD getting undeserved high reward
        has_violations = sum(violations.values()) > 0
        high_reward = reward_info.get('r_total', 0) > 0.5
        if has_violations and high_reward:
            violations['bad_ood_high_rewards'] = 1

        return violations

    def evaluate_model(self, test_data_path, max_examples):
        test_dataset = self.data_processor.load_medqa_data(test_data_path)
        self.reward_function.fact_verification_system.training_mode = False
    
        if max_examples is not None:
            test_dataset = test_dataset[:max_examples]
    
        print(f"Evaluating on {max_examples} test examples...")
    
        self.model.eval()
        if self.use_baseline:
            self.baseline_network.eval()
    
        correct_predictions = 0
        total_examples = len(test_dataset)
    
        # Track individual violations for standard deviation calculation
        violation_records = {
            'correctness_violations': [],
            'answer_leaking_violations': [],
            'format_violations': [],
            'factual_violations': [],
            'bad_ood_high_rewards': []
        }
    
        # Enhanced reward tracking including factual scores
        reward_components = {'r_binary': [], 'p_answer': [], 'p_structural': [], 'r_factual': []}
        all_rewards = []
        format_violations = 0
        evaluation_data = []
    
        with torch.no_grad():
            for i, test_item in enumerate(test_dataset):
                prompt = test_item['prompt']
                correct_answer = test_item['correct_answer']
    
                response = self.generate_response(prompt)
    
                # Enhanced reward computation with factual verification
                reward_info = self.reward_function.compute_total_reward(response, correct_answer)
    
                evaluation_data.append({
                    'response': response,
                    'reward_info': reward_info,
                    'correct_answer': correct_answer
                })
    
                # Track reward components
                for key in reward_components:
                    reward_components[key].append(reward_info[key])
                
                all_rewards.append(reward_info['r_normalized'])
    
                predicted_answer = self.reward_function.extract_answer_choice(response)
                is_correct = predicted_answer and predicted_answer.upper() == correct_answer.upper()
                if is_correct:
                    correct_predictions += 1
    
                format_ok = self.reward_function.validate_format(response)
                format_violations += int(not format_ok)
    
                # Track individual violations for std calculation
                violations = self._analyze_response_quality(response, reward_info, correct_answer)
                for key in violation_records.keys():
                    violation_records[key].append(violations[key])
    
                if i % 100 == 0:
                    print(f"Evaluated {i + 1}/{total_examples} examples...")
    
        format_violation_rate = format_violations / total_examples
        accuracy = correct_predictions / total_examples
        avg_rewards = {key: np.mean(values) for key, values in reward_components.items()}
        hacking_stats = self.reward_function.calculate_hacking_rate(evaluation_data)
    
        # Calculate standard deviations
        reward_std = np.std(all_rewards) if len(all_rewards) > 1 else 0.0
    
        # Map to the same keys MetricsTracker expects with std deviations
        test_metrics = {
            "correctness_violations_rate": 1 - accuracy,
            "answer_leaking_violations_rate": hacking_stats["answer_violation_rate"],
            "format_violations_rate": format_violation_rate,
            "factual_violations_rate": hacking_stats["factual_violation_rate"],
            "bad_ood_high_rewards_rate": hacking_stats["overall_violation_rate"],
            "avg_reward": np.mean(all_rewards),
            "avg_factual": avg_rewards["r_factual"],
            "accuracy": accuracy,
            "correct_predictions": correct_predictions,
            "total_examples": total_examples,
            "reward_history": self.reward_history,
            
            # Add standard deviations
            "std_reward": reward_std,
            "std_factual": np.std(reward_components['r_factual']) if len(reward_components['r_factual']) > 1 else 0.0
        }
    
        # Add violation standard deviations
        for key in violation_records.keys():
            rate_key = f"{key}_rate"
            std_key = f"{key}_std"
            
            # Calculate violation rate and std
            test_metrics[rate_key] = test_metrics.get(rate_key, np.mean(violation_records[key]))
            test_metrics[std_key] = np.std(violation_records[key]) if len(violation_records[key]) > 1 else 0.0
    
        # Use MetricsTracker instead of manual printing
        self.metrics_tracker.print_metrics_with_std(test_metrics, "Test Evaluation Results")
    
        return test_metrics

    def generate_response(self, prompt: str, max_new_tokens = 2048) -> str:
        self.model.eval()

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return response.strip()

    def analyze_hacking_sensitivity(self, test_data_path, max_examples, tau_answer_range=None, tau_preamble_range=None):
        if tau_answer_range is None:
            tau_answer_range = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
        if tau_preamble_range is None:
            tau_preamble_range = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

        test_dataset = self.data_processor.load_medqa_data(test_data_path)
        if max_examples is not None:
            test_dataset = test_dataset[:max_examples]

        print(f"Running sensitivity analysis on {len(test_dataset)} examples...")
        print(f"Testing tau_answer: {tau_answer_range}")
        print(f"Testing tau_preamble: {tau_preamble_range}")

        print("Generating responses...")
        self.model.eval()
        if self.use_baseline:
            self.baseline_network.eval()

        evaluation_data = []
        with torch.no_grad():
            for i, test_item in enumerate(test_dataset):
                prompt = test_item['prompt']
                correct_answer = test_item['correct_answer']
                response = self.generate_response(prompt)

                evaluation_data.append({
                    'response': response,
                    'correct_answer': correct_answer,
                    'prompt': prompt
                })

        # Store original thresholds
        original_tau_answer = self.reward_function.tau_answer
        original_tau_preamble = self.reward_function.tau_preamble

        sensitivity_results = []

        print("\nTesting threshold combinations...")
        for tau_answer in tau_answer_range:
            for tau_structural in tau_preamble_range:
                print(f"Testing tau_answer={tau_answer}, tau_preamble={tau_structural}")

                self.reward_function.tau_answer = tau_answer
                self.reward_function.tau_preamble = tau_structural

                responses_with_rewards = []
                for item in evaluation_data:
                    # Enhanced reward computation with factual verification
                    reward_info = self.reward_function.compute_total_reward(
                        item['response'], item['correct_answer']
                    )
                    responses_with_rewards.append({
                        'response': item['response'],
                        'reward_info': reward_info,
                        'correct_answer': item['correct_answer']
                    })

                hacking_stats = self.reward_function.calculate_hacking_rate(responses_with_rewards)

                positive_rewards = sum(1 for item in responses_with_rewards
                                       if item['reward_info']['r_total'] > 0)
                avg_answer_penalty = np.mean([item['reward_info']['p_answer']
                                              for item in responses_with_rewards])
                avg_structural_penalty = np.mean([item['reward_info']['p_structural']
                                                  for item in responses_with_rewards])
                avg_factual_reward = np.mean([item['reward_info']['r_factual']
                                              for item in responses_with_rewards])

                sensitivity_results.append({
                    'tau_answer': tau_answer,
                    'tau_preamble': tau_structural,
                    'answer_violation_rate': hacking_stats['answer_violation_rate'],
                    'structural_violation_rate': hacking_stats['structural_violation_rate'],
                    'factual_violation_rate': hacking_stats['factual_violation_rate'],
                    'overall_violation_rate': hacking_stats['overall_violation_rate'],
                    'answer_violation_count': hacking_stats['answer_violation_count'],
                    'structural_violation_count': hacking_stats['structural_violation_count'],
                    'factual_violation_count': hacking_stats['factual_violation_count'],
                    'positive_reward_count': positive_rewards,
                    'positive_reward_rate': positive_rewards / len(responses_with_rewards),
                    'avg_answer_penalty': avg_answer_penalty,
                    'avg_structural_penalty': avg_structural_penalty,
                    'avg_factual_reward': avg_factual_reward
                })

        # Restore original thresholds
        self.reward_function.tau_answer = original_tau_answer
        self.reward_function.tau_preamble = original_tau_preamble

        return {
            'sensitivity_results': sensitivity_results,
            'tau_answer_range': tau_answer_range,
            'tau_preamble_range': tau_preamble_range,
            'total_examples': len(test_dataset)


    def _train_reward_model_stage1(self, train_data_path, num_epochs, batch_size):
        """Stage 1: Train only the learnable reward components"""
        if not hasattr(self, 'reward_head'):
            print("No learnable reward head found - skipping reward model training")
            return
        
        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        print(f"Stage 1: Training reward model on {len(train_dataset)} examples")
        
        # Freeze policy model
        self.model.eval()
        self.reward_head.train()
        
        for epoch in range(num_epochs):
            total_loss = 0
            num_batches = 0
            
            for batch_start in range(0, len(train_dataset), batch_size):
                batch_end = min(batch_start + batch_size, len(train_dataset))
                batch_items = train_dataset[batch_start:batch_end]
                
                batch_losses = []
                
                for item in batch_items:
                    try:
                        prompt = item['prompt']
                        correct_answer = item['correct_answer']
                        
                        # Generate response (no gradients for policy)
                        with torch.no_grad():
                            response_text, _, hidden_state = self.generate_response_with_logprobs(prompt)
                        
                        # Compute target score using rule-based reward
                        rule_reward_info = self.reward_function.compute_total_reward(
                            response_text, correct_answer
                        )
                        target_score = rule_reward_info['r_normalized']
                        
                        # Predict score using learnable reward head
                        predicted_score = self.reward_head(hidden_state.detach().float())
                        if predicted_score.dim() > 0:
                            predicted_score = predicted_score.squeeze()
                        
                        # Compute loss
                        target_tensor = torch.tensor(target_score, device=self.device, dtype=torch.float32)
                        loss = torch.nn.functional.mse_loss(predicted_score, target_tensor)
                        batch_losses.append(loss)
                        
                    except Exception as e:
                        print(f"Error in reward model training: {e}")
                        continue
                
                if batch_losses:
                    total_batch_loss = torch.stack(batch_losses).mean()
                    total_batch_loss.backward()
                    total_loss += total_batch_loss.item()
                    num_batches += 1
                
                # Update reward model every batch
                torch.nn.utils.clip_grad_norm_(self.reward_head.parameters(), 1.0)
                self.reward_optimizer.step()
                self.reward_optimizer.zero_grad()
            
            avg_loss = total_loss / num_batches if num_batches > 0 else 0
            print(f"Reward model epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.2f}")
        
        print("Stage 1 complete: Reward model training finished")

    def _train_policy_stage2(self, train_data_path, num_epochs, batch_size):
        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        print(f"Stage 2: Training policy on {len(train_dataset)} examples")
        
        if hasattr(self, 'reward_head'):
            self.reward_head.eval()
        self.model.train()
        
        for epoch in range(num_epochs):
            epoch_rewards = []
            epoch_advantages = []
            epoch_policy_losses = []
            epoch_baseline_values = []
            
            violation_records = {
                'correctness_violations': [],
                'answer_leaking_violations': [],
                'format_violations': [],
                'factual_violations': [],
                'bad_ood_high_rewards': []
            }
            
            epoch_metrics = {
                'correctness_violations': 0,
                'answer_leaking_violations': 0,
                'format_violations': 0,
                'factual_violations': 0,
                'bad_ood_high_rewards': 0,
                'total_examples': 0
            }
            
            self.policy_optimizer.zero_grad()
            if self.use_baseline:
                self.baseline_optimizer.zero_grad()
            
            for batch_start in range(0, len(train_dataset), batch_size):
                batch_end = min(batch_start + batch_size, len(train_dataset))
                batch_items = train_dataset[batch_start:batch_end]
                
                batch_policy_losses = []
                batch_baseline_losses = []
                batch_rewards = []
                batch_advantages = []
                
                for item in batch_items:
                    try:
                        prompt = item['prompt']
                        correct_answer = item['correct_answer']
                        
                        response_text, log_probs, hidden_state, reward_info = self.generate_and_evaluate_with_facts(
                            prompt, correct_answer
                        )
                        
                        # Track violations for this individual example
                        violations = self._analyze_response_quality(response_text, reward_info, correct_answer)
                        
                        # Record individual violations (0 or 1) for std calculation
                        for key in violation_records.keys():
                            violation_records[key].append(violations[key])
                        
                        # Aggregate totals for rates
                        for key, value in violations.items():
                            epoch_metrics[key] = epoch_metrics.get(key, 0) + value
                        epoch_metrics["total_examples"] = epoch_metrics.get("total_examples", 0) + 1
                        
                        if log_probs.numel() == 0:
                            continue
                        
                        # Get reward from frozen reward model
                        with torch.no_grad():
                            if hasattr(self, 'reward_head'):
                                reward = self.reward_head(hidden_state.detach().float()).item()
                            else:
                                reward_info = self.reward_function.compute_total_reward(
                                    response_text, correct_answer
                                )
                                reward = reward_info['r_normalized']
                        
                        # Baseline and advantage
                        baseline_value = self.compute_baseline_value(hidden_state, training_mode=True)
                        reward_tensor = torch.tensor(reward, device=self.device, dtype=torch.float16)
                        advantage = reward_tensor - baseline_value
                        
                        batch_rewards.append(reward)
                        batch_advantages.append(advantage.detach().item())
                        
                        policy_loss = -torch.sum(log_probs) * advantage
                        batch_policy_losses.append(policy_loss)
                        
                        # Collect metrics for std calculation
                        epoch_rewards.append(reward)
                        epoch_advantages.append(advantage.detach().item())
                        epoch_policy_losses.append(policy_loss.detach().item())
                        if hasattr(baseline_value, 'item'):
                            epoch_baseline_values.append(baseline_value.item())
                        else:
                            epoch_baseline_values.append(float(baseline_value))
                        
                        if self.use_baseline:
                            baseline_loss = ((baseline_value - reward_tensor) ** 2)
                            batch_baseline_losses.append(baseline_loss)
                    
                    except Exception as e:
                        print(f"Error in policy training: {e}")
                        continue
                
                # Update policy
                if batch_policy_losses:
                    total_policy_loss = torch.stack(batch_policy_losses).mean()
                    total_loss = total_policy_loss
                    
                    if self.use_baseline and batch_baseline_losses:
                        total_baseline_loss = torch.stack(batch_baseline_losses).mean()
                        total_loss += total_baseline_loss
                    
                    total_loss.backward()
            
            # Update weights
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            if self.use_baseline:
                torch.nn.utils.clip_grad_norm_(self.baseline_network.parameters(), max_norm=1.0)
            
            self.policy_optimizer.step()
            self.policy_optimizer.zero_grad()
            if self.use_baseline:
                self.baseline_optimizer.step()
                self.baseline_optimizer.zero_grad()
            
            # Calculate means and standard deviations for continuous metrics
            avg_reward = np.mean(epoch_rewards) if epoch_rewards else 0
            std_reward = np.std(epoch_rewards) if len(epoch_rewards) > 1 else 0
            
            avg_advantage = np.mean(epoch_advantages) if epoch_advantages else 0
            std_advantage = np.std(epoch_advantages) if len(epoch_advantages) > 1 else 0
            
            avg_policy_loss = np.mean(epoch_policy_losses) if epoch_policy_losses else 0
            std_policy_loss = np.std(epoch_policy_losses) if len(epoch_policy_losses) > 1 else 0
            
            avg_baseline_value = np.mean(epoch_baseline_values) if epoch_baseline_values else 0
            std_baseline_value = np.std(epoch_baseline_values) if len(epoch_baseline_values) > 1 else 0
            
            print(f"Policy epoch {epoch + 1}/{num_epochs}")
            print(f"Avg Reward: {avg_reward:.2f} (±{std_reward:.2f})\n")
            # print(f"  Avg Advantage: {avg_advantage:.2f} (±{std_advantage:.2f})")
            # print(f"  Avg Policy Loss: {avg_policy_loss:.2f} (±{std_policy_loss:.2f})")
            if self.use_baseline:
                print(f"  Avg Baseline Value: {avg_baseline_value:.2f} (±{std_baseline_value:.2f})")
            
            if epoch_rewards and epoch_metrics["total_examples"] > 0:
                # Calculate violation rates and their standard deviations
                for key in ["correctness_violations", "answer_leaking_violations",
                            "format_violations", "factual_violations", "bad_ood_high_rewards"]:
                    rate_key = f"{key}_rate"
                    std_key = f"{key}_std"
                    
                    # Rate is the mean of 0s and 1s
                    epoch_metrics[rate_key] = epoch_metrics[key] / epoch_metrics["total_examples"]
                    
                    # Standard deviation of binary values (0s and 1s)
                    if len(violation_records[key]) > 1:
                        epoch_metrics[std_key] = np.std(violation_records[key])
                    else:
                        epoch_metrics[std_key] = 0.0
                
                # Store other metrics
                epoch_metrics["avg_reward"] = avg_reward
                epoch_metrics["std_reward"] = std_reward
                epoch_metrics["avg_advantage"] = avg_advantage 
                epoch_metrics["std_advantage"] = std_advantage
                epoch_metrics["avg_policy_loss"] = avg_policy_loss
                epoch_metrics["std_policy_loss"] = std_policy_loss
                epoch_metrics["avg_baseline_value"] = avg_baseline_value
                epoch_metrics["std_baseline_value"] = std_baseline_value

                self.metrics_tracker.print_metrics_with_std(epoch_metrics, f"POLICY EPOCH {epoch + 1}")
        
        print("Stage 2 complete: Policy training finished")
    
    def _train_adversarial_stage3(self, train_data_path, num_cycles=3):
        print("\n" + "=" * 60)
        print("STAGE 3: ADVERSARIAL ROBUSTNESS TRAINING")
        print("=" * 60)
        
        train_dataset = self.data_processor.load_medqa_data(train_data_path)
        prompts = [item['prompt'] for item in train_dataset[:100]]
        answers = [item['correct_answer'] for item in train_dataset[:100]]
        
        print("Evaluating model BEFORE adversarial training...")
        pre_adversarial_metrics = self.evaluate_model(train_data_path, 50)
        
        print("Pre-Adversarial Performance:")
        self.metrics_tracker.print_metrics_with_std(pre_adversarial_metrics, "PRE-ADVERSARIAL METRICS")
        
        # Run adversarial training cycles
        results = self.adversarial_trainer.run_adversarial_training_cycle(
            prompts, answers, num_cycles
        )
        
        print("Adversarial training complete:")
        print(f"  Final robustness score: {results['final_robustness']:.2f}")
        print(f"  Total cycles completed: {results['total_cycles']}")
        
        print("\nEvaluating model AFTER adversarial training...")
        post_adversarial_metrics = self.evaluate_model(train_data_path, 50)
        
        print("Post-Adversarial Performance:")
        self.metrics_tracker.print_metrics_with_std(post_adversarial_metrics, "POST-ADVERSARIAL METRICS")
        
        # Store both sets of metrics and print comprehensive analysis
        self.metrics_tracker.add_adversarial_stage_metrics(pre_adversarial_metrics, post_adversarial_metrics)
        self.metrics_tracker.print_adversarial_impact_analysis()
        
        return results


    def train_reward_policy(self, train_data_path, test_data_path, stage1_epochs, stage2_epochs, batch_size=8, max_eval_examples=100):
        print("TRAINING: Reward and Policy")
        print("="*40)
        
        self._train_reward_model_stage1(train_data_path, stage1_epochs, batch_size)
        self._train_policy_stage2(train_data_path, stage2_epochs, batch_size)
        
        print("\n Evaluating Reward and Policy Training Results:")
        stage1_metrics = self.evaluate_model(test_data_path, max_eval_examples)
        self.metrics_tracker.print_metrics_with_std(stage1_metrics, "STAGE 1 RESULTS")
        
        return stage1_metrics
    
    def train_adversarial(self, train_data_path, test_data_path, stage1_epochs, stage2_epochs, batch_size=8, max_eval_examples=100):
        """Train all stages including adversarial and evaluate"""
        print("TRAINING: All Stages (Including Adversarial)")
        print("="*50)
        
        self._train_reward_model_stage1(train_data_path, stage1_epochs, batch_size)
        self._train_policy_stage2(train_data_path, stage2_epochs, batch_size)
        self._train_adversarial_stage3(train_data_path, num_cycles=3)
        
        print("\nEVALUATING FINAL RESULTS:")
        final_metrics = self.evaluate_model(test_data_path, max_eval_examples)
        self.metrics_tracker.print_metrics_with_std(final_metrics, "STAGE 1 RESULTS")
        
        return final_metrics
    
    def compare_training_approaches(self, train_data_path, test_data_path, stage1_epochs, stage2_epochs, batch_size=8, max_eval_examples=100):
        """Compare both training approaches"""
        
        # Train and evaluate stage 1 only
        stage1_metrics = self.train_reward_policy(train_data_path, test_data_path, stage1_epochs, stage2_epochs, batch_size, max_eval_examples)
                
        # Train and evaluate with adversarial
        final_metrics = self.train_adversarial(train_data_path, test_data_path, stage1_epochs, stage2_epochs, batch_size, max_eval_examples)
        
        # Compare results
        self._compare_training_stages(stage1_metrics, final_metrics)
        
        return {
            'stage1_only': stage1_metrics,
            'with_adversarial': final_metrics
        }

# Adversarial Trainer

In [None]:
class AdversarialTrainer:
    
    def __init__(self, base_trainer, adversarial_config=None):

        self.base_trainer = base_trainer
        self.model = base_trainer.model
        self.tokenizer = base_trainer.tokenizer
        self.reward_function = base_trainer.reward_function  # This exists in your trainer
        self.device = base_trainer.device
        
        config = adversarial_config or {}
        self.adversarial_temperature = config.get('temperature', 1.2)
        self.max_adversarial_examples = config.get('max_examples', 50)
        self.preference_margin = config.get('preference_margin', 0.5)
        self.validation_threshold = config.get('validation_threshold', 0.7)
        
        # Buffers for adversarial examples
        self.adversarial_examples_buffer = []
        self.preference_pairs_buffer = []

        print("AdversarialTrainer initialized")
    
    def generate_adversarial_examples(self, prompts, correct_answers=None, target_reward=0.5):
        adversarial_examples = []
        print(f"Generating adversarial examples from {len(prompts)} prompts...")
    
        generation_strategies = [
            {'temperature': 1.5, 'top_p': 0.95, 'do_sample': True},
            {'temperature': 0.3, 'top_p': 0.8, 'do_sample': True},
            {'temperature': 1.0, 'top_p': 0.9, 'do_sample': True},
        ]
    
        shown = 0  # counter for how many we print
    
        for strategy_idx, strategy in enumerate(generation_strategies):
            print(f"  Strategy {strategy_idx + 1}: temp={strategy['temperature']}")
    
            for i, prompt in enumerate(prompts[:self.max_adversarial_examples]):
                try:
                    response = self._generate_with_strategy(prompt, strategy)
                    correct_answer = correct_answers[i] if correct_answers else "A"
                    reward_info = self.reward_function.compute_total_reward(response, correct_answer)
    
                    if self._is_adversarial_example(response, reward_info, target_reward):
                        adv = {
                            'prompt': prompt,
                            'response': response,
                            'reward_info': reward_info,
                            'correct_answer': correct_answer,
                            'strategy': strategy_idx,
                            'vulnerability_type': self._classify_vulnerability(reward_info)
                        }
                        adversarial_examples.append(adv)
    
                        # Show first 1–2 adversarial examples for inspection
                        if shown < 2:
                            print("\n=== Adversarial Example ===")
                            print(f"Prompt:\n{prompt}\n")
                            print(f"Response:\n{response}\n")
                            print(f"Reward info: {reward_info}")
                            print(f"Vulnerability: {adv['vulnerability_type']}")
                            print("===========================\n")
                            shown += 1
    
                except Exception as e:
                    print(f"Error generating adversarial example: {e}")
                    continue
    
        print(f"Found {len(adversarial_examples)} adversarial examples")
        self.adversarial_examples_buffer.extend(adversarial_examples)
        return adversarial_examples

    
    def _generate_with_strategy(self, prompt, strategy):
        """Generate response using specific strategy parameters."""
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                **strategy
            )
        
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        return response
    
    def _is_adversarial_example(self, response, reward_info, target_reward):
        total_reward = reward_info.get('r_total', 0)
        
        # Issue: This might miss adversarial examples that exploit the reward function
        # but don't necessarily have these specific penalty types
        if total_reward < target_reward:
            return False
        
        # Better approach: Check if reward is suspiciously high given response quality
        has_issues = (
            reward_info.get('p_answer', 0) > 0.05 or
            reward_info.get('p_structural', 0) > 0.05 or 
            reward_info.get('r_factual', 1.0) < 0.4 or
            not self.reward_function.validate_format(response)
        )
        
        # Adversarial: High reward despite quality issues
        return has_issues and total_reward > target_reward
    
    def _classify_vulnerability(self, reward_info):
        """Classify the type of vulnerability being exploited."""
        if reward_info.get('p_answer', 0) > 0.05:
            return "answer_leaking"
        elif reward_info.get('p_structural', 0) > 0.05:
            return "structural_gaming"
        elif reward_info.get('r_factual', 1.0) < 0.4:
            return "factual_exploitation"
        else:
            return "format_gaming"
    
    def create_preference_pairs(self, adversarial_examples):
        preference_pairs = []
        
        for adv_example in adversarial_examples:
            clean_response = self._generate_clean_response(
                adv_example['prompt'], 
                adv_example['correct_answer']
            )
            
            clean_reward_info = self.reward_function.compute_total_reward(
                clean_response, adv_example['correct_answer']
            )
            
            clean_quality = self._assess_response_quality(clean_response, clean_reward_info)
            adv_quality = self._assess_response_quality(adv_example['response'], adv_example['reward_info'])
            
            if clean_quality > adv_quality:
                    preference_pairs.append({
                        'prompt': adv_example['prompt'],
                        'chosen': clean_response,
                        'rejected': adv_example['response'],
                        'correct_answer': adv_example['correct_answer'],
                        'chosen_reward': clean_reward_info,
                        'rejected_reward': adv_example['reward_info'],
                        'vulnerability_type': adv_example['vulnerability_type']
                    })
        
        print(f"Created {len(preference_pairs)} preference pairs")
        self.preference_pairs_buffer.extend(preference_pairs)
        
        return preference_pairs
    
    def _generate_clean_response(self, prompt, correct_answer):
        """Generate a high-quality, non-adversarial response."""
        # Use conservative parameters for clean generation
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        # Ensure proper format
        response = self._ensure_proper_format(response, correct_answer)
        
        return response
    
    def _ensure_proper_format(self, response, correct_answer):
        """Ensure response has proper <think> and <answer> tags."""
        # Add think tags if missing
        if '<think>' not in response.lower():
            reasoning_part = response.split('<answer>')[0].strip()
            answer_part = response.split('<answer>')[1] if '<answer>' in response else ''
            response = f"<think>{reasoning_part}</think>"
            if answer_part:
                response += f"<answer>{answer_part}"
        
        # Add answer tags if missing
        if '<answer>' not in response.lower():
            response += f"\n<answer>{correct_answer}</answer>"
        
        return response
    
    def update_reward_model(self, preference_pairs):
        if not hasattr(self.base_trainer, 'reward_head'):
            print("No learnable reward component found - skipping adversarial update")
            return {'loss': 0.0, 'accuracy': 0.0}
        
        print(f"Updating reward model with {len(preference_pairs)} preference pairs...")
        self.base_trainer.reward_optimizer.zero_grad()
        self.base_trainer.reward_head.train()
        total_loss = 0
        correct_rankings = 0
        
        for pair in preference_pairs:
            try:
                # Get hidden states for both responses
                chosen_hidden = self._get_hidden_state_for_response(
                    pair['prompt'], pair['chosen']
                )
                rejected_hidden = self._get_hidden_state_for_response(
                    pair['prompt'], pair['rejected']
                )
                
                # Compute predicted rewards
                chosen_reward_pred = self.base_trainer.reward_head(chosen_hidden.detach().float())
                rejected_reward_pred = self.base_trainer.reward_head(rejected_hidden.detach().float())
                
                # Preference loss: chosen should have higher reward
                target = torch.ones_like(chosen_reward_pred)
                loss = F.margin_ranking_loss(
                    chosen_reward_pred, rejected_reward_pred, 
                    target, margin=self.preference_margin
                )
                
                loss.backward()
                total_loss += loss.item()
                
                # Track accuracy
                if chosen_reward_pred.item() > rejected_reward_pred.item():
                    correct_rankings += 1
                    
            except Exception as e:
                print(f"Error in preference update: {e}")
                continue
        
        # Update parameters
        if total_loss > 0:
            torch.nn.utils.clip_grad_norm_(self.base_trainer.reward_head.parameters(), 1.0)
            self.base_trainer.reward_optimizer.step()
        
        avg_loss = total_loss / len(preference_pairs) if preference_pairs else 0
        accuracy = correct_rankings / len(preference_pairs) if preference_pairs else 0
        
        print(f"  Adversarial loss: {avg_loss:.4f}")
        print(f"  Preference accuracy: {accuracy:.2f}")
        
        return {'loss': avg_loss, 'accuracy': accuracy}

    def analyze_response_quality(self, response, reward_info, correct_answer):
        """Analyze response against your 5 quality goals"""

        violations = {
            'correctness_violations': 0,
            'answer_leaking_violations': 0,
            'format_violations': 0,
            'factual_violations': 0,
            'bad_ood_high_rewards': 0
        }

        # Goal 1: Improved Correctness
        if reward_info.get('r_binary', 0) <= 0:
            violations['correctness_violations'] = 1

        # Goal 2: Less Answer Leaking
        if reward_info.get('p_answer', 0) > 0.1:
            violations['answer_leaking_violations'] = 1

        # Goal 3: Less Format Violations
        if reward_info.get('p_structural', 0) > 0.1 or not self.reward_function.validate_format(response):
            violations['format_violations'] = 1

        # Goal 4: Less Factual Errors
        if reward_info.get('r_factual', 1.0) < 0.3:
            violations['factual_violations'] = 1

        # Goal 5: Less Rewards to Bad OOD Reasoning
        # High total reward despite violations = bad OOD getting undeserved high reward
        has_violations = sum(violations.values()) > 0
        high_reward = reward_info.get('r_total', 0) > 0.5
        if has_violations and high_reward:
            violations['bad_ood_high_rewards'] = 1

        return violations

    def _assess_response_quality(self, response, reward_info, correct_answer=None):
        """
        Assess overall response quality using existing analysis method.
        
        Returns:
            Quality score between 0 and 1 (higher = better quality)
        """
        # Use existing violation analysis
        violations = self.analyze_response_quality(response, reward_info, correct_answer or "A")
        
        # Convert violations to quality score
        total_violations = sum(violations.values())
        max_violations = len(violations)  # 5 possible violation types
        
        # Quality = 1 - (violation_ratio)
        quality_score = 1.0 - (total_violations / max_violations)
        
        return quality_score
    
    def _get_hidden_state_for_response(self, prompt, response):
        """Get hidden state for a specific response."""
        full_text = prompt + " " + response
        inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            hidden_state = outputs.hidden_states[-1][0].mean(dim=0)  # Average over sequence
        
        return hidden_state
    
    def validate_robustness(self, test_prompts, test_answers):
        """
        Validate that reward model has become more robust.
        
        Args:
            test_prompts: List of test prompts
            test_answers: List of correct answers
            
        Returns:
            Robustness metrics
        """
        print("Validating reward model robustness...")
        
        # Generate new adversarial examples
        new_adversarial = self.generate_adversarial_examples(
            test_prompts[:20], test_answers[:20]
        )
        
        # Check if reward model correctly identifies them as problematic
        correct_identifications = 0
        
        for adv_example in new_adversarial:
            reward_info = adv_example['reward_info']
            
            # If reward model is robust, adversarial examples should get lower rewards
            if reward_info['r_total'] < 0.3:  # Threshold for "correctly identified as bad"
                correct_identifications += 1
        
        robustness_score = correct_identifications / len(new_adversarial) if new_adversarial else 0
        
        print(f"Robustness validation: {robustness_score:.2f}")
        print(f"  Found {len(new_adversarial)} new adversarial examples")
        print(f"  Correctly identified {correct_identifications} as problematic")
        
        return {
            'robustness_score': robustness_score,
            'adversarial_found': len(new_adversarial),
            'correctly_identified': correct_identifications
        }
    
    def run_adversarial_training_cycle(self, prompts, correct_answers, num_cycles=3):
        """
        Complete adversarial training cycle.
        
        Args:
            prompts: Training prompts
            correct_answers: Correct answers
            num_cycles: Number of adversarial cycles to run
            
        Returns:
            Training summary
        """
        print(f"Starting {num_cycles} adversarial training cycles...")
        
        cycle_results = []
        
        for cycle in range(num_cycles):
            print(f"\n=== Adversarial Cycle {cycle + 1}/{num_cycles} ===")
            
            # Step 1: Generate adversarial examples
            adversarial_examples = self.generate_adversarial_examples(
                prompts, correct_answers
            )
            
            if not adversarial_examples:
                print("No adversarial examples found")
                break
            
            # Step 2: Create preference pairs
            preference_pairs = self.create_preference_pairs(adversarial_examples)
            
            if not preference_pairs:
                print("No valid preference pairs created")
                continue
            
            # Step 3: Update reward model
            update_metrics = self.update_reward_model(preference_pairs)
            
            # Step 4: Validate improvement
            validation_metrics = self.validate_robustness(
                prompts[:10], correct_answers[:10]
            )
            
            cycle_results.append({
                'cycle': cycle + 1,
                'adversarial_found': len(adversarial_examples),
                'preference_pairs': len(preference_pairs),
                'update_loss': update_metrics['loss'],
                'update_accuracy': update_metrics['accuracy'],
                'robustness_score': validation_metrics['robustness_score']
            })
            
            print(f"Cycle {cycle + 1} complete:")
            print(f"  Adversarial examples: {len(adversarial_examples)}")
            print(f"  Update accuracy: {update_metrics['accuracy']:.2f}")
            print(f"  Robustness score: {validation_metrics['robustness_score']:.2f}")
        
        return {
            'total_cycles': len(cycle_results),
            'cycle_results': cycle_results,
            'final_robustness': cycle_results[-1]['robustness_score'] if cycle_results else 0
        }
    
    def clear_buffers(self):
        """Clear adversarial example and preference pair buffers."""
        self.adversarial_examples_buffer.clear()
        self.preference_pairs_buffer.clear()
        print("Adversarial training buffers cleared")

# Fact Extraction and LLM Judge

## AtomicFactExtractor

In [None]:
@dataclass
class AtomicFact:
    text: str
    category: str
    confidence: float = 0.0
    llm_score: float = 0.0
    kb_score: float = 0.0
    source_sentence= ""


class FactCategory:
    """Categories for atomic facts as defined in the methodology"""
    CONDITION = "patient_condition"
    MEDICATION = "medication"
    ANATOMY = "anatomy"
    RELATIONSHIP = "relationship"


class AtomicFactExtractor:
    def __init__(self, model_name, region_name):
        self.model_name = model_name
        self.extraction_prompt = self._build_extraction_prompt()
        self.client = ClaudeBedrockClient(model_name, region_name)

    def _build_extraction_prompt(self):
        return """You are a medical fact extraction specialist. Extract discrete, 
        verifiable clinical facts from the given reasoning text.

        FACT CATEGORIES:
        1. patient_condition: Statements about patient conditions, symptoms, or disease states
        2. medication: Assertions about medications, procedures, or therapeutic interventions  
        3. anatomy: Facts about anatomical structures, biological processes, or pathophysiology
        4. relationship: Relationships between symptoms, conditions, and underlying mechanisms

        EXTRACTION RULES:
        - Each fact should be self-contained and independently verifiable
        - Focus on specific medical claims, not general reasoning steps
        - Exclude subjective statements or reasoning connections
        - Each fact should be 1-2 sentences maximum

        INPUT TEXT: {reasoning_text}

        You must respond with valid JSON only, wrapped strictly between START_JSON and END_JSON markers. Example format:
        START_JSON 
        {"facts": [{{"text": "Ampicillin is effective against E. coli infections","category": "medication","source_sentence": "Ampicillin is a first-line antibiotic for UTIs"}}]}
        END_JSON

        EXTRACTED FACTS (JSON only):"""

    def extract_facts(self, reasoning_text):
        try:
            cleaned_text = self._clean_reasoning_text(reasoning_text)
            prompt = self.extraction_prompt.format(reasoning_text=cleaned_text)
            response = self._call_extraction_llm(prompt)
            facts = self._parse_extraction_response(response)
            return facts

        except Exception as e:
            print(f"Error in fact extraction: {e}")
            return []

    def _clean_reasoning_text(self, text):
        print(f"Original text: '{text}'")
        
        think_match = re.search(r'<think>(.*?)</think>', text, flags=re.DOTALL)
        if think_match:
            text = think_match.group(1)  # Extract the captured group
            print(f"Extracted from <think> tags: '{text}'")
        else:
            print("No <think> tags found, using original text")
        
        # Remove <answer> tags entirely
        text = re.sub(r'<answer>.*?</answer>', '', text, flags=re.DOTALL)
        
        # Clean whitespace
        text = ' '.join(text.split())
        
        print(f"Final cleaned text: '{text}'")
        return text.strip()

    def _call_extraction_llm(self, prompt):
        full_prompt = f"""You are a medical fact extraction specialist.\n\n{prompt}"""

        try:
            response = self.client.chat_completion(
                messages=[{"role": "user", "content": full_prompt}],
                temperature=0.1,
                max_tokens=1000
            )

            content = response['choices'][0]['message']['content']
            return content

        except Exception as e:
            print(f"API call failed: {e}")
            return ""

    def _fallback_extraction(self, text):
        facts = []

        # Extract medical sentences using simple patterns
        sentences = re.split(r'[.!?]+', text)
        medical_keywords = [
            'patient', 'diagnosis', 'treatment', 'symptom', 'medication',
            'therapy', 'condition', 'disease', 'clinical', 'medical'
        ]

        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 15:  # Minimum length
                # Check if sentence contains medical content
                if any(keyword in sentence.lower() for keyword in medical_keywords):
                    fact = AtomicFact(
                        text=sentence,
                        category="medical_statement",
                        source_sentence=sentence
                    )
                    facts.append(fact)

        return facts[:5]

    def _parse_extraction_response(self, response):
        facts = []
        
        try:
            print(f"Parsing response: '{response}'")
            
            # Step 1: Clean the response
            cleaned_response = response.strip()
            
            # Remove code block markers if present
            cleaned_response = re.sub(r'```json\s*', '', cleaned_response)
            cleaned_response = re.sub(r'```\s*$', '', cleaned_response)
            
            # Step 2: Extract JSON - FIXED PATTERNS
            json_str = None
            
            # Try to find JSON between START_JSON and END_JSON markers
            start_end_match = re.search(r'START_JSON\s*(.+?)\s*END_JSON', cleaned_response, re.DOTALL)
            if start_end_match:
                json_str = start_end_match.group(1).strip()
                print(f"Found JSON between markers: '{json_str}'")
            else:
                # Look for JSON object pattern - IMPROVED REGEX
                json_match = re.search(r'\{[^{}]*"facts"[^{}]*\[[^\]]*\][^{}]*\}', cleaned_response, re.DOTALL)
                if json_match:
                    json_str = json_match.group(0)
                    print(f"Found JSON with facts pattern: '{json_str}'")
                else:
                    # Fallback: any JSON-like structure
                    json_match = re.search(r'\{.*?\}', cleaned_response, re.DOTALL)
                    if json_match:
                        json_str = json_match.group(0)
                        print(f"Found JSON with fallback pattern: '{json_str}'")
            
            if not json_str:
                print("No JSON pattern found, using fallback extraction")
                return self._fallback_extraction(response)
            
            # Step 3: Fix common JSON issues
            json_str = self._fix_json_issues(json_str)
            print(f"After fixing JSON: '{json_str}'")
            
            # Step 4: Parse JSON
            try:
                data = json.loads(json_str)
                print(f"Successfully parsed JSON: {data}")
            except json.JSONDecodeError as e:
                print(f"JSON decode failed: {e}")
                print("Attempting repair...")
                
                repaired_json = self._repair_json_advanced(json_str)
                try:
                    data = json.loads(repaired_json)
                    print("Successfully repaired and parsed JSON")
                except Exception as repair_e:
                    print(f"JSON repair also failed: {repair_e}")
                    return self._fallback_extraction(response)
            
            # Step 5: Extract facts from parsed data
            if isinstance(data, dict) and "facts" in data:
                facts_list = data["facts"]
                if isinstance(facts_list, list):
                    print(f"Found {len(facts_list)} facts in data")
                    
                    for i, fact_data in enumerate(facts_list):
                        if isinstance(fact_data, dict):
                            text = fact_data.get("text", "").strip()
                            if text:
                                fact = AtomicFact(
                                    text=text,
                                    category=fact_data.get("category", "unknown"),
                                    source_sentence=fact_data.get("source_sentence", "")
                                )
                                facts.append(fact)
                                print(f"Created fact {i+1}: {fact.text}")
                            else:
                                print(f"Skipping fact {i+1}: no text content")
                else:
                    print(f"'facts' is not a list: {type(facts_list)}")
            else:
                print(f"No 'facts' key found. Available keys: {list(data.keys()) if isinstance(data, dict) else 'not a dict'}")
            
            if not facts:
                print("No valid facts extracted, using fallback")
                return self._fallback_extraction(response)
            
            return facts
            
        except Exception as e:
            print(f"Error in _parse_extraction_response: {e}")
            import traceback
            traceback.print_exc()
            return self._fallback_extraction(response)

    def _fix_incomplete_json(self, json_str):
        """Fix common incomplete JSON issues"""
        # Remove trailing incomplete objects
        if json_str.count('{') > json_str.count('}'):
            # Add missing closing braces
            missing_braces = json_str.count('{') - json_str.count('}')
            json_str += '}' * missing_braces

        # Fix incomplete arrays
        if json_str.count('[') > json_str.count(']'):
            missing_brackets = json_str.count('[') - json_str.count(']')
            json_str += ']' * missing_brackets

        # Remove trailing commas
        json_str = re.sub(r',\s*}', '}', json_str)
        json_str = re.sub(r',\s*]', ']', json_str)

        return json_str

    def _repair_json(self, json_str):
        """Attempt to repair malformed JSON"""
        # Remove incomplete trailing objects
        lines = json_str.split('\n')
        clean_lines = []

        for line in lines:
            # Skip lines that look incomplete (no closing quote)
            if '"text":' in line and line.count('"') % 2 != 0:
                continue
            clean_lines.append(line)

        repaired = '\n'.join(clean_lines)

        # Ensure proper closing
        if not repaired.strip().endswith('}'):
            repaired = repaired.rstrip() + '\n    ]\n}'

        return repaired

## LLM Judge

In [1]:
# use your own llm judge

In [None]:
class LLMJudge:
    def __init__(self, model_name, region_name):

        self.claude_client = ClaudeBedrockClient(model_name, region_name)
        self.model_name = model_name

    def judge_fact_accuracy(self, fact, context="", temperature=0.1):
        full_prompt = f"""You are a medical expert evaluating the accuracy of medical facts. 
        Provide a confidence score between 0.0 and 1.0 where:
        - 1.0 = Definitely accurate and well-established medical fact
        - 0.8-0.9 = Very likely accurate with strong evidence
        - 0.6-0.7 = Probably accurate but may have exceptions
        - 0.4-0.5 = Uncertain or conflicting evidence
        - 0.2-0.3 = Probably inaccurate
        - 0.0-0.1 = Definitely inaccurate or harmful
        Please evaluate this medical fact:

        Fact: {fact}

        Context: {context if context else "No additional context provided"}

        Respond with just a JSON object containing:
        {{
            "confidence_score": <float between 0.0 and 1.0>,
            "reasoning": "<brief explanation>",
            "medical_category": "<relevant medical specialty>"
        }}

        Provide your assessment as JSON only."""

        messages = [{"role": "user", "content": full_prompt}]

        try:
            response = self.claude_client.chat_completion(
                messages=messages,
                temperature=temperature,
                max_tokens=500
            )

            content = response['choices'][0]['message']['content']

            try:
                json_match = re.search(r'\{.*\}', content, re.DOTALL)
                if json_match:
                    result = json.loads(json_match.group())

                    if 'confidence_score' not in result:
                        result['confidence_score'] = 0.5

                    result['model_used'] = self.model_name
                    result['raw_response'] = content

                    return result
                else:
                    raise ValueError("No JSON found in response")

            except (json.JSONDecodeError, ValueError) as e:
                print(f"Failed to parse Claude response as JSON: {e}")
                print(f"Raw response: {content}")

                score_match = re.search(r'(\d+\.?\d*)', content)
                score = float(score_match.group(1)) if score_match else 0.5
                if score > 1.0:
                    score = score / 10.0

                return {
                    'confidence_score': min(1.0, max(0.0, score)),
                    'reasoning': content[:200] + "..." if len(content) > 200 else content,
                    'medical_category': 'unknown',
                    'model_used': self.model_name,
                    'raw_response': content,
                    'parsing_error': str(e)
                }

        except Exception as e:
            print(f"Error in Claude fact judgment: {e}")
            return {
                'confidence_score': 0.0,
                'reasoning': f'Error occurred: {str(e)}',
                'medical_category': 'error',
                'model_used': self.model_name,
                'error': str(e)
            }

    def batch_judge_facts(self, facts, context=""):
        """Judge multiple facts in batch"""
        results = []

        for i, fact in enumerate(facts):
            print(f"Judging fact {i + 1}/{len(facts)}: {fact[:100]}...")
            result = self.judge_fact_accuracy(fact, context)
            results.append(result)

        return results

In [None]:
class LLMJudgeVerifier:
    def __init__(self, model_name, region_name):
        self.model_name = model_name
        self.verification_prompt = self._build_verification_prompt()
        self.client = ClaudeBedrockClient(model_name, region_name)

    def _build_verification_prompt(self):
        return """You are a medical fact verification specialist. Evaluate the clinical accuracy of the given atomic fact.

        VERIFICATION CRITERIA:
        1. Medical Accuracy: Is the fact clinically correct?
        2. Specificity: Is the claim specific enough to be verifiable?
        3. Context Appropriateness: Does the fact make sense in the clinical context?
        4. Evidence Base: Is this supported by established medical knowledge?

        FACT TO VERIFY: {fact_text}
        MEDICAL CONTEXT: {context}

        You MUST respond with valid JSON in exactly this format (no extra text):
        {{
            "confidence_score": 0.85,
            "is_accurate": true,
            "reasoning": "detailed explanation of your assessment",
            "concerns": ["any specific concerns or caveats"]
        }}

        CRITICAL: The confidence_score must be a decimal number between 0.0 and 1.0.

        ASSESSMENT:"""

    def verify_fact(self, fact, medical_context=""):
        """
        Verify a single atomic fact using LLM judge
        Returns confidence score [0, 1]
        """
        try:
            # Format verification prompt
            prompt = self.verification_prompt.format(
                fact_text=fact.text,
                context=medical_context
            )
            response = self._call_verification_llm(prompt)
            confidence = self._parse_verification_response(response)
            fact.llm_score = confidence
            return confidence

        except Exception as e:
            print(f"Error in LLM verification: {e}")
            return 0.0

    def _call_verification_llm(self, prompt):
        full_prompt = f"""You are a medical fact verification specialist.

    {prompt}"""
        response = self.client.chat_completion(
            messages=[{"role": "user", "content": full_prompt}],
            temperature=0.1,
            max_tokens=500
        )
        return response['choices'][0]['message']['content']

    def _parse_verification_response(self, response):
        try:
            if not response or response.strip() == "":
                return 0.0

            cleaned = response.strip()
            cleaned = re.sub(r'```json\s*', '', cleaned)
            cleaned = re.sub(r'```\s*', '', cleaned)

            json_patterns = [
                r'\{[^{}]*"confidence_score"[^{}]*:[^{}]*[0-9.]+[^{}]*\}',
                r'\{.*?"confidence_score".*?\}',
                r'\{.*\}',
            ]
            for pattern in json_patterns:
                matches = re.findall(pattern, cleaned, re.DOTALL | re.IGNORECASE)
                for match in matches:
                    try:
                        json_str = match.strip()
                        json_str = re.sub(r',\s*}', '}', json_str)
                        json_str = re.sub(r',\s*]', ']', json_str)

                        data = json.loads(json_str)
                        score = data.get("confidence_score")
                        if score is not None:
                            return max(0.0, min(1.0, float(score)))

                    except (json.JSONDecodeError, ValueError, TypeError):
                        continue

            score_patterns = [
                r'"confidence_score"[:\s]*([0-9.]+)',
                r'confidence[_\s]*score[:\s]*([0-9.]+)',
                r'score[:\s]*([0-9.]+)',
                r'confidence[:\s]*([0-9.]+)'
            ]

            for pattern in score_patterns:
                match = re.search(pattern, cleaned, re.IGNORECASE)
                if match:
                    try:
                        score = float(match.group(1))
                        return max(0.0, min(1.0, score))
                    except (ValueError, IndexError):
                        continue

            number_match = re.search(r'([0-9.]+)', cleaned)
            if number_match:
                try:
                    score = float(number_match.group(1))
                    if 0.0 <= score <= 1.0:
                        return score
                    elif 0.0 <= score <= 5.0:
                        return score / 5.0
                    elif 0.0 <= score <= 10.0:
                        return score / 10.0
                except ValueError:
                    pass

            print(f"Could not parse verification response: {response[:200]}")
            return 0.5

        except Exception as e:
            print(f"Error parsing verification response: {e}")
            return 0.0

## KB Verify

### Local KB

In [None]:
class LocalMedicalKnowledgeBase:
    def __init__(self, 
                 embedding_model= 'all-MiniLM-L6-v2',
                 cache_dir= "local_kb_cache"):

        self.embedding_model = SentenceTransformer(embedding_model)
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)

        self.knowledge_base = []
        self.knowledge_embeddings = None

        self._initialize_medical_knowledge()

    def _initialize_medical_knowledge(self):
        """Initialize with comprehensive medical facts"""
        cache_file = os.path.join(self.cache_dir, "local_medical_knowledge.pkl")

        if os.path.exists(cache_file):
            try:
                with open(cache_file, 'rb') as f:
                    cached_data = pickle.load(f)
                    self.knowledge_base = cached_data['facts']
                    self.knowledge_embeddings = cached_data['embeddings']
                    print(f"Loaded {len(self.knowledge_base)} medical facts from cache")
                    return
            except Exception as e:
                print(f"Failed to load cache: {e}")
        
        print("Building local medical knowledge base...")
        self.knowledge_base = self._create_medical_facts()
        
        fact_texts = [fact['text'] for fact in self.knowledge_base]
        print("Generating embeddings for medical facts...")
        self.knowledge_embeddings = self.embedding_model.encode(fact_texts)
        
        try:
            with open(cache_file, 'wb') as f:
                pickle.dump({
                    'facts': self.knowledge_base,
                    'embeddings': self.knowledge_embeddings
                }, f)
            print(f"Cached {len(self.knowledge_base)} medical facts")
        except Exception as e:
            print(f"Failed to cache knowledge base: {e}")
    
    def _create_medical_facts(self) -> List[Dict]:
        """Create comprehensive medical knowledge base"""
        return [
            # CARDIOLOGY
            {"text": "Myocardial infarction is caused by coronary artery occlusion leading to cardiac muscle death", "category": "cardiology", "source": "medical_textbook", "confidence": 0.95},
            {"text": "Chest pain, shortness of breath, and diaphoresis are classic symptoms of myocardial infarction", "category": "cardiology", "source": "clinical_guidelines", "confidence": 0.92},
            {"text": "Hypertension is defined as systolic blood pressure ≥140 mmHg or diastolic ≥90 mmHg", "category": "cardiology", "source": "AHA_guidelines", "confidence": 0.96},
            {"text": "ACE inhibitors are first-line treatment for hypertension and heart failure", "category": "cardiology", "source": "treatment_guidelines", "confidence": 0.90},
            {"text": "Atrial fibrillation significantly increases the risk of stroke", "category": "cardiology", "source": "clinical_studies", "confidence": 0.94},
            {"text": "Beta-blockers reduce heart rate and myocardial oxygen demand", "category": "cardiology", "source": "pharmacology", "confidence": 0.93},
            {"text": "Electrocardiography shows ST-elevation in acute myocardial infarction", "category": "cardiology", "source": "diagnostic_criteria", "confidence": 0.91},
            {"text": "Cardiac catheterization is the gold standard for diagnosing coronary artery disease", "category": "cardiology", "source": "diagnostic_procedures", "confidence": 0.89},
            
            # ENDOCRINOLOGY
            {"text": "Type 1 diabetes mellitus is caused by autoimmune destruction of pancreatic beta cells", "category": "endocrinology", "source": "pathophysiology", "confidence": 0.96},
            {"text": "Type 2 diabetes mellitus is characterized by insulin resistance and relative insulin deficiency", "category": "endocrinology", "source": "pathophysiology", "confidence": 0.95},
            {"text": "Metformin is the first-line treatment for type 2 diabetes mellitus", "category": "endocrinology", "source": "ADA_guidelines", "confidence": 0.94},
            {"text": "HbA1c ≥6.5% indicates diabetes mellitus diagnosis", "category": "endocrinology", "source": "diagnostic_criteria", "confidence": 0.96},
            {"text": "Insulin is absolutely required for type 1 diabetes management", "category": "endocrinology", "source": "treatment_standards", "confidence": 0.98},
            {"text": "Diabetic ketoacidosis is a life-threatening complication of diabetes", "category": "endocrinology", "source": "emergency_medicine", "confidence": 0.93},
            {"text": "Hypoglycemia symptoms include diaphoresis, tremor, and altered mental status", "category": "endocrinology", "source": "clinical_presentation", "confidence": 0.91},
            {"text": "Thyroid stimulating hormone (TSH) is elevated in hypothyroidism", "category": "endocrinology", "source": "laboratory_medicine", "confidence": 0.94},
            
            # PULMONOLOGY
            {"text": "Pneumonia causes consolidation visible on chest radiograph", "category": "pulmonology", "source": "radiology", "confidence": 0.90},
            {"text": "Asthma is characterized by reversible airway obstruction and inflammation", "category": "pulmonology", "source": "pathophysiology", "confidence": 0.93},
            {"text": "Albuterol is a short-acting beta-2 agonist bronchodilator", "category": "pulmonology", "source": "pharmacology", "confidence": 0.96},
            {"text": "Chronic obstructive pulmonary disease (COPD) is primarily caused by tobacco smoking", "category": "pulmonology", "source": "epidemiology", "confidence": 0.91},
            {"text": "Pulmonary embolism can cause sudden onset dyspnea and chest pain", "category": "pulmonology", "source": "clinical_presentation", "confidence": 0.89},
            {"text": "Spirometry shows reduced FEV1/FVC ratio in obstructive lung disease", "category": "pulmonology", "source": "pulmonary_function", "confidence": 0.92},
            {"text": "Oxygen saturation below 90% indicates significant hypoxemia", "category": "pulmonology", "source": "critical_care", "confidence": 0.88},
            
            # INFECTIOUS DISEASE
            {"text": "Sepsis is a life-threatening organ dysfunction caused by dysregulated host response to infection", "category": "infectious_disease", "source": "sepsis_guidelines", "confidence": 0.94},
            {"text": "Penicillin is effective against gram-positive bacterial infections", "category": "infectious_disease", "source": "microbiology", "confidence": 0.92},
            {"text": "Viral infections do not respond to antibiotic treatment", "category": "infectious_disease", "source": "antimicrobial_stewardship", "confidence": 0.97},
            {"text": "Urinary tract infections commonly present with dysuria and urinary frequency", "category": "infectious_disease", "source": "clinical_presentation", "confidence": 0.88},
            {"text": "Blood cultures should be obtained before starting empiric antibiotic therapy", "category": "infectious_disease", "source": "diagnostic_guidelines", "confidence": 0.85},
            {"text": "Methicillin-resistant Staphylococcus aureus (MRSA) requires vancomycin treatment", "category": "infectious_disease", "source": "antimicrobial_guidelines", "confidence": 0.91},
            
            # NEUROLOGY
            {"text": "Stroke symptoms include sudden onset focal neurological deficits", "category": "neurology", "source": "clinical_criteria", "confidence": 0.93},
            {"text": "Tissue plasminogen activator (tPA) is used for acute ischemic stroke within 4.5 hours", "category": "neurology", "source": "stroke_guidelines", "confidence": 0.90},
            {"text": "Seizures can be focal or generalized based on their origin and spread", "category": "neurology", "source": "epilepsy_classification", "confidence": 0.94},
            {"text": "Computed tomography (CT) can rapidly detect hemorrhagic stroke", "category": "neurology", "source": "neuroimaging", "confidence": 0.87},
            {"text": "Lumbar puncture is contraindicated with increased intracranial pressure", "category": "neurology", "source": "procedural_guidelines", "confidence": 0.89},
            {"text": "Multiple sclerosis causes demyelinating lesions in the central nervous system", "category": "neurology", "source": "pathophysiology", "confidence": 0.92},
            
            # PSYCHIATRY
            {"text": "Major depressive disorder requires at least 2 weeks of depressive symptoms", "category": "psychiatry", "source": "DSM5", "confidence": 0.96},
            {"text": "Selective serotonin reuptake inhibitors (SSRIs) are first-line treatment for depression", "category": "psychiatry", "source": "treatment_guidelines", "confidence": 0.88},
            {"text": "Bipolar disorder includes both manic and depressive episodes", "category": "psychiatry", "source": "DSM5", "confidence": 0.95},
            {"text": "Suicidal ideation requires immediate safety assessment and intervention", "category": "psychiatry", "source": "crisis_intervention", "confidence": 0.97},
            {"text": "Antipsychotic medications are used to treat schizophrenia and psychotic disorders", "category": "psychiatry", "source": "psychopharmacology", "confidence": 0.91},
            
            # PHARMACOLOGY
            {"text": "Warfarin requires regular INR monitoring due to narrow therapeutic window", "category": "pharmacology", "source": "anticoagulation_guidelines", "confidence": 0.93},
            {"text": "Nonsteroidal anti-inflammatory drugs (NSAIDs) can cause gastric ulceration", "category": "pharmacology", "source": "adverse_effects", "confidence": 0.90},
            {"text": "Statins reduce cholesterol synthesis by inhibiting HMG-CoA reductase", "category": "pharmacology", "source": "mechanism_of_action", "confidence": 0.94},
            {"text": "Opioids can cause respiratory depression at high doses", "category": "pharmacology", "source": "toxicology", "confidence": 0.91},
            {"text": "Drug interactions can alter medication effectiveness and safety", "category": "pharmacology", "source": "clinical_pharmacology", "confidence": 0.87},
            
            # PEDIATRICS
            {"text": "Sudden infant death syndrome risk is reduced by supine sleeping position", "category": "pediatrics", "source": "AAP_guidelines", "confidence": 0.92},
            {"text": "Febrile seizures are common in children aged 6 months to 5 years", "category": "pediatrics", "source": "pediatric_neurology", "confidence": 0.89},
            {"text": "Vaccination schedules protect children from preventable diseases", "category": "pediatrics", "source": "immunization_guidelines", "confidence": 0.95},
            {"text": "Growth charts assess normal pediatric development", "category": "pediatrics", "source": "developmental_medicine", "confidence": 0.86},
            
            # SURGERY
            {"text": "Appendicitis typically presents with right lower quadrant abdominal pain", "category": "surgery", "source": "surgical_diagnosis", "confidence": 0.90},
            {"text": "Cholecystitis causes right upper quadrant pain and Murphy's sign", "category": "surgery", "source": "surgical_examination", "confidence": 0.88},
            {"text": "Surgical site infections are prevented by proper sterile technique", "category": "surgery", "source": "infection_control", "confidence": 0.91},
            {"text": "Bowel obstruction can cause abdominal distension and vomiting", "category": "surgery", "source": "surgical_emergencies", "confidence": 0.87},
            
            # OBSTETRICS/GYNECOLOGY
            {"text": "Preeclampsia is defined by hypertension and proteinuria in pregnancy", "category": "obstetrics", "source": "ACOG_guidelines", "confidence": 0.94},
            {"text": "Folic acid supplementation prevents neural tube defects", "category": "obstetrics", "source": "preventive_medicine", "confidence": 0.92},
            {"text": "Regular prenatal care improves maternal and fetal outcomes", "category": "obstetrics", "source": "prenatal_guidelines", "confidence": 0.89},
            {"text": "Gestational diabetes increases risk of macrosomia", "category": "obstetrics", "source": "maternal_fetal_medicine", "confidence": 0.86},
            
            # EMERGENCY MEDICINE
            {"text": "Advanced Cardiac Life Support (ACLS) protocols guide cardiac arrest management", "category": "emergency_medicine", "source": "AHA_guidelines", "confidence": 0.95},
            {"text": "Trauma patients require primary and secondary survey assessment", "category": "emergency_medicine", "source": "ATLS_guidelines", "confidence": 0.93},
            {"text": "Anaphylaxis is treated with intramuscular epinephrine", "category": "emergency_medicine", "source": "allergy_guidelines", "confidence": 0.96},
            {"text": "Glasgow Coma Scale assesses level of consciousness", "category": "emergency_medicine", "source": "neurological_assessment", "confidence": 0.91},
            
            # NEPHROLOGY
            {"text": "Chronic kidney disease is staged based on glomerular filtration rate", "category": "nephrology", "source": "KDIGO_guidelines", "confidence": 0.93},
            {"text": "Acute kidney injury can be prerenal, intrinsic, or postrenal", "category": "nephrology", "source": "nephrology_classification", "confidence": 0.90},
            {"text": "Dialysis is indicated for severe uremia or fluid overload", "category": "nephrology", "source": "renal_replacement_therapy", "confidence": 0.88},
            {"text": "Proteinuria indicates glomerular kidney disease", "category": "nephrology", "source": "laboratory_findings", "confidence": 0.85},
            
            # GASTROENTEROLOGY
            {"text": "Gastroesophageal reflux disease (GERD) causes heartburn and regurgitation", "category": "gastroenterology", "source": "clinical_presentation", "confidence": 0.87},
            {"text": "Peptic ulcer disease is commonly caused by Helicobacter pylori", "category": "gastroenterology", "source": "etiology", "confidence": 0.91},
            {"text": "Inflammatory bowel disease includes Crohn's disease and ulcerative colitis", "category": "gastroenterology", "source": "disease_classification", "confidence": 0.93},
            {"text": "Cirrhosis can lead to portal hypertension and ascites", "category": "gastroenterology", "source": "hepatology", "confidence": 0.89},
            
            # ONCOLOGY
            {"text": "Cancer staging determines prognosis and treatment approach", "category": "oncology", "source": "cancer_guidelines", "confidence": 0.92},
            {"text": "Chemotherapy targets rapidly dividing cancer cells", "category": "oncology", "source": "cancer_treatment", "confidence": 0.90},
            {"text": "Tumor markers can aid in cancer diagnosis and monitoring", "category": "oncology", "source": "laboratory_oncology", "confidence": 0.84},
            {"text": "Radiation therapy delivers targeted energy to destroy cancer cells", "category": "oncology", "source": "radiation_oncology", "confidence": 0.88},
            
            # DERMATOLOGY
            {"text": "Melanoma is the most dangerous form of skin cancer", "category": "dermatology", "source": "dermatopathology", "confidence": 0.91},
            {"text": "Topical corticosteroids treat inflammatory skin conditions", "category": "dermatology", "source": "dermatologic_therapeutics", "confidence": 0.86},
            {"text": "Skin biopsy provides definitive diagnosis of skin lesions", "category": "dermatology", "source": "diagnostic_procedures", "confidence": 0.89},
            
            # RHEUMATOLOGY
            {"text": "Rheumatoid arthritis is an autoimmune inflammatory joint disease", "category": "rheumatology", "source": "autoimmune_diseases", "confidence": 0.93},
            {"text": "Disease-modifying antirheumatic drugs (DMARDs) slow joint destruction", "category": "rheumatology", "source": "treatment_guidelines", "confidence": 0.90},
            {"text": "Systemic lupus erythematosus affects multiple organ systems", "category": "rheumatology", "source": "connective_tissue_disorders", "confidence": 0.88},
        ]
    
    def verify_fact(self, fact, threshold: float = 0.7) -> float:
        if not self.knowledge_base or self.knowledge_embeddings.size == 0:
            return 0.5  # Neutral score if no knowledge base
        
        try:
            # Generate embedding for the fact
            fact_text = fact.text if hasattr(fact, 'text') else str(fact)
            fact_embedding = self.embedding_model.encode([fact_text])
            
            # Calculate similarities
            similarities = np.dot(fact_embedding, self.knowledge_embeddings.T)[0]
            
            # Get the best matching facts
            best_matches_idx = np.argsort(similarities)[-3:][::-1]  # Top 3 matches
            best_similarities = similarities[best_matches_idx]
            
            # Calculate confidence score based on similarity and source confidence
            confidence_scores = []
            for idx, similarity in zip(best_matches_idx, best_similarities):
                kb_fact = self.knowledge_base[idx]
                source_confidence = kb_fact.get('confidence', 0.8)
                
                # Combine similarity and source confidence
                combined_score = similarity * source_confidence
                confidence_scores.append(combined_score)
            
            # Use the best match
            final_confidence = max(confidence_scores) if confidence_scores else 0.0
            
            # Apply threshold-based adjustment
            if max(best_similarities) < threshold:
                final_confidence *= 0.5  # Penalize low similarity
            
            # Update fact with KB score if it's an AtomicFact object
            if hasattr(fact, 'kb_score'):
                fact.kb_score = final_confidence
            
            return min(1.0, max(0.0, final_confidence))
            
        except Exception as e:
            print(f"Error in local knowledge base verification: {e}")
            return 0.0
    
    def search_facts(self, query: str, limit = 10) -> List[Dict]:
        """Search for facts similar to the query"""
        if not self.knowledge_base or self.knowledge_embeddings.size == 0:
            return []
        
        try:
            # Generate embedding for the query
            query_embedding = self.embedding_model.encode([query])
            
            # Calculate similarities
            similarities = np.dot(query_embedding, self.knowledge_embeddings.T)[0]
            
            # Get top matches
            top_indices = np.argsort(similarities)[-limit:][::-1]
            
            results = []
            for idx in top_indices:
                fact = self.knowledge_base[idx].copy()
                fact['similarity_score'] = similarities[idx]
                results.append(fact)
            
            return results
            
        except Exception as e:
            print(f"Error searching local knowledge base: {e}")
            return []
    
    def get_knowledge_stats(self):
        """Get statistics about the local knowledge base"""
        if not self.knowledge_base:
            return {"total_facts": 0, "categories": [], "sources": []}
        
        categories = {}
        sources = {}
        
        for fact in self.knowledge_base:
            cat = fact.get('category', 'unknown')
            src = fact.get('source', 'unknown')
            
            categories[cat] = categories.get(cat, 0) + 1
            sources[src] = sources.get(src, 0) + 1
        
        return {
            "total_facts": len(self.knowledge_base),
            "categories": categories,
            "sources": sources,
            "has_embeddings": self.knowledge_embeddings is not None and self.knowledge_embeddings.size > 0
        }
    
    def add_facts(self, new_facts: List[Dict]):
        """Add new facts to the knowledge base"""
        self.knowledge_base.extend(new_facts)

        # Regenerate embeddings
        fact_texts = [fact['text'] for fact in self.knowledge_base]
        print(f"Regenerating embeddings for {len(fact_texts)} facts...")
        self.knowledge_embeddings = self.embedding_model.encode(fact_texts)

        print(f"Added {len(new_facts)} new facts to knowledge base")


class LocalKnowledgeBaseVerifier:
    """Drop-in replacement using local knowledge base"""
    
    def __init__(self, knowledge_sources: List[str] = None):
        self.knowledge_sources = knowledge_sources or ["Local Medical DB"]
        self.local_kb = LocalMedicalKnowledgeBase()
        
        print("Local Knowledge Base Verifier initialized")
        stats = self.local_kb.get_knowledge_stats()
        print(f"Loaded {stats['total_facts']} facts across {len(stats['categories'])} categories")
    
    def verify_fact(self, fact, threshold: float = 0.7) -> float:
        """Verify fact against local knowledge base"""
        return self.local_kb.verify_fact(fact, threshold)
    
    def search_facts(self, query: str, limit = 10) -> List[Dict]:
        """Search for relevant facts"""
        return self.local_kb.search_facts(query, limit)

## Fact Verifier

In [None]:
class FactualRewardCalculator:
    def __init__(self, agreement_threshold):
        self.agreement_threshold = agreement_threshold

    def extract_think_content(self, generation):
        """Extract content from <think> tags"""
        think_match = re.search(r'<think>(.*?)</think>', generation, re.DOTALL | re.IGNORECASE)
        if think_match:
            content = think_match.group(1).strip()
            return re.sub(r'^Reasoning:\s*', '', content)
        return ""

    def compute_factual_reward(self, facts_list):
        if not facts_list:
            return {
                "factual_reward": 0.0,
                "factual_analysis": {
                    "factual_reward": 0.0,
                    "individual_rewards": [],
                    "agreement_rate": 0.0,
                    "avg_llm_score": 0.0,
                    "avg_kb_score": 0.0,
                    "num_facts": 0
                },
                "error": "No facts provided"
            }
        atomic_facts = []
        for fact_item in facts_list:
            if isinstance(fact_item, AtomicFact):
                atomic_facts.append(fact_item)
            else:
                fact = AtomicFact(
                    text=fact_item.get("text", ""),
                    category=fact_item.get("category", "unknown"),
                    source_sentence=fact_item.get("source_sentence", "")
                )
                fact.llm_score = fact_item.get("llm_score", 0.0)
                fact.kb_score = fact_item.get("kb_score", 0.0)
                atomic_facts.append(fact)

        individual_rewards = []
        agreement_count = 0
        llm_scores = []
        kb_scores = []

        for fact in atomic_facts:
            fact_reward = self._compute_individual_fact_reward(fact)
            individual_rewards.append(fact_reward)

            if abs(fact.llm_score - fact.kb_score) <= self.agreement_threshold:
                agreement_count += 1

            llm_scores.append(fact.llm_score)
            kb_scores.append(fact.kb_score)

        factual_reward = sum(individual_rewards) / len(individual_rewards) if individual_rewards else 0.0

        agreement_rate = agreement_count / len(atomic_facts) if atomic_facts else 0.0
        avg_llm_score = sum(llm_scores) / len(llm_scores) if llm_scores else 0.0
        avg_kb_score = sum(kb_scores) / len(kb_scores) if kb_scores else 0.0

        factual_analysis = {
            "factual_reward": factual_reward,
            "individual_rewards": individual_rewards,
            "agreement_rate": agreement_rate,
            "avg_llm_score": avg_llm_score,
            "avg_kb_score": avg_kb_score,
            "num_facts": len(atomic_facts)
        }

        return {
            "factual_reward": factual_reward,
            "factual_analysis": factual_analysis,
            "error": None
        }

    def _compute_individual_fact_reward(self, fact):
        llm_score = fact.llm_score
        kb_score = fact.kb_score

        # Weight LLM more heavily (e.g., 70% LLM, 30% KB)
        llm_weight = 0.7
        kb_weight = 0.3

        base_score = llm_weight * llm_score + kb_weight * kb_score

        agreement = abs(llm_score - kb_score) <= self.agreement_threshold
        if agreement:
            return base_score
        else:
            return base_score * 0.9

## Atomic Fact Verification System

In [None]:
class AtomicFactVerificationSystem:
    def __init__(self,
                 agreement_threshold,
                 umls_api_key=None):
        self.extractor = AtomicFactExtractor(model_name=MODEL_NAME
                                             region_name="us-east-1")
        self.llm_verifier = LLMJudgeVerifier(model_name=MODEL_NAME,
                                             region_name="us-east-1")
        self.kb_verifier = LocalKnowledgeBaseVerifier()
        self.reward_calculator = FactualRewardCalculator(agreement_threshold)
        self.training_mode = False

    def process_response(self, reasoning_text, context=""):
        try:
            facts = self.extractor.extract_facts(reasoning_text)

            if not facts:
                return {
                    "facts": [],
                    "factual_analysis": self._empty_analysis(),
                    "error": "No facts extracted"
                }

            # Check if we're in training mode to skip expensive verification
            if hasattr(self, 'training_mode') and self.training_mode:
                for fact in facts:
                    fact.llm_score = self._get_training_heuristic_score(fact)
                    fact.kb_score = self._get_training_heuristic_score(fact)
            else:
                for fact in facts:
                    try:
                        # LLM verification
                        self.llm_verifier.verify_fact(fact, context)
                        # KB verification
                        self.kb_verifier.verify_fact(fact)
                        # Small delay to avoid rate limits during evaluation
                        time.sleep(4)
                    except Exception as verification_error:
                        print(f"Verification error for fact: {verification_error}")
                        # Use fallback scores if verification fails
                        fact.llm_score = getattr(fact, 'llm_score', 0.5)
                        fact.kb_score = getattr(fact, 'kb_score', 0.5)

            # Convert facts to dictionaries for consistent handling
            facts_as_dicts = [self._fact_to_dict(fact) for fact in facts]

            # Calculate factual reward
            factual_analysis = self.reward_calculator.compute_factual_reward(facts_as_dicts)

            return {
                "facts": facts_as_dicts,
                "factual_analysis": factual_analysis,
                "error": None
            }

        except Exception as e:
            return {
                "facts": [],
                "factual_analysis": self._empty_analysis(),
                "error": str(e)
            }

    def _get_training_heuristic_score(self, fact):
        fact_text = fact.text.lower()
        # Simple heuristics based on medical fact patterns
        if any(keyword in fact_text for keyword in ['patient', 'symptom', 'diagnosis', 'treatment']):
            # Medical context facts get moderate-high scores
            return 0.7
        elif any(keyword in fact_text for keyword in ['anatomy', 'physiology', 'mechanism']):
            # Basic medical knowledge gets high scores
            return 0.8
        elif len(fact_text) < 20:
            # Very short facts might be incomplete
            return 0.5
        elif any(keyword in fact_text for keyword in ['may', 'might', 'could', 'possibly']):
            # Uncertain statements get lower scores
            return 0.6
        else:
            # Default reasonable score
            return 0.7

    def _fact_to_dict(self, fact: AtomicFact):
        return {
            "text": fact.text,
            "category": fact.category,
            "llm_score": fact.llm_score,
            "kb_score": fact.kb_score,
            "source_sentence": fact.source_sentence
        }

    def _empty_analysis(self):
        return {
            "factual_reward": 0.0,
            "individual_rewards": [],
            "agreement_rate": 0.0,
            "avg_llm_score": 0.0,
            "avg_kb_score": 0.0,
            "num_facts": 0
        }

# Training and Evaluation

## MetricsTracker

In [None]:
class MetricsTracker:
    def __init__(self):
        self.epoch_metrics = []

    def add_epoch_metrics(self, epoch, metrics):
        self.epoch_metrics.append({
            'epoch': epoch,
            'correctness_violations_rate': metrics.get('correctness_violations_rate', 0),
            'answer_leaking_violations_rate': metrics.get('answer_leaking_violations_rate', 0), 
            'format_violations_rate': metrics.get('format_violations_rate', 0),
            'factual_violations_rate': metrics.get('factual_violations_rate', 0),
            'bad_ood_high_rewards_rate': metrics.get('bad_ood_high_rewards_rate', 0),
            'avg_reward': metrics.get('avg_reward', 0),
            'avg_factual': metrics.get('avg_factual', 0)
        })

    
    def print_stage_summary(self, stage_name, start_epoch=0, end_epoch=None):
        """Print mean and std for a training stage"""
        if end_epoch is None:
            end_epoch = len(self.epoch_metrics)
            
        stage_data = self.epoch_metrics[start_epoch:end_epoch]
        if not stage_data:
            print(f"No data for {stage_name}")
            return

        print(f"\n{'='*60}")
        print(f"{stage_name.upper()} SUMMARY (Epochs {start_epoch+1}-{end_epoch})")
        print(f"{'='*60}")
        
        # Extract values for each metric
        metrics = {
            'Correctness Violations': [d['correctness_violations_rate'] for d in stage_data],
            'Answer Leaking': [d['answer_leaking_violations_rate'] for d in stage_data],
            'Format Violations': [d['format_violations_rate'] for d in stage_data], 
            'Factual Violations': [d['factual_violations_rate'] for d in stage_data],
            'Bad OOD Rewards': [d['bad_ood_high_rewards_rate'] for d in stage_data],
            'Average Reward': [d['avg_reward'] for d in stage_data],
            'Factual Reward': [d['avg_factual'] for d in stage_data]
        }
        
        # Print statistics
        for metric_name, values in metrics.items():
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f"{metric_name:20s}: {mean_val:.2f} ± {std_val:.2f}")
            
        # Calculate improvement (first vs last epoch)
        if len(stage_data) > 1:
            print(f"\n{stage_name} Improvements:")
            first_epoch = stage_data[0]
            last_epoch = stage_data[-1]
            
            # For violation rates, improvement is decrease (negative change)
            for metric in ['correctness_violations_rate', 'answer_leaking_violations_rate', 
                          'format_violations_rate', 'factual_violations_rate', 'bad_ood_high_rewards_rate']:
                change = last_epoch[metric] - first_epoch[metric]
                improvement = -change  # Negative change is improvement for violations
                print(f"  {metric.replace('_', ' ').title():25s}: {improvement:+.2f}")
            
            # For rewards, improvement is increase (positive change) 
            for metric in ['avg_reward', 'avg_factual']:
                change = last_epoch[metric] - first_epoch[metric]
                print(f"  {metric.replace('_', ' ').title():25s}: {change:+.2f}")


    def print_metrics_with_std(self, metrics, title="METRICS", include_std=True):
        """Print metrics in consistent format with optional standard deviations"""
        print(f"\n{title}")
        print("-" * len(title))
        
        # Define metric display configurations
        metric_configs = [
            ('correctness_violations_rate', 'Correctness Violations', 'correctness_violations_std'),
            ('answer_leaking_violations_rate', 'Answer Leakage Rate', 'answer_leaking_violations_std'),
            ('format_violations_rate', 'Bad Format Rate', 'format_violations_std'),
            ('factual_violations_rate', 'Factual Violations', 'factual_violations_std'),
            ('bad_ood_high_rewards_rate', 'High Rewards Rate', 'bad_ood_high_rewards_std'),
            ('avg_reward', 'Average Reward', 'std_reward'),
            ('avg_factual', 'Factual Reward', None),
            ('accuracy', 'Accuracy', None)
        ]
        
        for metric_key, display_name, std_key in metric_configs:
            if metric_key in metrics:
                value = metrics[metric_key]
                if include_std and std_key and std_key in metrics:
                    std_value = metrics[std_key]
                    print(f"{display_name}: {value:.2f} (±{std_value:.2f})")
                else:
                    print(f"{display_name}: {value:.2f}")
    
    def add_adversarial_stage_metrics(self, pre_metrics, post_metrics):
        """Add special adversarial training stage metrics"""
        # Store pre-adversarial as epoch -1
        self.add_epoch_metrics(-1, pre_metrics)
        
        # Store post-adversarial as epoch 999
        self.add_epoch_metrics(999, post_metrics)
        
        # Calculate improvement metrics
        improvement_metrics = {}
        for key in pre_metrics:
            if key in post_metrics and isinstance(pre_metrics[key], (int, float)):
                improvement_metrics[f"{key}_improvement"] = post_metrics[key] - pre_metrics[key]
        
        # Store improvement metrics
        self.add_epoch_metrics(1000, improvement_metrics)
    
    def print_adversarial_impact_analysis(self):
        """Print comprehensive adversarial training impact analysis"""
        if -1 in self.epoch_metrics and 999 in self.epoch_metrics:
            pre_metrics = self.epoch_metrics[-1]
            post_metrics = self.epoch_metrics[999]
            
            print("\n" + "=" * 60)
            print("ADVERSARIAL TRAINING IMPACT ANALYSIS")
            print("=" * 60)
            
            print("\nPRE-ADVERSARIAL PERFORMANCE:")
            self.print_metrics_with_std(pre_metrics, "", include_std=True)
            
            print("\nPOST-ADVERSARIAL PERFORMANCE:")
            self.print_metrics_with_std(post_metrics, "", include_std=True)
                            
        else:
            print("Adversarial training metrics not found")
    
    
    def save_metrics(self, filepath="training_metrics.json"):
        """Save metrics to JSON file"""
        with open(filepath, 'w') as f:
            json.dump(self.epoch_metrics, f, indent=2)
        print(f"Metrics saved to {filepath}")

def monitor_memory():
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024 ** 3
        memory_reserved = torch.cuda.memory_reserved() / 1024 ** 3
        print(f"GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB")

## With Llama3.2-3B-Instruct-SFT

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

train_df = pd.read_json("medqa_train.json")
test_df = pd.read_json("medqa_test.json")

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()
model_url = "Llama3.2-3B-Instruct-SFT"

trainer_1 = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key=UMLS_API_KEY
)

print("=" * 80)
train_subset_df = train_df.iloc[:50]
train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

comparison_results = trainer_1.compare_training_approaches(
    train_data_path="medqa_train_sample.json",
    test_data_path="medqa_test.json",
    stage1_epochs=3, 
    stage2_epochs=1,
    max_eval_examples=50
)
print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

In [None]:
trainer = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key=UMLS_API_KEY
)


print("=" * 80)
train_subset_df = train_df.iloc[:300]
train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

trainer.train_two_stage_pipeline(
    train_data_path="medqa_train_sample.json",
    stage1_epochs=2,
    stage2_epochs=2,
    batch_size=8,
    adversarial_frequency=2
)

## OOD with MMLU-PRO

In [None]:
# Local path where model was saved (download from S3 if needed)
local_model_dir = LOCAL_MODEL_DIR

# Configuration for reward function
reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_f': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}

trainer = PolicyTrainer(
    model_path=local_model_dir,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key=UMLS_API_KEY
)

# 1. Reload model + tokenizer
trainer.model = AutoModelForCausalLM.from_pretrained(local_model_dir).to(trainer.device)
trainer.tokenizer = AutoTokenizer.from_pretrained(local_model_dir)

# 2. Reload baseline network
baseline_path = os.path.join(local_model_dir, "baseline_network.pt")
trainer.baseline_network.load_state_dict(torch.load(baseline_path, map_location=trainer.device))
trainer.baseline_network.to(trainer.device)
trainer.baseline_network.eval()

print("\nRunning final evaluation on the test set...")
final_results = trainer.evaluate_model("mmlu_pro_health_test.json", 200)

## With Llama3.2-3B-Instruct

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

train_df = pd.read_json("medqa_train.json")
test_df = pd.read_json("medqa_test.json")

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()
model_url = "meta-llama/Llama-3.2-3B-Instruct"

trainer_3 = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key=UMLS_API_KEY
)

print("=" * 80)
train_subset_df = train_df.iloc[:50]
train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

comparison_results = trainer_3.compare_training_approaches(
    train_data_path="medqa_train_sample.json",
    test_data_path="medqa_test.json",
    stage1_epochs=3, 
    stage2_epochs=1,
    max_eval_examples=50
)
print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

## OOD with MMLU-PRO (OG llama)

In [None]:
local_model_dir = LOCAL_MODEL_DIR

# Configuration for reward function
reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_f': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}

trainer_2 = PolicyTrainer(
    model_path=local_model_dir,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key=UMLS_API_KEY
)

# 1. Reload model + tokenizer
trainer_2.model = AutoModelForCausalLM.from_pretrained(local_model_dir)
trainer_2.tokenizer = AutoTokenizer.from_pretrained(local_model_dir)

# 2. Reload baseline network
baseline_path = os.path.join(local_model_dir, "baseline_network.pt")
trainer_2.baseline_network.load_state_dict(torch.load(baseline_path, map_location=trainer_2.device))
trainer_2.baseline_network.to(trainer.device)
trainer_2.baseline_network.eval()

print("\nRunning final evaluation on the test set...")
final_results = trainer_2.evaluate_model("mmlu_pro_health_test.json", 200)

## With Qwen2.5-3B-Instruct

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

train_df = pd.read_json("medqa_train.json")
test_df = pd.read_json("medqa_test.json")

reward_config = {
    'w_b': 1.0,
    'w_a': 0.2,
    'w_s': 0.2,
    'w_fact': 0.2,
    'tau_answer': 0.7,
    'tau_preamble': 15,
    'lambda_s': 1.0
}
print("Initial memory state:")
monitor_memory()
model_url = "sft_Qwen2_5_3B_Instruct"

trainer_3 = PolicyTrainer(
    model_path=model_url,
    reward_config=reward_config,
    use_baseline=True,
    umls_api_key=UMLS_API_KEY
)

print("=" * 80)
train_subset_df = train_df.iloc[:50]
train_subset_df.to_json("medqa_train_sample.json", orient='records', indent=4)

comparison_results = trainer_3.compare_training_approaches(
    train_data_path="medqa_train_sample.json",
    test_data_path="medqa_test.json",
    stage1_epochs=2,
    stage2_epochs=1,
    max_eval_examples=50
)
print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)