# Calculate EPG (Precision) over all five folds

In [None]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\no_nosamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                            contribution_map[contribution_map < 0] = 0
                            
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))



=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1880       0.1852         0.2004            497              114
       Atelectasis          0.3591       0.3756         0.3494             14               24
     Calcification          0.2033       0.2764         0.2024              1               90
      Cardiomegaly          0.2371       0.2364         0.2389            329              131
     Consolidation          0.3594       0.4531         0.3115             24               47
               ILD          0.3996       0.4838         0.3856             11               66
      Infiltration          0.3005       0.3574         0.2797             33               90
      Lung Opacity          0.2616       0.3418         0.1911            124              141
       Nodule/Mass          0.1664       0.1784         0.1645             23              142
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1966       0.1859         0.2338            475              137
       Atelectasis          0.3275       0.4114         0.3112              6               31
     Calcification          0.2131       0.2561         0.2031             17               73
      Cardiomegaly          0.2495       0.2528         0.2385            355              105
     Consolidation          0.3846       0.5177         0.3602             11               60
               ILD          0.4403       0.5909         0.4156             11               67
      Infiltration          0.3006       0.4240         0.2568             32               90
      Lung Opacity          0.2788       0.3937         0.2064            102              162
       Nodule/Mass          0.2051       0.1579         0.2140             26              139
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1852       0.1789         0.2229            533               88
       Atelectasis          0.3173       0.3742         0.2933             11               26
     Calcification          0.1710       0.1948         0.1693              6               84
      Cardiomegaly          0.2207       0.2213         0.2128            428               32
     Consolidation          0.3522       0.4799         0.2897             23               47
               ILD          0.4258       0.5504         0.3760             22               55
      Infiltration          0.2913       0.3824         0.2222             53               70
      Lung Opacity          0.2614       0.3500         0.1930            115              149
       Nodule/Mass          0.1934       0.2217         0.1912             12              153
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2015       0.1980         0.2258            536               76
       Atelectasis          0.3123       0.4417         0.2708              9               28
     Calcification          0.1932       0.0000         0.1932              0               90
      Cardiomegaly          0.2495       0.2406         0.2963            387               73
     Consolidation          0.3540       0.4850         0.3015             20               50
               ILD          0.4001       0.5428         0.3859              7               70
      Infiltration          0.3123       0.4150         0.2676             37               85
      Lung Opacity          0.2626       0.3108         0.1861            162              102
       Nodule/Mass          0.2066       0.2107         0.2044             59              106
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2048       0.2074         0.1855            538               73
       Atelectasis          0.3483       0.4380         0.3151             10               27
     Calcification          0.2075       0.2579         0.2046              5               86
      Cardiomegaly          0.2552       0.2535         0.2607            353              107
     Consolidation          0.3650       0.5242         0.3187             16               55
               ILD          0.4416       0.5443         0.4146             16               61
      Infiltration          0.3129       0.4670         0.2466             37               86
      Lung Opacity          0.2622       0.3669         0.2102             88              177
       Nodule/Mass          0.1916       0.2888         0.1790             19              147
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2011       0.2006         0.2050            535               76
       Atelectasis          0.3872       0.4202         0.3506             20               18
     Calcification          0.2268       0.2909         0.2180             11               80
      Cardiomegaly          0.2697       0.2659         0.2860            373               87
     Consolidation          0.3929       0.4733         0.3406             28               43
               ILD          0.4141       0.5469         0.3820             15               62
      Infiltration          0.3377       0.4316         0.2988             36               87
      Lung Opacity          0.2750       0.3881         0.2252             81              184
       Nodule/Mass          0.1823       0.2170         0.1732             34              131
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1872       0.1762         0.2416            509              103
       Atelectasis          0.3173       0.2992         0.3194              4               33
     Calcification          0.2012       0.2752         0.1968              5               85
      Cardiomegaly          0.2412       0.2311         0.2715            345              115
     Consolidation          0.3970       0.5006         0.3506             22               49
               ILD          0.4401       0.6080         0.4210              8               70
      Infiltration          0.3134       0.4591         0.2777             24               98
      Lung Opacity          0.2769       0.3819         0.1894            120              144
       Nodule/Mass          0.2116       0.2616         0.1961             39              126
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1912       0.1863         0.2074            479              142
       Atelectasis          0.3505       0.3438         0.3530             10               27
     Calcification          0.1876       0.2300         0.1722             24               66
      Cardiomegaly          0.2468       0.2498         0.2344            370               90
     Consolidation          0.3694       0.4883         0.3313             17               53
               ILD          0.4546       0.5260         0.4115             29               48
      Infiltration          0.3011       0.3677         0.2702             39               84
      Lung Opacity          0.2740       0.4245         0.2278             62              202
       Nodule/Mass          0.2034       0.1933         0.2057             30              135
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1933       0.1880         0.2073            443              169
       Atelectasis          0.2996       0.3610         0.2577             15               22
     Calcification          0.1870       0.2437         0.1799             10               80
      Cardiomegaly          0.2507       0.2371         0.2701            271              189
     Consolidation          0.3512       0.3878         0.3023             40               30
               ILD          0.3887       0.5002         0.3546             18               59
      Infiltration          0.3047       0.3505         0.2315             75               47
      Lung Opacity          0.2603       0.3208         0.1888            143              121
       Nodule/Mass          0.2052       0.2184         0.1970             63              102
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.1787       0.1795         0.1746            507              104
       Atelectasis          0.3565       0.3590         0.3549             14               23
     Calcification          0.1882       0.3870         0.1815              3               88
      Cardiomegaly          0.2432       0.2386         0.2568            342              118
     Consolidation          0.3665       0.4155         0.3327             29               42
               ILD          0.4283       0.4829         0.3873             33               44
      Infiltration          0.3035       0.3502         0.2305             75               48
      Lung Opacity          0.2621       0.3232         0.1905            143              122
       Nodule/Mass          0.1833       0.2018         0.1782             36              130
  

# Calculate EPG (Precision) across both seeds 

In [6]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\light_oversamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                            contribution_map[contribution_map < 0] = 0
                            
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))



