In [None]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d
from torchvision import models, transforms
from torchvision.models import ResNet50_Weights, resnet50



import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID
from cam.layercam import LayerCAM
import torch.nn as nn

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Baseline\no_nosamp\seed_{seed}\pneumonia_detection_model_resnet_baseline_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 14)
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        model_dict = dict(
            type="resnet50",
            layer_name="layer4",
            arch=model,
            target_layer=model.layer4[2].conv3 # Example: last layer of ResNet's layer4  ###### IN BCOS:     target_layer=model.layer4[-1].conv3  # Example: last layer of ResNet's layer4  ### double check!

        )
        cam = LayerCAM(model_dict)
        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                images, labels = images.to(device), labels.to(device)                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    output = model(image)
                    probs = torch.sigmoid(output)
                    
                    binary_preds = (probs > 0.5).int()                    
                    
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = binary_preds[0][class_idx]
                            contribution_map = cam(image, class_idx=class_idx)
                            contribution_map[contribution_map < 0] = 0
                            contribution_map = contribution_map.squeeze(0).squeeze(0)
                            
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))



=== Seed 1 Results ===
Processing fold 1 of 5




torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])


: 