In [None]:
from libraries.energyPointGame import energy_point_game, energy_point_game_recall
from libraries.bcosconv2d import NormedConv2d


import random
import numpy as np
import torch
import pandas as pd
import pydicom
import os
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries.bcosconv2d import NormedConv2d
from blurpool.blur_bcosconv2d import ModifiedBcosConv2d
from pooling.flc_bcosconv2d import ModifiedFLCBcosConv2d
from libraries.bcoslinear import BcosLinear



np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

original_width, original_height = 1024, 1024
explanation_width, explanation_height = 224, 224

image_folder = r"C:\Users\Admin\Documents\rsna-pneumonia-detection-challenge\stage_2_train_images"
csv_path_splits = r"G:\Meine Ablage\Universität\Master Thesis\Pneumonia\training\grouped_data.csv"
csv_path = r"C:\Users\Admin\Documents\rsna-pneumonia-detection-challenge\stage_2_train_labels.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Pneumonia\training\splits\splits_balanced_fix.pkl"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = pd.read_csv(csv_path)
data_splits = pd.read_csv(csv_path_splits)

with open(splits_path, 'rb') as f:
    splits = pickle.load(f)

# Loop over whole validation set of first fold 

### alternative in new models
class PneumoniaDataset(Dataset):
    def __init__(self, dataframe, image_folder, transform=None):
        self.data = dataframe
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = os.path.join(self.image_folder, f"{row['patientId']}.dcm")
        label = row['Target']
        patient_id = row['patientId']

        # Load DICOM file and process it into RGB format
        dicom = pydicom.dcmread(image_path)
        image = dicom.pixel_array
        image = Image.fromarray(image).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long), patient_id


transform = transforms.Compose([
    transforms.ToTensor()  # Normalize with ImageNet stats
])


''' 
transform = no_augmentations() 
val_dataset = PneumoniaDataset(val_data, image_folder, transform=transform)

'''

scale_x = explanation_width / original_width
scale_y = explanation_height / original_height

avg_proportions = []
avg_proportions_incorrect = []
avg_proportions_correct = []

for fold in range(0,5):
    model_path = f"C:/Users/Admin/Documents/MasterThesis/results/Pneumonia/ResNet50_FLC/no_nosamp/seed_1/pneumonia_detection_model_resnet_bestf1_{fold+1}.pth"
    split = splits[fold] # fold selection
    val_idx = split[1]  # Only use the validation indices from the first fold
    val_data = data_splits.iloc[val_idx]
    val_dataset = PneumoniaDataset(val_data, image_folder, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


    bbox_areas = []
    bbox_areas_correct = []
    bbox_areas_incorrect = []
    count_correct = 0
    count_incorrect = 0
    
    model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)

    model.layer2[0].conv2 = ModifiedFLCBcosConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
    model.layer2[0].downsample[0] = ModifiedFLCBcosConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)

    model.layer3[0].conv2 = ModifiedFLCBcosConv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
    model.layer3[0].downsample[0] = ModifiedFLCBcosConv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)

    model.layer4[0].conv2 = ModifiedFLCBcosConv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
    model.layer4[0].downsample[0] = ModifiedFLCBcosConv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)    
        
    model.fc.linear = NormedConv2d(2048, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) # code from B-cos paper reused to adjust network
    #model = torch.hub.load('B-cos/B-cos-v2', 'vitc_b_patch1_14', pretrained=True)
    #model[0].linear_head.linear = BcosLinear(in_features=768, out_features=2, bias=False, b=2)

    #model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
    #model.fc.linear = NormedConv2d(2048, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) # code from B-cos paper reused to adjust network

    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)

    model.eval()
    with torch.no_grad():
        for images, labels, patient_ids in val_loader:
            #images, labels = images.to(device), labels.to(device)
            labels = labels.to(device)
            six_channel_images = []
            for img_tensor in images:
                numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                pil_image = Image.fromarray(numpy_image)
                transformed_image = model.transform(pil_image)
                six_channel_images.append(transformed_image)
                
            six_channel_images = torch.stack(six_channel_images).to(device)
            
            
            for image, label, patient_id in zip(six_channel_images, labels, patient_ids):
                filtered_rows = data[(data['patientId'] == patient_id) & (data['Target'] == 1)]
                if not filtered_rows.empty: 
                    image = image[None]
                    output = model(image)
                    #prediction = torch.argmax(output, dim=1)
                    expl = model.explain(image)
                    prediction = expl['prediction']
                    contribution_map = expl['contribution_map'].squeeze(0).cpu()
                    bbox_total = 0.0
                    num_boxes = len(filtered_rows)  
                    for _, row in filtered_rows.iterrows():
                        x, y, width, height = round(row["x"] * scale_x), round(row["y"] * scale_y), round(row["width"] * scale_x), round(row["height"] * scale_y)
                        bbox_area = width * height
                        img_area = explanation_height * explanation_width
                        bbox_percentage = bbox_area / img_area
                        bbox_total += bbox_percentage                        
                        
                    bbox_areas.append(bbox_total)
                    if prediction == 1:
                        bbox_areas_correct.append(bbox_total)
                        count_correct = count_correct + 1
                    else:
                        bbox_areas_incorrect.append(bbox_total)
                        count_incorrect = count_incorrect + 1
    if bbox_areas:
        avg_proportion = (sum(bbox_areas) / len(bbox_areas))
        avg_proportion_incorrect = (sum(bbox_areas_incorrect) / len(bbox_areas_incorrect))
        avg_proportion_correct = (sum(bbox_areas_correct) / len(bbox_areas_correct))
        avg_proportions.append(avg_proportion)
        avg_proportions_incorrect.append(avg_proportion_incorrect)
        avg_proportions_correct.append(avg_proportion_correct)

    avg_proportion = round(avg_proportion, 4)
    avg_proportion_incorrect = round(avg_proportion_incorrect, 4)
    avg_proportion_correct = round(avg_proportion_correct, 4)

    print(f"Average Energy-Based Pointing Game Proportion (Positive): {avg_proportion}")
    print(f"Average Energy-Based Pointing Game Proportion (Positive) of Incorrectly Classified Images: {avg_proportion_incorrect}, Count: {count_incorrect}", flush=True)
    print(f"Average Energy-Based Pointing Game Proportion (Positive) of Correctly Classified Images: {avg_proportion_correct}, Count: {count_correct}", flush=True)


final_avg_prop = sum(avg_proportions) / len(avg_proportions) 
final_avg_prop_incorrect = sum(avg_proportions_incorrect) / len(avg_proportions_incorrect)
final_avg_prop_correct = sum(avg_proportions_correct) / len(avg_proportions_correct)

final_avg_prop = round(final_avg_prop, 4)
final_avg_prop_incorrect = round(final_avg_prop_incorrect, 4)
final_avg_prop_correct = round(final_avg_prop_correct, 4)
print()
print(f"Average Energy-Based Pointing Game Proportion (Positive) over all folds: {final_avg_prop}", flush=True)
print(f"Average Energy-Based Pointing Game Proportion (Positive) of Incorrectly Classified Images over all folds: {final_avg_prop_incorrect}", flush=True)
print(f"Average Energy-Based Pointing Game Proportion (Positive) of Correctly Classified Images over all folds: {final_avg_prop_correct}", flush=True)