In [None]:
import numpy as np
import json
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig
)
from common import constants
from common.mwoz_data import CustomMwozDataset
from common.utils import compute_prediction_scores
from typing import List, Dict
from tqdm import tqdm


class ModelEvaluator:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_type = 'mistral' if 'mistral' in model_name.lower() else 'gemma'
        self.response_marker = "[/INST]" if self.model_type == 'mistral' else "ASSISTANT:"
        self.model, self.tokenizer = self._load_model()
        
    def _load_model(self):
        # Load in the same way as training
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
        
        # Model specific configurations
        model_kwargs = {
            "quantization_config": bnb_config,
            "low_cpu_mem_usage": True,
        }
        
        if self.model_type == 'mistral':
            model_kwargs["attn_implementation"] = "flash_attention_2"
        else:
            model_kwargs["attn_implementation"] = "eager"
            
        print(f"Loading model from {constants.MERGED_MODEL[self.model_name]}")
        model = AutoModelForCausalLM.from_pretrained(
            constants.MERGED_MODEL[self.model_name],
            **model_kwargs
        )
        
        tokenizer = AutoTokenizer.from_pretrained(constants.MERGED_MODEL[self.model_name])
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'right'
        
        return model, tokenizer
    
    def decode_response(self, output_ids: torch.Tensor, to_print: bool = False) -> str:
        """Decode model output following training format"""
        decoded = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        
        if to_print:
            print("Raw response:", decoded)
            
        # Extract response part
        if self.response_marker in decoded:
            et_preds = decoded.split(self.response_marker)[1].strip()
        else:
            et_preds = decoded
            
        # Handle proper formats as in training
        if 'no entity' in et_preds:
            et_preds = '[no entity]'
        else:
            et_preds = et_preds.replace('[', '').replace(']', '').strip()
            
        if to_print:
            print("Cleaned response:", et_preds)
            
        return et_preds
    
    def generate_with_strategy(self, 
                             input_ids: torch.Tensor,
                             attention_mask: torch.Tensor,
                             strategy: str = "greedy",
                             **kwargs) -> List[str]:
        """Generate responses with different decoding strategies"""
        
        generation_config = {
            "max_new_tokens": constants.MAX_NEW_TOKENS,
            "do_sample": False,
            "num_beams": 1,
        }
        
        if strategy == "beam":
            generation_config.update({
                "num_beams": kwargs.get("num_beams", 5),
                "length_penalty": kwargs.get("length_penalty", 0.6),
                "num_return_sequences": kwargs.get("num_return_sequences", 1),
                "no_repeat_ngram_size": 3,
            })
        elif strategy == "sampling":
            generation_config.update({
                "do_sample": True,
                "num_return_sequences": kwargs.get("num_samples", 5),
                "temperature": kwargs.get("temperature", 0.7),
                "top_p": kwargs.get("top_p", 0.9),
            })
            
        outputs = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        
        # Handle multiple sequences
        if (strategy == "beam" and kwargs.get("num_return_sequences", 1) > 1) or \
           (strategy == "sampling" and kwargs.get("num_samples", 1) > 1):
            outputs = outputs.reshape(input_ids.shape[0], -1, outputs.shape[-1])
            responses = []
            for batch_outputs in outputs:
                batch_responses = []
                for sequence in batch_outputs:
                    generated_ids = sequence[input_ids.shape[1]:]
                    response = self.decode_response(generated_ids)
                    batch_responses.append(response)
                
                if strategy == "sampling":
                    # For sampling, use majority voting
                    from collections import Counter
                    response = Counter(batch_responses).most_common(1)[0][0]
                else:
                    # For beam search, use first (highest probability) response
                    response = batch_responses[0]
                    
                responses.append(response)
            return responses
        else:
            generated_ids = outputs[:, input_ids.shape[1]:]
            responses = [self.decode_response(ids) for ids in generated_ids]
            return responses
    
    def evaluate(self, 
                test_data_path: str,
                strategy: str = "greedy",
                batch_size: int = 1,
                save_results: bool = True,
                **generation_kwargs) -> Dict:
        """Evaluate model on test set"""
        
        print(f"\nEvaluating {self.model_name} using {strategy} decoding...")
        
        # Load test dataset
        test_dataset = CustomMwozDataset(
            self.tokenizer,
            data_filename=test_data_path,
            model_type=self.model_type,
            mode='infer'
        ).data
        
        dataloader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False
        )
        
        all_responses = []
        all_ground_truths = []
        
        # Set model to eval mode
        self.model.eval()
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Generating responses"):
                # Prepare inputs
                formatted_prompts = batch['text']
                inputs = self.tokenizer(
                    formatted_prompts,
                    padding="longest",
                    return_tensors="pt"
                ).to(self.device)
                
                # Generate responses
                responses = self.generate_with_strategy(
                    inputs['input_ids'],
                    inputs['attention_mask'],
                    strategy=strategy,
                    **generation_kwargs
                )
                
                all_responses.extend(responses)
                all_ground_truths = test_dataset
        
        # Compute metrics
        prec, rec, f1, accuracy = compute_prediction_scores(all_responses, all_ground_truths)
        
        metrics = {
            "precision": prec,
            "recall": rec,
            "f1": f1,
            "accuracy": accuracy
        }
        
        print(f"Results for {self.model_name} using {strategy} decoding:")
        print(json.dumps(metrics, indent=2))
        
        if save_results:
            output = []
            for idx, resp in enumerate(all_responses):
                tetypes = resp.split('|')
                tetypes = [x.strip() for x in tetypes if '[no entity]' != x]
                output.append({
                    'uuid': test_dataset[idx]['uuid'],
                    'turn_id': test_dataset[idx]['turn_id'],
                    'prediction': tetypes
                })
                
            save_path = f"{constants.TEST_RESULT_FILE[self.model_name]}.{strategy}"
            with open(save_path, 'w') as f:
                json.dump(output, f, indent=2)
        
        return metrics


def main():
    # Set random seeds for reproducibility
    np.random.seed(constants.SEED)
    torch.manual_seed(constants.SEED)
    
    # Models to evaluate
    models = ['mistral-entity-pred', 'gemma-entity-pred']
    
    # Decoding strategies and their parameters
    strategies = {
        "greedy": {},
        "beam": {
            "num_beams": 5,
            "length_penalty": 0.6
        },
        "sampling": {
            "num_samples": 5,
            "temperature": 0.7,
            "top_p": 0.9
        }
    }
    
    # Store all results
    all_results = {}
    
    for model_name in models:
        print(f"\nEvaluating {model_name}...")
        evaluator = ModelEvaluator(model_name)
        model_results = {}
        
        for strategy, params in strategies.items():
            results = evaluator.evaluate(
                test_data_path=f'{constants.DATA_DIR}test.json',
                strategy=strategy,
                batch_size=1,  # Can be increased based on GPU memory
                **params
            )
            model_results[strategy] = results
            
        all_results[model_name] = model_results
    
    # Save overall results
    print("\nAll Results:")
    print(json.dumps(all_results, indent=2))
    
    with open('evaluation_results.json', 'w') as f:
        json.dump(all_results, f, indent=2)


if __name__ == "__main__":
    main()