=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2426       0.2416         0.2528            555               56
       Atelectasis          0.3714       0.3824         0.3643             15               23
     Calcification          0.2419       0.3042         0.2296             15               76
      Cardiomegaly          0.3057       0.3070         0.2954            412               48
     Consolidation          0.3813       0.4048         0.3685             25               46
               ILD          0.4165       0.5235         0.3862             17               60
      Infiltration          0.3152       0.3397         0.3000             47               76
      Lung Opacity          0.2814       0.3793         0.2181            104              161
       Nodule/Mass          0.1884       0.1923         0.1869             47              118
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2481       0.2414         0.3054            548               64
       Atelectasis          0.3675       0.3632         0.3698             13               24
     Calcification          0.2521       0.3233         0.2262             24               66
      Cardiomegaly          0.3032       0.3010         0.3232            415               45
     Consolidation          0.4183       0.4639         0.3964             23               48
               ILD          0.4701       0.4748         0.4682             23               55
      Infiltration          0.3143       0.3999         0.2710             41               81
      Lung Opacity          0.3054       0.3950         0.2249            125              139
       Nodule/Mass          0.2226       0.2190         0.2238             43              122
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2409       0.2388         0.2553            542               79
       Atelectasis          0.3798       0.4056         0.3715              9               28
     Calcification          0.2049       0.2652         0.1947             13               77
      Cardiomegaly          0.2986       0.2977         0.3062            412               48
     Consolidation          0.3591       0.5193         0.3226             13               57
               ILD          0.4483       0.5557         0.4131             19               58
      Infiltration          0.2974       0.4217         0.2480             35               88
      Lung Opacity          0.2885       0.3923         0.2231            102              162
       Nodule/Mass          0.2202       0.2496         0.2131             32              133
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2395       0.2370         0.2569            534               78
       Atelectasis          0.3244       0.3352         0.3209              9               28
     Calcification          0.2179       0.2696         0.2092             13               77
      Cardiomegaly          0.3112       0.3063         0.3366            386               74
     Consolidation          0.3939       0.4545         0.3661             22               48
               ILD          0.4282       0.5213         0.3977             19               58
      Infiltration          0.3268       0.3671         0.3064             41               81
      Lung Opacity          0.2953       0.3574         0.2451            118              146
       Nodule/Mass          0.2335       0.2465         0.2296             38              127
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2306       0.2299         0.2373            551               60
       Atelectasis          0.3862       0.4733         0.3540             10               27
     Calcification          0.2272       0.3034         0.2058             20               71
      Cardiomegaly          0.2997       0.2914         0.3566            401               59
     Consolidation          0.4021       0.4730         0.3780             18               53
               ILD          0.4519       0.5893         0.4003             21               56
      Infiltration          0.3221       0.4337         0.2702             39               84
      Lung Opacity          0.2913       0.3752         0.2240            118              147
       Nodule/Mass          0.2058       0.2160         0.2021             44              122
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2207       0.2208         0.2200            538               73
       Atelectasis          0.3964       0.4098         0.3876             15               23
     Calcification          0.2382       0.3927         0.2191             10               81
      Cardiomegaly          0.2802       0.2804         0.2797            354              106
     Consolidation          0.3571       0.4032         0.3415             18               53
               ILD          0.3924       0.4540         0.3628             25               52
      Infiltration          0.3119       0.3575         0.2762             54               69
      Lung Opacity          0.2740       0.3640         0.1935            125              140
       Nodule/Mass          0.1854       0.2019         0.1821             27              138
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2392       0.2315         0.3018            545               67
       Atelectasis          0.3585       0.3517         0.3637             16               21
     Calcification          0.2457       0.3609         0.2170             18               72
      Cardiomegaly          0.2799       0.2751         0.2998            369               91
     Consolidation          0.4130       0.4324         0.3931             36               35
               ILD          0.4797       0.5227         0.4556             28               50
      Infiltration          0.3216       0.4013         0.2813             41               81
      Lung Opacity          0.2990       0.3673         0.2327            130              134
       Nodule/Mass          0.2247       0.2331         0.2218             43              122
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2390       0.2359         0.2626            548               73
       Atelectasis          0.3838       0.4147         0.3723             10               27
     Calcification          0.2136       0.2298         0.2120              8               82
      Cardiomegaly          0.3033       0.3046         0.2917            414               46
     Consolidation          0.3805       0.5077         0.3295             20               50
               ILD          0.4665       0.5684         0.4083             28               49
      Infiltration          0.3140       0.4211         0.2500             46               77
      Lung Opacity          0.3014       0.3859         0.2276            123              141
       Nodule/Mass          0.2320       0.2968         0.2132             37              128
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2244       0.2243         0.2249            542               70
       Atelectasis          0.3120       0.3895         0.2833             10               27
     Calcification          0.2104       0.3152         0.1943             12               78
      Cardiomegaly          0.2936       0.2868         0.3286            385               75
     Consolidation          0.3851       0.4738         0.3445             22               48
               ILD          0.4206       0.5011         0.3720             29               48
      Infiltration          0.3222       0.3746         0.2870             49               73
      Lung Opacity          0.2855       0.3503         0.2244            128              136
       Nodule/Mass          0.2247       0.2338         0.2217             41              124
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2363       0.2379         0.2193            559               52
       Atelectasis          0.3742       0.3661         0.3761              7               30
     Calcification          0.2331       0.3723         0.2056             15               76
      Cardiomegaly          0.3055       0.3014         0.3366            406               54
     Consolidation          0.4101       0.5065         0.3843             15               56
               ILD          0.4600       0.5433         0.4307             20               57
      Infiltration          0.3325       0.4542         0.2841             35               88
      Lung Opacity          0.2952       0.3751         0.2358            113              152
       Nodule/Mass          0.2101       0.1963         0.2147             41              125
  

In [None]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d
from pooling.flc_bcosconv2d import ModifiedFLCBcosConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos_FLC\no_nosamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.layer2[0].conv2 = ModifiedFLCBcosConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
        model.layer2[0].downsample[0] = ModifiedFLCBcosConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)

        model.layer3[0].conv2 = ModifiedFLCBcosConv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
        model.layer3[0].downsample[0] = ModifiedFLCBcosConv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)

        model.layer4[0].conv2 = ModifiedFLCBcosConv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
        model.layer4[0].downsample[0] = ModifiedFLCBcosConv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)    

        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                            contribution_map[contribution_map < 0] = 0
                            
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))




=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2326       0.2330         0.2303            514               97
       Atelectasis          0.3712       0.3683         0.3730             15               23
     Calcification          0.2294       0.8209         0.2228              1               90
      Cardiomegaly          0.3149       0.3137         0.3171            300              160
     Consolidation          0.3390       0.4380         0.2428             35               36
               ILD          0.4289       0.5211         0.3390             38               39
      Infiltration          0.3042       0.3633         0.2401             64               59
      Lung Opacity          0.2575       0.3254         0.1959            126              139
       Nodule/Mass          0.1733       0.2046         0.1723              5              160
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2263       0.2254         0.2287            430              182
       Atelectasis          0.3550       0.4863         0.3064             10               27
     Calcification          0.2483       0.2965         0.2472              2               88
      Cardiomegaly          0.2740       0.2758         0.2676            358              102
     Consolidation          0.3733       0.4417         0.3338             26               45
               ILD          0.4987       0.5453         0.4876             15               63
      Infiltration          0.3115       0.4702         0.2424             37               85
      Lung Opacity          0.2661       0.3532         0.1737            136              128
       Nodule/Mass          0.1971       0.1506         0.2078             31              134
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2401       0.2432         0.2179            544               77
       Atelectasis          0.3441       0.3954         0.3300              8               29
     Calcification          0.1994       0.2624         0.1973              3               87
      Cardiomegaly          0.3078       0.3038         0.3330            397               63
     Consolidation          0.3279       0.4731         0.2738             19               51
               ILD          0.4778       0.5430         0.4565             19               58
      Infiltration          0.3018       0.4245         0.2605             31               92
      Lung Opacity          0.2646       0.3672         0.1978            104              160
       Nodule/Mass          0.2027       0.2342         0.1986             19              146
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2253       0.2287         0.2102            501              111
       Atelectasis          0.3389       0.4402         0.3266              4               33
     Calcification          0.2005       0.5720         0.1963              1               89
      Cardiomegaly          0.3248       0.3249         0.3244            410               50
     Consolidation          0.3485       0.4248         0.2525             39               31
               ILD          0.4434       0.5902         0.4239              9               68
      Infiltration          0.3190       0.4642         0.2455             41               81
      Lung Opacity          0.2644       0.3250         0.1808            153              111
       Nodule/Mass          0.2060       0.2411         0.1975             32              133
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2335       0.2404         0.2136            453              158
       Atelectasis          0.3682       0.5108         0.3349              7               30
     Calcification          0.2155       0.4674         0.2099              2               89
      Cardiomegaly          0.2939       0.2882         0.3247            388               72
     Consolidation          0.3501       0.4631         0.2764             28               43
               ILD          0.4941       0.6136         0.4521             20               57
      Infiltration          0.3111       0.4705         0.2763             22              101
      Lung Opacity          0.2616       0.3620         0.1925            108              157
       Nodule/Mass          0.1883       0.0795         0.1903              3              163
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2326       0.2330         0.2303            514               97
       Atelectasis          0.3712       0.3683         0.3730             15               23
     Calcification          0.2294       0.8209         0.2228              1               90
      Cardiomegaly          0.3149       0.3137         0.3171            300              160
     Consolidation          0.3390       0.4380         0.2428             35               36
               ILD          0.4289       0.5211         0.3390             38               39
      Infiltration          0.3042       0.3633         0.2401             64               59
      Lung Opacity          0.2575       0.3254         0.1959            126              139
       Nodule/Mass          0.1733       0.2046         0.1723              5              160
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2263       0.2254         0.2287            430              182
       Atelectasis          0.3550       0.4863         0.3064             10               27
     Calcification          0.2483       0.2965         0.2472              2               88
      Cardiomegaly          0.2740       0.2758         0.2676            358              102
     Consolidation          0.3733       0.4417         0.3338             26               45
               ILD          0.4987       0.5453         0.4876             15               63
      Infiltration          0.3115       0.4702         0.2424             37               85
      Lung Opacity          0.2661       0.3532         0.1737            136              128
       Nodule/Mass          0.1971       0.1506         0.2078             31              134
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2401       0.2432         0.2179            544               77
       Atelectasis          0.3441       0.3954         0.3300              8               29
     Calcification          0.1994       0.2624         0.1973              3               87
      Cardiomegaly          0.3078       0.3038         0.3330            397               63
     Consolidation          0.3279       0.4731         0.2738             19               51
               ILD          0.4778       0.5430         0.4565             19               58
      Infiltration          0.3018       0.4245         0.2605             31               92
      Lung Opacity          0.2646       0.3672         0.1978            104              160
       Nodule/Mass          0.2027       0.2342         0.1986             19              146
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2253       0.2287         0.2102            501              111
       Atelectasis          0.3389       0.4402         0.3266              4               33
     Calcification          0.2005       0.5720         0.1963              1               89
      Cardiomegaly          0.3248       0.3249         0.3244            410               50
     Consolidation          0.3485       0.4248         0.2525             39               31
               ILD          0.4434       0.5902         0.4239              9               68
      Infiltration          0.3190       0.4642         0.2455             41               81
      Lung Opacity          0.2644       0.3250         0.1808            153              111
       Nodule/Mass          0.2060       0.2411         0.1975             32              133
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2335       0.2404         0.2136            453              158
       Atelectasis          0.3682       0.5108         0.3349              7               30
     Calcification          0.2155       0.4673         0.2099              2               89
      Cardiomegaly          0.2939       0.2882         0.3247            388               72
     Consolidation          0.3501       0.4631         0.2764             28               43
               ILD          0.4941       0.6136         0.4521             20               57
      Infiltration          0.3111       0.4705         0.2763             22              101
      Lung Opacity          0.2616       0.3620         0.1925            108              157
       Nodule/Mass          0.1883       0.0795         0.1903              3              163
  

