In [None]:
import os
import torch
from nnunet.training.network_training.UniSeg_Trainer import UniSeg_Trainer
from nnunet.paths import network_training_output_dir
from nnunet.run.default_configuration import get_default_configuration

def evaluate_model(trainer, data_loader):
    prompts = []
    predicted_segmentations = []
    ground_truth_segmentations = []

    with torch.no_grad():
        for data in data_loader:
            inputs = data['data']
            task_ids = data['task_id']
            targets = data['target']

            # Move inputs, task_ids, and targets to the appropriate device
            inputs = inputs.to(trainer.device)
            task_ids = task_ids.to(trainer.device)
            targets = targets.to(trainer.device)

            # Forward pass with get_prompt=True
            outputs, intermedia_prompt, dynamic_prompt, task_prompt = trainer.network.forward(inputs, task_ids, get_prompt=True)

            # Collect the prompts
            prompts.append({
                'intermedia_prompt': intermedia_prompt.cpu().numpy(),
                'dynamic_prompt': dynamic_prompt.cpu().numpy(),
                'task_prompt': task_prompt.cpu().numpy()
            })

            # Collect the predicted segmentations
            predicted_segmentations.append(outputs.cpu().numpy())

            # Collect the ground truth segmentations
            ground_truth_segmentations.append(targets.cpu().numpy())

    return prompts, predicted_segmentations, ground_truth_segmentations

def get_metric(output, target, inter_mediate_prompt, dynamic_prompt, task_prompt, features_x):
    # get mask
    target.append(torch.nn.functional.interpolate(target[-1], scale_factor=0.5, mode='trilinear', align_corners=False))
    mask = target[-1] > 0
    mask = mask.float()
    # apply mask
    positive_prompt = task_prompt * mask
    negative_prompt = task_prompt * (1 - mask)
    return sum(positive_prompt)

def collect_model_info_and_evaluate(model_checkpoints, exp_name = "UniSeg_Trainer", network = "3d_fullres", task = "Task097_11task", network_trainer = "UniSeg_Trainer", plans_identifier = "DoDNetPlans", fold=0):
    # Get the main plans file
    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
        trainer_class = get_default_configuration(exp_name, network, task, network_trainer, plans_identifier)
    trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
                            batch_dice=batch_dice, stage=stage, unpack_data=True, deterministic=True, fp16=True)
    for model_checkpoint in model_checkpoints:
        # Extract the model name from the checkpoint file name
        path_checkpoint = os.path.join(output_folder_name, f"fold_{fold}",model_checkpoint)        
        trainer.load_checkpoint(path_checkpoint)
        
        # Get the data loaders
        tr_gen = trainer.tr_gen
        val_gen = trainer.val_gen
        len_data = len(tr_gen.generator._data)
        outputs, inter_mediate_prompts, dynamic_prompts, task_prompts, features_xs,targets = [], [], [], [], [], []
        for i in range(len_data):
            output, inter_mediate_prompt, dynamic_prompt, task_prompt, features_x, target = trainer.run_iteration(tr_gen, False,False, True)
            metric = get_metric(output, target, inter_mediate_prompt, dynamic_prompt, task_prompt, features_x)
            outputs.append(output)
            inter_mediate_prompts.append(inter_mediate_prompt)
            dynamic_prompts.append(dynamic_prompt)
            task_prompts.append(task_prompt)
            features_xs.append(features_x)
            targets.append(target)
    return tr_gen, trainer
model_checkpoints = ["model_best.model"]
tr_gen, trainer = collect_model_info_and_evaluate(model_checkpoints, exp_name = "UniSeg_Trainer", network = "3d_fullres", task = "Task097_11task", network_trainer = "UniSeg_Trainer", plans_identifier = "DoDNetPlans")
#     print("---")