In [9]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID
from pooling.blur_bcosconv2d import ModifiedBcosConv2d


# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos_Blur\light_oversamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.layer2[0].conv2 = ModifiedBcosConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2)
        model.layer2[0].downsample[0] = ModifiedBcosConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), b=2)

        model.layer3[0].conv2 = ModifiedBcosConv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2)
        model.layer3[0].downsample[0] = ModifiedBcosConv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), b=2)

        model.layer4[0].conv2 = ModifiedBcosConv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2)
        model.layer4[0].downsample[0] = ModifiedBcosConv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), b=2)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                            contribution_map[contribution_map < 0] = 0
                            
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))


=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2033       0.2076         0.1722            537               74
       Atelectasis          0.4348       0.4414         0.4309             14               24
     Calcification          0.2268       0.3246         0.2105             13               78
      Cardiomegaly          0.2954       0.2949         0.2995            409               51
     Consolidation          0.3871       0.4597         0.3605             19               52
               ILD          0.4454       0.5532         0.4125             18               59
      Infiltration          0.3122       0.3551         0.2837             49               74
      Lung Opacity          0.2633       0.3413         0.1947            124              141
       Nodule/Mass          0.1624       0.1566         0.1653             55              110
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2096       0.2063         0.2362            544               68
       Atelectasis          0.3476       0.3986         0.3165             14               23
     Calcification          0.2153       0.2536         0.2021             23               67
      Cardiomegaly          0.2955       0.2898         0.3317            398               62
     Consolidation          0.3807       0.4024         0.3682             26               45
               ILD          0.4844       0.5940         0.4160             30               48
      Infiltration          0.2790       0.3970         0.2389             31               91
      Lung Opacity          0.2603       0.3409         0.1720            138              126
       Nodule/Mass          0.1819       0.1725         0.1847             38              127
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2189       0.2154         0.2408            536               85
       Atelectasis          0.3829       0.4184         0.3659             12               25
     Calcification          0.1951       0.2551         0.1822             16               74
      Cardiomegaly          0.3158       0.3157         0.3165            407               53
     Consolidation          0.3755       0.5077         0.2924             27               43
               ILD          0.4739       0.5624         0.4262             27               50
      Infiltration          0.2848       0.3962         0.2312             40               83
      Lung Opacity          0.2606       0.3403         0.1900            124              140
       Nodule/Mass          0.2034       0.2084         0.2011             50              115
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2238       0.2244         0.2194            536               76
       Atelectasis          0.3359       0.4293         0.2853             13               24
     Calcification          0.1994       0.2190         0.1934             21               69
      Cardiomegaly          0.3372       0.3346         0.3562            404               56
     Consolidation          0.3605       0.4368         0.3155             26               44
               ILD          0.4246       0.5400         0.3813             21               56
      Infiltration          0.3110       0.3768         0.2752             43               79
      Lung Opacity          0.2490       0.3029         0.2061            117              147
       Nodule/Mass          0.2052       0.1818         0.2132             42              123
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2273       0.2277         0.2233            555               56
       Atelectasis          0.3804       0.3527         0.3921             11               26
     Calcification          0.2120       0.2612         0.1982             20               71
      Cardiomegaly          0.3231       0.3225         0.3285            413               47
     Consolidation          0.3997       0.4382         0.3894             15               56
               ILD          0.4818       0.6141         0.4255             23               54
      Infiltration          0.3109       0.4135         0.2749             32               91
      Lung Opacity          0.2570       0.3191         0.2200             99              166
       Nodule/Mass          0.1823       0.1642         0.1905             52              114
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2225       0.2236         0.2146            537               74
       Atelectasis          0.4310       0.4591         0.4081             17               21
     Calcification          0.2251       0.2482         0.1982             49               42
      Cardiomegaly          0.3372       0.3344         0.3530            391               69
     Consolidation          0.3763       0.4303         0.3521             22               49
               ILD          0.4188       0.4810         0.3870             26               51
      Infiltration          0.3273       0.4537         0.2951             25               98
      Lung Opacity          0.2454       0.2899         0.1978            137              128
       Nodule/Mass          0.1732       0.1619         0.1775             45              120
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2245       0.2218         0.2480            550               62
       Atelectasis          0.3768       0.4288         0.3326             17               20
     Calcification          0.2400       0.4051         0.2146             12               78
      Cardiomegaly          0.3101       0.3069         0.3347            406               54
     Consolidation          0.4279       0.4817         0.3693             37               34
               ILD          0.5118       0.5531         0.4956             22               56
      Infiltration          0.3107       0.3861         0.2725             41               81
      Lung Opacity          0.2826       0.3938         0.2033            110              154
       Nodule/Mass          0.2033       0.2070         0.2022             38              127
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2342       0.2335         0.2402            558               63
       Atelectasis          0.3862       0.4349         0.3599             13               24
     Calcification          0.2167       0.1912         0.2214             14               76
      Cardiomegaly          0.3265       0.3286         0.3102            408               52
     Consolidation          0.3988       0.5232         0.3490             20               50
               ILD          0.5019       0.5666         0.4744             23               54
      Infiltration          0.3141       0.4199         0.2720             35               88
      Lung Opacity          0.2781       0.3755         0.2097            109              155
       Nodule/Mass          0.2216       0.1940         0.2268             26              139
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2153       0.2142         0.2227            531               81
       Atelectasis          0.3291       0.4624         0.2727             11               26
     Calcification          0.1846       0.2579         0.1688             16               74
      Cardiomegaly          0.3087       0.3044         0.3306            386               74
     Consolidation          0.3748       0.4682         0.3196             26               44
               ILD          0.4353       0.5131         0.3857             30               47
      Infiltration          0.3094       0.3792         0.2727             42               80
      Lung Opacity          0.2486       0.3071         0.1874            135              129
       Nodule/Mass          0.2029       0.2008         0.2039             52              113
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.2138       0.2149         0.2032            555               56
       Atelectasis          0.3909       0.4221         0.3777             11               26
     Calcification          0.2197       0.3048         0.2092             10               81
      Cardiomegaly          0.2877       0.2801         0.3406            402               58
     Consolidation          0.3885       0.5240         0.3491             16               55
               ILD          0.4542       0.5861         0.4079             20               57
      Infiltration          0.3055       0.4261         0.2613             33               90
      Lung Opacity          0.2455       0.3184         0.1921            112              153
       Nodule/Mass          0.1817       0.1567         0.1870             29              137
  

# Recall Calculation
- Normal FLC

In [1]:
from libraries_multilabel.energyPointGame import energy_point_game_recall
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\no_nosamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                                                        
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_recall(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))



=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9992       0.9999         0.9964            497              114
       Atelectasis          0.9997       0.9998         0.9996             14               24
     Calcification          0.9982       0.9998         0.9982              1               90
      Cardiomegaly          0.9988       0.9999         0.9962            329              131
     Consolidation          0.9998       0.9998         0.9999             24               47
               ILD          0.9990       0.9999         0.9988             11               66
      Infiltration          0.9998       0.9999         0.9997             33               90
      Lung Opacity          0.9999       0.9998         0.9999            124              141
       Nodule/Mass          0.9999       0.9998         0.9999             23              142
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            475              137
       Atelectasis          0.9542       0.9998         0.9453              6               31
     Calcification          0.9998       0.9997         0.9998             17               73
      Cardiomegaly          0.9998       0.9998         0.9997            355              105
     Consolidation          0.9997       0.9998         0.9996             11               60
               ILD          0.9993       0.9998         0.9992             11               67
      Infiltration          0.9983       0.9998         0.9978             32               90
      Lung Opacity          0.9994       0.9997         0.9992            102              162
       Nodule/Mass          0.9997       0.9997         0.9997             26              139
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            533               88
       Atelectasis          0.9988       0.9999         0.9983             11               26
     Calcification          0.9998       0.9996         0.9998              6               84
      Cardiomegaly          0.9998       0.9998         0.9998            428               32
     Consolidation          0.9998       0.9998         0.9998             23               47
               ILD          0.9998       0.9998         0.9998             22               55
      Infiltration          0.9999       0.9999         0.9998             53               70
      Lung Opacity          0.9998       0.9998         0.9998            115              149
       Nodule/Mass          0.9998       0.9997         0.9998             12              153
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9984       0.9998         0.9883            536               76
       Atelectasis          0.9998       0.9996         0.9998              9               28
     Calcification          0.9998       0.0000         0.9998              0               90
      Cardiomegaly          0.9989       0.9999         0.9938            387               73
     Consolidation          0.9998       0.9997         0.9998             20               50
               ILD          0.9868       0.9999         0.9855              7               70
      Infiltration          0.9999       0.9999         0.9999             37               85
      Lung Opacity          0.9997       0.9998         0.9995            162              102
       Nodule/Mass          0.9999       0.9999         0.9999             59              106
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            538               73
       Atelectasis          0.9641       0.9999         0.9509             10               27
     Calcification          0.9998       0.9998         0.9998              5               86
      Cardiomegaly          0.9998       0.9998         0.9997            353              107
     Consolidation          0.9946       0.9998         0.9931             16               55
               ILD          0.9995       0.9998         0.9994             16               61
      Infiltration          0.9996       0.9999         0.9995             37               86
      Lung Opacity          0.9998       0.9998         0.9998             88              177
       Nodule/Mass          0.9998       0.9998         0.9998             19              147
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9989       0.9998         0.9926            535               76
       Atelectasis          0.9980       0.9995         0.9962             20               18
     Calcification          0.9996       0.9995         0.9996             11               80
      Cardiomegaly          0.9996       0.9998         0.9988            373               87
     Consolidation          0.9995       0.9996         0.9995             28               43
               ILD          0.9975       0.9998         0.9970             15               62
      Infiltration          0.9984       0.9998         0.9979             36               87
      Lung Opacity          0.9996       0.9996         0.9995             81              184
       Nodule/Mass          0.9998       0.9998         0.9998             34              131
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9993       0.9997         0.9968            509              103
       Atelectasis          0.9457       0.9998         0.9391              4               33
     Calcification          0.9907       0.9998         0.9902              5               85
      Cardiomegaly          0.9975       0.9998         0.9908            345              115
     Consolidation          0.9995       0.9997         0.9995             22               49
               ILD          0.9898       0.9994         0.9887              8               70
      Infiltration          0.9911       0.9999         0.9890             24               98
      Lung Opacity          0.9955       0.9997         0.9919            120              144
       Nodule/Mass          0.9997       0.9997         0.9997             39              126
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9997       0.9997         0.9998            479              142
       Atelectasis          0.9978       0.9997         0.9971             10               27
     Calcification          0.9997       0.9996         0.9998             24               66
      Cardiomegaly          0.9997       0.9997         0.9997            370               90
     Consolidation          0.9996       0.9995         0.9996             17               53
               ILD          0.9996       0.9997         0.9996             29               48
      Infiltration          0.9997       0.9998         0.9997             39               84
      Lung Opacity          0.9995       0.9996         0.9995             62              202
       Nodule/Mass          0.9997       0.9997         0.9997             30              135
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            443              169
       Atelectasis          0.9998       0.9998         0.9998             15               22
     Calcification          0.9998       0.9998         0.9998             10               80
      Cardiomegaly          0.9998       0.9998         0.9997            271              189
     Consolidation          0.9998       0.9998         0.9998             40               30
               ILD          0.9999       0.9999         0.9999             18               59
      Infiltration          0.9999       0.9999         0.9998             75               47
      Lung Opacity          0.9998       0.9998         0.9998            143              121
       Nodule/Mass          0.9998       0.9998         0.9998             63              102
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9999            507              104
       Atelectasis          0.9728       0.9999         0.9563             14               23
     Calcification          0.9999       0.9997         0.9999              3               88
      Cardiomegaly          0.9999       1.0000         0.9999            342              118
     Consolidation          0.9998       0.9998         0.9999             29               42
               ILD          0.9998       0.9999         0.9997             33               44
      Infiltration          0.9999       0.9999         0.9999             75               48
      Lung Opacity          0.9999       0.9999         0.9999            143              122
       Nodule/Mass          0.9999       0.9999         0.9999             36              130
  

In [2]:
from libraries_multilabel.energyPointGame import energy_point_game_recall
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\light_oversamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                                                        
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_recall(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))



=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         0.9999            555               56
       Atelectasis          0.9991       0.9999         0.9987             15               23
     Calcification          0.9998       0.9996         0.9999             15               76
      Cardiomegaly          1.0000       1.0000         1.0000            412               48
     Consolidation          0.9999       0.9999         0.9998             25               46
               ILD          0.9998       1.0000         0.9997             17               60
      Infiltration          0.9998       1.0000         0.9998             47               76
      Lung Opacity          0.9999       0.9999         0.9999            104              161
       Nodule/Mass          1.0000       0.9999         1.0000             47              118
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       1.0000         0.9997            548               64
       Atelectasis          0.9960       0.9999         0.9939             13               24
     Calcification          1.0000       0.9999         1.0000             24               66
      Cardiomegaly          1.0000       1.0000         1.0000            415               45
     Consolidation          0.9999       0.9999         0.9999             23               48
               ILD          0.9999       0.9999         0.9999             23               55
      Infiltration          0.9999       1.0000         0.9998             41               81
      Lung Opacity          0.9999       0.9999         0.9999            125              139
       Nodule/Mass          0.9999       0.9999         0.9999             43              122
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         0.9999            542               79
       Atelectasis          0.9999       0.9999         0.9999              9               28
     Calcification          0.9999       0.9999         1.0000             13               77
      Cardiomegaly          1.0000       1.0000         0.9999            412               48
     Consolidation          0.9866       0.9998         0.9835             13               57
               ILD          0.9999       0.9999         0.9999             19               58
      Infiltration          0.9999       0.9999         0.9999             35               88
      Lung Opacity          0.9999       0.9999         0.9999            102              162
       Nodule/Mass          0.9999       0.9999         1.0000             32              133
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         1.0000            534               78
       Atelectasis          0.9971       1.0000         0.9961              9               28
     Calcification          1.0000       1.0000         0.9999             13               77
      Cardiomegaly          1.0000       1.0000         0.9999            386               74
     Consolidation          0.9999       0.9999         0.9999             22               48
               ILD          1.0000       0.9999         1.0000             19               58
      Infiltration          1.0000       1.0000         0.9999             41               81
      Lung Opacity          0.9999       0.9999         1.0000            118              146
       Nodule/Mass          0.9999       1.0000         0.9999             38              127
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         0.9999            551               60
       Atelectasis          0.9999       0.9999         0.9999             10               27
     Calcification          0.9999       0.9999         0.9999             20               71
      Cardiomegaly          0.9999       1.0000         0.9998            401               59
     Consolidation          0.9986       0.9997         0.9982             18               53
               ILD          0.9999       0.9999         0.9999             21               56
      Infiltration          0.9999       1.0000         0.9999             39               84
      Lung Opacity          0.9999       0.9999         0.9999            118              147
       Nodule/Mass          0.9999       0.9999         0.9999             44              122
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         0.9999            538               73
       Atelectasis          0.9997       0.9999         0.9996             15               23
     Calcification          0.9999       0.9999         0.9999             10               81
      Cardiomegaly          1.0000       1.0000         1.0000            354              106
     Consolidation          0.9999       0.9999         0.9999             18               53
               ILD          0.9999       0.9999         0.9999             25               52
      Infiltration          0.9999       0.9999         1.0000             54               69
      Lung Opacity          0.9999       0.9999         0.9999            125              140
       Nodule/Mass          1.0000       1.0000         1.0000             27              138
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       1.0000         0.9998            545               67
       Atelectasis          1.0000       0.9999         1.0000             16               21
     Calcification          1.0000       0.9999         1.0000             18               72
      Cardiomegaly          1.0000       1.0000         0.9999            369               91
     Consolidation          0.9999       0.9999         0.9999             36               35
               ILD          0.9999       0.9999         1.0000             28               50
      Infiltration          0.9998       1.0000         0.9997             41               81
      Lung Opacity          0.9999       0.9999         0.9999            130              134
       Nodule/Mass          0.9999       1.0000         0.9999             43              122
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         1.0000            548               73
       Atelectasis          0.9915       1.0000         0.9883             10               27
     Calcification          1.0000       0.9999         1.0000              8               82
      Cardiomegaly          1.0000       1.0000         1.0000            414               46
     Consolidation          0.9985       0.9999         0.9980             20               50
               ILD          0.9999       0.9999         0.9999             28               49
      Infiltration          0.9999       0.9999         0.9998             46               77
      Lung Opacity          0.9999       0.9999         1.0000            123              141
       Nodule/Mass          1.0000       1.0000         1.0000             37              128
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          1.0000       1.0000         1.0000            542               70
       Atelectasis          0.9999       0.9999         1.0000             10               27
     Calcification          1.0000       1.0000         1.0000             12               78
      Cardiomegaly          1.0000       1.0000         1.0000            385               75
     Consolidation          0.9998       0.9999         0.9998             22               48
               ILD          1.0000       0.9999         1.0000             29               48
      Infiltration          0.9999       1.0000         0.9999             49               73
      Lung Opacity          1.0000       0.9999         1.0000            128              136
       Nodule/Mass          1.0000       1.0000         0.9999             41              124
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9999            559               52
       Atelectasis          0.9996       0.9999         0.9995              7               30
     Calcification          0.9999       0.9998         0.9999             15               76
      Cardiomegaly          0.9999       0.9999         0.9994            406               54
     Consolidation          0.9985       0.9996         0.9981             15               56
               ILD          0.9999       0.9998         0.9999             20               57
      Infiltration          0.9999       0.9999         0.9999             35               88
      Lung Opacity          0.9999       0.9999         0.9999            113              152
       Nodule/Mass          0.9998       0.9999         0.9997             41              125
  

In [4]:
from libraries_multilabel.energyPointGame import energy_point_game_recall
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID
from pooling.flc_bcosconv2d import ModifiedFLCBcosConv2d

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"


# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos_FLC\no_nosamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.layer2[0].conv2 = ModifiedFLCBcosConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
        model.layer2[0].downsample[0] = ModifiedFLCBcosConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)

        model.layer3[0].conv2 = ModifiedFLCBcosConv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
        model.layer3[0].downsample[0] = ModifiedFLCBcosConv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)

        model.layer4[0].conv2 = ModifiedFLCBcosConv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2, transpose=True)
        model.layer4[0].downsample[0] = ModifiedFLCBcosConv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), b=2, transpose=False)    

        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                                                        
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_recall(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))


=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9999         0.9997            514               97
       Atelectasis          0.9996       0.9998         0.9995             15               23
     Calcification          0.9997       0.9997         0.9997              1               90
      Cardiomegaly          0.9998       0.9998         0.9998            300              160
     Consolidation          0.9998       0.9998         0.9999             35               36
               ILD          0.9998       0.9999         0.9997             38               39
      Infiltration          0.9999       0.9999         0.9998             64               59
      Lung Opacity          0.9998       0.9998         0.9998            126              139
       Nodule/Mass          0.9999       0.9999         0.9999              5              160
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9950       0.9997         0.9837            430              182
       Atelectasis          0.9542       0.9998         0.9373             10               27
     Calcification          0.9884       0.9992         0.9881              2               88
      Cardiomegaly          0.9958       0.9998         0.9818            358              102
     Consolidation          0.9990       0.9996         0.9986             26               45
               ILD          0.9973       0.9999         0.9967             15               63
      Infiltration          0.9757       0.9998         0.9652             37               85
      Lung Opacity          0.9870       0.9998         0.9735            136              128
       Nodule/Mass          0.9991       0.9997         0.9990             31              134
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9992       0.9999         0.9944            544               77
       Atelectasis          0.9984       0.9999         0.9979              8               29
     Calcification          0.9999       1.0000         0.9999              3               87
      Cardiomegaly          0.9999       0.9999         0.9996            397               63
     Consolidation          0.9998       0.9999         0.9998             19               51
               ILD          0.9985       0.9999         0.9980             19               58
      Infiltration          0.9993       0.9999         0.9991             31               92
      Lung Opacity          0.9994       0.9999         0.9990            104              160
       Nodule/Mass          0.9999       0.9999         0.9999             19              146
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9997       0.9997         0.9996            501              111
       Atelectasis          0.9969       0.9994         0.9966              4               33
     Calcification          0.9996       0.9997         0.9996              1               89
      Cardiomegaly          0.9993       0.9997         0.9956            410               50
     Consolidation          0.9961       0.9995         0.9920             39               31
               ILD          0.9997       0.9998         0.9997              9               68
      Infiltration          0.9994       0.9997         0.9992             41               81
      Lung Opacity          0.9942       0.9996         0.9868            153              111
       Nodule/Mass          0.9986       0.9998         0.9984             32              133
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9965       0.9992         0.9887            453              158
       Atelectasis          0.9918       0.9997         0.9899              7               30
     Calcification          0.9859       0.9994         0.9855              2               89
      Cardiomegaly          0.9969       0.9995         0.9833            388               72
     Consolidation          0.9994       0.9997         0.9992             28               43
               ILD          0.9960       0.9994         0.9948             20               57
      Infiltration          0.9988       0.9998         0.9986             22              101
      Lung Opacity          0.9984       0.9996         0.9976            108              157
       Nodule/Mass          0.9982       0.9998         0.9982              3              163
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9999         0.9997            514               97
       Atelectasis          0.9996       0.9998         0.9995             15               23
     Calcification          0.9997       0.9997         0.9997              1               90
      Cardiomegaly          0.9998       0.9998         0.9998            300              160
     Consolidation          0.9998       0.9998         0.9999             35               36
               ILD          0.9998       0.9999         0.9997             38               39
      Infiltration          0.9999       0.9999         0.9998             64               59
      Lung Opacity          0.9998       0.9998         0.9998            126              139
       Nodule/Mass          0.9999       0.9999         0.9999              5              160
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9950       0.9997         0.9837            430              182
       Atelectasis          0.9542       0.9998         0.9373             10               27
     Calcification          0.9884       0.9992         0.9881              2               88
      Cardiomegaly          0.9958       0.9998         0.9818            358              102
     Consolidation          0.9990       0.9996         0.9986             26               45
               ILD          0.9973       0.9999         0.9967             15               63
      Infiltration          0.9757       0.9998         0.9652             37               85
      Lung Opacity          0.9870       0.9998         0.9735            136              128
       Nodule/Mass          0.9991       0.9997         0.9990             31              134
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9992       0.9999         0.9944            544               77
       Atelectasis          0.9984       0.9999         0.9979              8               29
     Calcification          0.9999       1.0000         0.9999              3               87
      Cardiomegaly          0.9999       0.9999         0.9996            397               63
     Consolidation          0.9998       0.9999         0.9998             19               51
               ILD          0.9985       0.9999         0.9980             19               58
      Infiltration          0.9993       0.9999         0.9991             31               92
      Lung Opacity          0.9994       0.9999         0.9990            104              160
       Nodule/Mass          0.9999       0.9999         0.9999             19              146
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9997       0.9997         0.9996            501              111
       Atelectasis          0.9969       0.9994         0.9966              4               33
     Calcification          0.9996       0.9997         0.9996              1               89
      Cardiomegaly          0.9993       0.9997         0.9956            410               50
     Consolidation          0.9961       0.9995         0.9920             39               31
               ILD          0.9997       0.9998         0.9997              9               68
      Infiltration          0.9994       0.9997         0.9992             41               81
      Lung Opacity          0.9942       0.9996         0.9868            153              111
       Nodule/Mass          0.9986       0.9998         0.9984             32              133
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9965       0.9992         0.9887            453              158
       Atelectasis          0.9918       0.9997         0.9899              7               30
     Calcification          0.9859       0.9994         0.9855              2               89
      Cardiomegaly          0.9969       0.9995         0.9833            388               72
     Consolidation          0.9994       0.9997         0.9992             28               43
               ILD          0.9960       0.9994         0.9948             20               57
      Infiltration          0.9988       0.9998         0.9986             22              101
      Lung Opacity          0.9984       0.9996         0.9976            108              157
       Nodule/Mass          0.9982       0.9998         0.9982              3              163
  

In [5]:
from libraries_multilabel.energyPointGame import energy_point_game_recall
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\no_nosamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                                                        
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_recall(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))



=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9992       0.9999         0.9964            497              114
       Atelectasis          0.9997       0.9998         0.9996             14               24
     Calcification          0.9982       0.9998         0.9982              1               90
      Cardiomegaly          0.9988       0.9999         0.9962            329              131
     Consolidation          0.9998       0.9998         0.9999             24               47
               ILD          0.9990       0.9999         0.9988             11               66
      Infiltration          0.9998       0.9999         0.9997             33               90
      Lung Opacity          0.9999       0.9998         0.9999            124              141
       Nodule/Mass          0.9999       0.9998         0.9999             23              142
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            475              137
       Atelectasis          0.9542       0.9998         0.9453              6               31
     Calcification          0.9998       0.9997         0.9998             17               73
      Cardiomegaly          0.9998       0.9998         0.9997            355              105
     Consolidation          0.9997       0.9998         0.9996             11               60
               ILD          0.9993       0.9998         0.9992             11               67
      Infiltration          0.9983       0.9998         0.9978             32               90
      Lung Opacity          0.9994       0.9997         0.9992            102              162
       Nodule/Mass          0.9997       0.9997         0.9997             26              139
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            533               88
       Atelectasis          0.9988       0.9999         0.9983             11               26
     Calcification          0.9998       0.9996         0.9998              6               84
      Cardiomegaly          0.9998       0.9998         0.9998            428               32
     Consolidation          0.9998       0.9998         0.9998             23               47
               ILD          0.9998       0.9998         0.9998             22               55
      Infiltration          0.9999       0.9999         0.9998             53               70
      Lung Opacity          0.9998       0.9998         0.9998            115              149
       Nodule/Mass          0.9998       0.9997         0.9998             12              153
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9984       0.9998         0.9883            536               76
       Atelectasis          0.9998       0.9996         0.9998              9               28
     Calcification          0.9998       0.0000         0.9998              0               90
      Cardiomegaly          0.9989       0.9999         0.9938            387               73
     Consolidation          0.9998       0.9997         0.9998             20               50
               ILD          0.9868       0.9999         0.9855              7               70
      Infiltration          0.9999       0.9999         0.9999             37               85
      Lung Opacity          0.9997       0.9998         0.9995            162              102
       Nodule/Mass          0.9999       0.9999         0.9999             59              106
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            538               73
       Atelectasis          0.9641       0.9999         0.9509             10               27
     Calcification          0.9998       0.9998         0.9998              5               86
      Cardiomegaly          0.9998       0.9998         0.9997            353              107
     Consolidation          0.9946       0.9998         0.9931             16               55
               ILD          0.9995       0.9998         0.9994             16               61
      Infiltration          0.9996       0.9999         0.9995             37               86
      Lung Opacity          0.9998       0.9998         0.9998             88              177
       Nodule/Mass          0.9998       0.9998         0.9998             19              147
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9989       0.9998         0.9926            535               76
       Atelectasis          0.9980       0.9995         0.9962             20               18
     Calcification          0.9996       0.9995         0.9996             11               80
      Cardiomegaly          0.9996       0.9998         0.9988            373               87
     Consolidation          0.9995       0.9996         0.9995             28               43
               ILD          0.9975       0.9998         0.9970             15               62
      Infiltration          0.9984       0.9998         0.9979             36               87
      Lung Opacity          0.9996       0.9996         0.9995             81              184
       Nodule/Mass          0.9998       0.9998         0.9998             34              131
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9993       0.9997         0.9968            509              103
       Atelectasis          0.9457       0.9998         0.9391              4               33
     Calcification          0.9907       0.9998         0.9902              5               85
      Cardiomegaly          0.9975       0.9998         0.9908            345              115
     Consolidation          0.9995       0.9997         0.9995             22               49
               ILD          0.9898       0.9994         0.9887              8               70
      Infiltration          0.9911       0.9999         0.9890             24               98
      Lung Opacity          0.9955       0.9997         0.9919            120              144
       Nodule/Mass          0.9997       0.9997         0.9997             39              126
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9997       0.9997         0.9998            479              142
       Atelectasis          0.9978       0.9997         0.9971             10               27
     Calcification          0.9997       0.9996         0.9998             24               66
      Cardiomegaly          0.9997       0.9997         0.9997            370               90
     Consolidation          0.9996       0.9995         0.9996             17               53
               ILD          0.9996       0.9997         0.9996             29               48
      Infiltration          0.9997       0.9998         0.9997             39               84
      Lung Opacity          0.9995       0.9996         0.9995             62              202
       Nodule/Mass          0.9997       0.9997         0.9997             30              135
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            443              169
       Atelectasis          0.9998       0.9998         0.9998             15               22
     Calcification          0.9998       0.9998         0.9998             10               80
      Cardiomegaly          0.9998       0.9998         0.9997            271              189
     Consolidation          0.9998       0.9998         0.9998             40               30
               ILD          0.9999       0.9999         0.9999             18               59
      Infiltration          0.9999       0.9999         0.9998             75               47
      Lung Opacity          0.9998       0.9998         0.9998            143              121
       Nodule/Mass          0.9998       0.9998         0.9998             63              102
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9999            507              104
       Atelectasis          0.9728       0.9999         0.9563             14               23
     Calcification          0.9999       0.9997         0.9999              3               88
      Cardiomegaly          0.9999       1.0000         0.9999            342              118
     Consolidation          0.9998       0.9998         0.9999             29               42
               ILD          0.9998       0.9999         0.9997             33               44
      Infiltration          0.9999       0.9999         0.9999             75               48
      Lung Opacity          0.9999       0.9999         0.9999            143              122
       Nodule/Mass          0.9999       0.9999         0.9999             36              130
  

In [6]:
from libraries_multilabel.energyPointGame import energy_point_game_recall
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID
from pooling.blur_bcosconv2d import ModifiedBcosConv2d


# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"


# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

class_names_extended = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis", "Average"]

seeds = [0, 1]
seed_results = {
    name: {
        'proportions': [],
        'correct': [],
        'incorrect': [],
    } for name in class_names
}

seed_results_avg = { 
        'proportions': [],
        'correct': [],
        'incorrect': []
        }

for seed in seeds:
    print(f"\n=== Seed {seed+1} Results ===")
    all_fold_results = []
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    fold_average_results = {name: {
            'avg_proportion': [],
            'avg_correct': [],
            'avg_incorrect': [],
            'correct_count': [],
            'incorrect_count': []
        } for name in class_names}

    fold_avg_row_results = {'proportions': [],
        'correct': [],
        'incorrect': []
        }

    # Process all 5 folds
    for fold in range(5):
        print(f"Processing fold {fold+1} of 5")
        model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos_Blur\light_oversamp\seed_{seed}\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
        
        # Initialize results storage        
        fold_all_outputs = {name: {
            'proportions': [],
            'correct': [],
            'incorrect': [],
            'correct_count': 0,
            'incorrect_count': 0
        } for name in class_names} # need to be averaged for final_fold_result


        model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
        model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
        model.layer2[0].conv2 = ModifiedBcosConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2)
        model.layer2[0].downsample[0] = ModifiedBcosConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), b=2)

        model.layer3[0].conv2 = ModifiedBcosConv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2)
        model.layer3[0].downsample[0] = ModifiedBcosConv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), b=2)

        model.layer4[0].conv2 = ModifiedBcosConv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b=2)
        model.layer4[0].downsample[0] = ModifiedBcosConv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), b=2)

        model.load_state_dict(torch.load(model_path))
        model.to(device)

        
        val_idx = splits[fold][1]
        val_data = data.iloc[val_idx]
        val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

        multiLabelWrapper = MultiLabelModelWrapper(model)
        model.eval()
        multiLabelWrapper.model.eval()

        with torch.enable_grad():
            for images, labels, image_ids in val_loader:
                labels = labels.to(device)
                six_channel_images = []
                
                # Convert images to 6-channel format
                for img_tensor in images:
                    numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    pil_image = Image.fromarray(numpy_image)
                    transformed_image = model.transform(pil_image)
                    six_channel_images.append(transformed_image)
                
                six_channel_images = torch.stack(six_channel_images).to(device)
                
                mask = torch.zeros((224, 224), dtype=torch.int32)
                for image, label, image_id in zip(six_channel_images, labels, image_ids):
                    image = image[None]  # Add batch dimension
                    expl = multiLabelWrapper.explain(image)
                    
                    # Process each class
                    for class_idx, class_name in enumerate(class_names):
                        filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                        labels_row = data[data['image_id'] == image_id]
                        
                        if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                            prediction = expl["binary_predictions"][0][class_idx]
                            contribution_map = expl['contribution_maps'][class_idx]
                                                        
                            # Create mask from bounding boxes
                            mask.zero_()
                            for _, row in filtered_rows.iterrows():
                                x_min = int(row["x_min"])
                                y_min = int(row["y_min"])
                                x_max = int(row["x_max"])
                                y_max = int(row["y_max"])
                                mask[y_min:y_max, x_min:x_max] = 1.0
                            
                            # Calculate energy point game metric
                            ebpg_result = energy_point_game_recall(mask.cpu(), contribution_map.cpu())
                            
                            # Update results
                            fold_all_outputs[class_name]['proportions'].append(ebpg_result)
                            if prediction == 1:
                                fold_all_outputs[class_name]['correct'].append(ebpg_result)
                                fold_all_outputs[class_name]['correct_count'] += 1
                            else:
                                fold_all_outputs[class_name]['incorrect'].append(ebpg_result)
                                fold_all_outputs[class_name]['incorrect_count'] += 1
                                
                        
        for class_name in class_names:
            class_fold_result = fold_all_outputs[class_name]
            
            avg_total = np.mean(class_fold_result['proportions']) if class_fold_result['proportions'] else 0
            avg_correct = np.mean(class_fold_result['correct']) if class_fold_result['correct'] else 0
            avg_incorrect = np.mean(class_fold_result['incorrect']) if class_fold_result['incorrect'] else 0
            
            fold_average_results[class_name]['avg_proportion'].append(round(avg_total, 4))
            fold_average_results[class_name]['avg_correct'].append(round(avg_correct, 4))
            fold_average_results[class_name]['avg_incorrect'].append(round(avg_incorrect, 4))
            fold_average_results[class_name]['correct_count'].append(class_fold_result['correct_count'])
            fold_average_results[class_name]['incorrect_count'].append(class_fold_result['incorrect_count'])

        rows = []
        for class_name in class_names:
            rows.append({
                'Class': class_name,
                'avg_proportion': fold_average_results[class_name]['avg_proportion'][fold],
                'avg_correct': fold_average_results[class_name]['avg_correct'][fold],
                'avg_incorrect': fold_average_results[class_name]['avg_incorrect'][fold],
                'correct_count': fold_average_results[class_name]['correct_count'][fold],
                'incorrect_count': fold_average_results[class_name]['incorrect_count'][fold]
            })

        df = pd.DataFrame(rows)

        # add an average row
        avg_row = {
            'Class': 'Average',
            'avg_proportion': round(df['avg_proportion'].mean(), 4),
            'avg_correct': round(df['avg_correct'].mean(), 4),
            'avg_incorrect': round(df['avg_incorrect'].mean(), 4),
            'correct_count': int(df['correct_count'].sum()),        
            'incorrect_count': int(df['incorrect_count'].sum()),    
        }
        fold_avg_row_results['proportions'].append(avg_row['avg_proportion'])
        fold_avg_row_results['correct'].append(avg_row['avg_correct'])
        fold_avg_row_results['incorrect'].append(avg_row['avg_incorrect'])

        
        df = pd.concat([
            df,
            pd.DataFrame([avg_row])  # Wrap in list to create 1-row DataFrame
        ], ignore_index=True)

        print(f"\nEnergy-Based Pointing Game Results for Fold {fold+1}:")
        print(df.to_string(index=False))        
        
        if fold == 4:
            for class_name in class_names:
                seed_results[class_name]['proportions'].append(np.mean(fold_average_results[class_name]['avg_proportion']))
                seed_results[class_name]['correct'].append(np.mean(fold_average_results[class_name]['avg_correct']))
                seed_results[class_name]['incorrect'].append(np.mean(fold_average_results[class_name]['avg_incorrect']))   
                
            seed_results_avg['proportions'].append(np.mean(fold_avg_row_results['proportions']))
            seed_results_avg['correct'].append(np.mean(fold_avg_row_results['correct']))
            seed_results_avg['incorrect'].append(np.mean(fold_avg_row_results['incorrect']))

                
        
                    

def calculate_stats(values):
    return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

all_props = []
all_corrs = []
all_incorrects = []

print("\n\n=== Final Cross-Seed Averages ===")
for class_name in class_names:
    print(f"\nClass: {class_name}")
    print("Proportion:", calculate_stats(seed_results[class_name]['proportions']))
    print("Correct   :", calculate_stats(seed_results[class_name]['correct']))
    print("Incorrect :", calculate_stats(seed_results[class_name]['incorrect']))

print("\n=== Final Cross-Seed Average (across all classes) ===")
print("Proportion:", calculate_stats(seed_results_avg['proportions']))
print("Correct   :", calculate_stats(seed_results_avg['correct']))
print("Incorrect :", calculate_stats(seed_results_avg['incorrect']))


=== Seed 1 Results ===
Processing fold 1 of 5


Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9999         0.9998            537               74
       Atelectasis          0.9812       0.9996         0.9705             14               24
     Calcification          0.9993       0.9990         0.9994             13               78
      Cardiomegaly          0.9996       0.9997         0.9987            409               51
     Consolidation          0.9986       0.9996         0.9982             19               52
               ILD          0.9989       0.9999         0.9986             18               59
      Infiltration          0.9970       0.9999         0.9951             49               74
      Lung Opacity          0.9998       0.9997         0.9998            124              141
       Nodule/Mass          0.9984       0.9997         0.9978             55              110
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9998            544               68
       Atelectasis          0.9800       0.9998         0.9680             14               23
     Calcification          0.9998       0.9997         0.9999             23               67
      Cardiomegaly          0.9999       0.9999         0.9998            398               62
     Consolidation          0.9998       0.9998         0.9998             26               45
               ILD          0.9997       0.9997         0.9996             30               48
      Infiltration          0.9998       0.9999         0.9998             31               91
      Lung Opacity          0.9998       0.9998         0.9998            138              126
       Nodule/Mass          0.9998       0.9997         0.9998             38              127
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9999            536               85
       Atelectasis          0.9960       0.9998         0.9942             12               25
     Calcification          0.9998       0.9998         0.9998             16               74
      Cardiomegaly          0.9997       0.9999         0.9986            407               53
     Consolidation          0.9990       0.9997         0.9986             27               43
               ILD          0.9996       0.9999         0.9995             27               50
      Infiltration          0.9998       0.9999         0.9998             40               83
      Lung Opacity          0.9998       0.9998         0.9998            124              140
       Nodule/Mass          0.9998       0.9998         0.9999             50              115
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9996            536               76
       Atelectasis          0.9694       0.9995         0.9532             13               24
     Calcification          0.9958       0.9996         0.9947             21               69
      Cardiomegaly          0.9998       0.9998         0.9996            404               56
     Consolidation          0.9970       0.9994         0.9956             26               44
               ILD          0.9996       0.9994         0.9997             21               56
      Infiltration          0.9996       0.9998         0.9996             43               79
      Lung Opacity          0.9997       0.9995         0.9998            117              147
       Nodule/Mass          0.9996       0.9996         0.9996             42              123
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9998            555               56
       Atelectasis          0.9827       0.9999         0.9754             11               26
     Calcification          0.9987       0.9994         0.9986             20               71
      Cardiomegaly          0.9999       0.9999         0.9999            413               47
     Consolidation          0.9993       0.9996         0.9992             15               56
               ILD          0.9998       0.9996         0.9999             23               54
      Infiltration          0.9998       0.9999         0.9998             32               91
      Lung Opacity          0.9998       0.9998         0.9998             99              166
       Nodule/Mass          0.9998       0.9998         0.9997             52              114
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 1:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9999         0.9989            537               74
       Atelectasis          0.9990       0.9997         0.9984             17               21
     Calcification          0.9996       0.9995         0.9998             49               42
      Cardiomegaly          0.9995       0.9998         0.9973            391               69
     Consolidation          0.9708       0.9998         0.9578             22               49
               ILD          0.9944       0.9998         0.9916             26               51
      Infiltration          0.9831       0.9999         0.9789             25               98
      Lung Opacity          0.9997       0.9998         0.9997            137              128
       Nodule/Mass          0.9993       0.9999         0.9990             45              120
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 2:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            550               62
       Atelectasis          0.9803       0.9998         0.9637             17               20
     Calcification          0.9997       0.9995         0.9997             12               78
      Cardiomegaly          0.9995       0.9998         0.9971            406               54
     Consolidation          0.9967       0.9996         0.9935             37               34
               ILD          0.9979       0.9997         0.9972             22               56
      Infiltration          0.9998       0.9998         0.9998             41               81
      Lung Opacity          0.9997       0.9997         0.9998            110              154
       Nodule/Mass          0.9996       0.9998         0.9995             38              127
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 3:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9999       0.9999         0.9998            558               63
       Atelectasis          0.9977       0.9997         0.9965             13               24
     Calcification          0.9985       0.9998         0.9983             14               76
      Cardiomegaly          0.9997       0.9999         0.9981            408               52
     Consolidation          0.9753       0.9996         0.9656             20               50
               ILD          0.9941       0.9998         0.9917             23               54
      Infiltration          0.9987       0.9998         0.9983             35               88
      Lung Opacity          0.9998       0.9997         0.9998            109              155
       Nodule/Mass          0.9996       0.9999         0.9996             26              139
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 4:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            531               81
       Atelectasis          0.9599       0.9994         0.9432             11               26
     Calcification          0.9994       0.9996         0.9994             16               74
      Cardiomegaly          0.9995       0.9998         0.9980            386               74
     Consolidation          0.9979       0.9995         0.9969             26               44
               ILD          0.9987       0.9996         0.9982             30               47
      Infiltration          0.9998       0.9999         0.9997             42               80
      Lung Opacity          0.9998       0.9997         0.9999            135              129
       Nodule/Mass          0.9998       0.9998         0.9997             52              113
  

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results for Fold 5:
             Class  avg_proportion  avg_correct  avg_incorrect  correct_count  incorrect_count
Aortic enlargement          0.9998       0.9998         0.9998            555               56
       Atelectasis          0.9669       0.9996         0.9530             11               26
     Calcification          0.9958       0.9991         0.9954             10               81
      Cardiomegaly          0.9997       0.9998         0.9996            402               58
     Consolidation          0.9961       0.9995         0.9951             16               55
               ILD          0.9934       0.9996         0.9912             20               57
      Infiltration          0.9993       0.9998         0.9991             33               90
      Lung Opacity          0.9993       0.9996         0.9990            112              153
       Nodule/Mass          0.9996       0.9997         0.9995             29              137
  