In [None]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d


import random
import numpy as np
import torch
import pandas as pd
import pydicom
import os
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

#from dataset.augmentations import no_augmentations
from libraries_multilabel.bcosconv2d import NormedConv2d
#from blurpool.blur_bcosconv2d import ModifiedBcosConv2d
#from pooling.flc_bcosconv2d import ModifiedFLCBcosConv2d
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID



np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

original_width, original_height = 1024, 1024
explanation_width, explanation_height = 224, 224

image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
model_path = r"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\no_nosamp\seed_0\pneumonia_detection_model_resnet_bcos_bestf1_1.pth"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#model = torch.hub.load('B-cos/B-cos-v2', 'vitc_b_patch1_14', pretrained=True)
#model[0].linear_head.linear = BcosLinear(in_features=768, out_features=2, bias=False, b=2)

model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), stride=(1, 1), bias=False) # code from B-cos paper reused to adjust network


state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

model.to(device)


data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)

with open(splits_path, 'rb') as f:
    splits = pickle.load(f)

# Loop over whole validation set of first fold 
first_split = splits[0] # fold selection
val_idx = first_split[1]  # Only use the validation indices from the first fold
val_data = data.iloc[val_idx]

transform = transforms.Compose([
    transforms.ToTensor()
])


val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

scale_x = explanation_width / original_width
scale_y = explanation_height / original_height


class_names = [
    "Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

# Initialize results storage
class_results = {name: {
    'proportions': [],
    'correct': [],
    'incorrect': [],
    'correct_count': 0,
    'incorrect_count': 0
} for name in class_names}

proportions = []
proportions_correct = []
proportions_incorrect = []
count_correct = 0
count_incorrect = 0
multiLabelWrapper = MultiLabelModelWrapper(model)
model.eval()
multiLabelWrapper.model.eval()

with torch.enable_grad():
    for images, labels, image_ids in val_loader:
        labels = labels.to(device)
        six_channel_images = []
        for img_tensor in images:
            numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            pil_image = Image.fromarray(numpy_image)
            transformed_image = model.transform(pil_image)
            six_channel_images.append(transformed_image)
            
        six_channel_images = torch.stack(six_channel_images).to(device)
        
        mask = torch.zeros((224, 224), dtype=torch.int32)
        for image, label, image_id in zip(six_channel_images, labels, image_ids):
            image = image[None]
            outputs = model(image)
            prediction = torch.sigmoid(outputs)
            expl = multiLabelWrapper.explain(image)
            for class_idx, class_name in enumerate(class_names):
                filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                labels_row = data[data['image_id'] == image_id]
            
                if (not filtered_rows.empty) and (not labels_row.empty) and (labels_row[class_name].iloc[0] == 1):  
                    prediction = expl["binary_predictions"][0][class_idx]
                    contribution_map = expl['contribution_maps'][class_idx]
                    contribution_map[contribution_map<0] = 0  
                    proportion = 0.0
                    mask.zero_()
                    for _, row in filtered_rows.iterrows():
                        x_min = int(row["x_min"])
                        y_min = int(row["y_min"])
                        x_max = int(row["x_max"])
                        y_max = int(row["y_max"])
                        mask[y_min:y_max, x_min:x_max] = 1.0

                    ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                    class_results[class_name]['proportions'].append(ebpg_result)
                        
                    if prediction == 1:
                        class_results[class_name]['correct'].append(ebpg_result)
                        class_results[class_name]['correct_count'] += 1
                    else:
                        class_results[class_name]['incorrect'].append(ebpg_result)
                        class_results[class_name]['incorrect_count'] += 1
                

results = []
for class_name in class_names:
    data = class_results[class_name]
    if data['proportions']:
        avg_total = sum(data['proportions']) / len(data['proportions'])
        avg_correct = sum(data['correct']) / len(data['correct']) if data['correct'] else 0
        avg_incorrect = sum(data['incorrect']) / len(data['incorrect']) if data['incorrect'] else 0
    else:
        avg_total = avg_correct = avg_incorrect = 0
        
    results.append({
        'Class': class_name,
        'Avg Total': round(avg_total, 4),
        'Avg Correct': round(avg_correct, 4),
        'Avg Incorrect': round(avg_incorrect, 4),
        'Correct Count': data['correct_count'],
        'Incorrect Count': data['incorrect_count'],
        'Total Samples': len(data['proportions'])
    })

# Create and print dataframe
results_df = pd.DataFrame(results)
print("\nEnergy-Based Pointing Game Results by Class:")
print(results_df.to_string(index=False))



Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main



Energy-Based Pointing Game Results by Class:
             Class  Avg Total  Avg Correct  Avg Incorrect  Correct Count  Incorrect Count  Total Samples
Aortic enlargement     0.1880       0.1852         0.2004            497              114            611
       Atelectasis     0.3591       0.3756         0.3494             14               24             38
     Calcification     0.2033       0.2764         0.2024              1               90             91
      Cardiomegaly     0.2371       0.2364         0.2389            329              131            460
     Consolidation     0.3594       0.4531         0.3115             24               47             71
               ILD     0.3996       0.4838         0.3856             11               66             77
      Infiltration     0.3005       0.3574         0.2797             33               90            123
      Lung Opacity     0.2616       0.3418         0.1911            124              141            265
       No

# Calculate over all five folds

In [None]:
from libraries_multilabel.energyPointGame import energy_point_game_mask
from libraries_multilabel.bcosconv2d import NormedConv2d

import random
import numpy as np
import torch
import pandas as pd
import os
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from libraries_multilabel.bcoslinear import BcosLinear
from libraries_multilabel.MultiLabelExplanationWrapper import MultiLabelModelWrapper
from libraries_multilabel.MultiLabelDatasets import MultiLabelDatasetID

# Configuration
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path configurations
image_folder = r"D:\vinbigdata-chest-xray-abnormalities-detection\train_png_224"
csv_path_boxes = r"D:\vinbigdata-chest-xray-abnormalities-detection\train224.csv"
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\multilabel_dataset.csv"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Multi-Classification\training\vinbigdata_5fold_splits.pkl"

# Initialize datasets and transforms
data = pd.read_csv(csv_path)
data_boxes = pd.read_csv(csv_path_boxes)
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)
    
transform = transforms.Compose([transforms.ToTensor()])
class_names = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
    "Consolidation", "ILD", "Infiltration", "Lung Opacity",
    "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
    "Pneumothorax", "Pulmonary fibrosis"]

# Initialize results storage
class_results = {name: {
    'proportions': [],
    'correct': [],
    'incorrect': [],
    'correct_count': 0,
    'incorrect_count': 0
} for name in class_names}

# Process all 5 folds
for fold in range(5):
    print(f"Processing fold {fold+1} of 5")
    
    # 1. Load model for current fold
    model_path = fr"C:\Users\Admin\Documents\MasterThesis\results\VinBigData\ResNet_Bcos\no_nosamp\seed_0\pneumonia_detection_model_resnet_bcos_bestf1_{fold+1}.pth"
    
    model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
    model.fc.linear = NormedConv2d(2048, 14, kernel_size=(1, 1), bias=False)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    
    # 2. Setup validation data for current fold
    val_idx = splits[fold][1]
    val_data = data.iloc[val_idx]
    val_dataset = MultiLabelDatasetID(val_data, image_folder, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

    # 3. Process validation set
    multiLabelWrapper = MultiLabelModelWrapper(model)
    model.eval()
    multiLabelWrapper.model.eval()
    
    with torch.enable_grad():
        for images, labels, image_ids in val_loader:
            labels = labels.to(device)
            six_channel_images = []
            
            # Convert images to 6-channel format
            for img_tensor in images:
                numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                pil_image = Image.fromarray(numpy_image)
                transformed_image = model.transform(pil_image)
                six_channel_images.append(transformed_image)
            
            six_channel_images = torch.stack(six_channel_images).to(device)
            
            # Generate explanations
            mask = torch.zeros((224, 224), dtype=torch.int32)
            for image, label, image_id in zip(six_channel_images, labels, image_ids):
                image = image[None]  # Add batch dimension
                expl = multiLabelWrapper.explain(image)
                
                # Process each class
                for class_idx, class_name in enumerate(class_names):
                    filtered_rows = data_boxes[(data_boxes['image_id'] == image_id)]
                    labels_row = data[data['image_id'] == image_id]
                    
                    if not filtered_rows.empty and not labels_row.empty and labels_row[class_name].iloc[0] == 1:
                        prediction = expl["binary_predictions"][0][class_idx]
                        contribution_map = expl['contribution_maps'][class_idx]
                        contribution_map[contribution_map < 0] = 0
                        
                        # Create mask from bounding boxes
                        mask.zero_()
                        for _, row in filtered_rows.iterrows():
                            x_min = int(row["x_min"] * scale_x)
                            y_min = int(row["y_min"] * scale_y)
                            x_max = int(row["x_max"] * scale_x)
                            y_max = int(row["y_max"] * scale_y)
                            mask[y_min:y_max, x_min:x_max] = 1.0
                        
                        # Calculate energy point game metric
                        ebpg_result = energy_point_game_mask(mask.cpu(), contribution_map.cpu())
                        
                        # Update results
                        class_results[class_name]['proportions'].append(ebpg_result)
                        if prediction == 1:
                            class_results[class_name]['correct'].append(ebpg_result)
                            class_results[class_name]['correct_count'] += 1
                        else:
                            class_results[class_name]['incorrect'].append(ebpg_result)
                            class_results[class_name]['incorrect_count'] += 1

# Calculate final averaged results
results = []
for class_name in class_names:
    data = class_results[class_name]
    
    avg_total = np.mean(data['proportions']) if data['proportions'] else 0
    avg_correct = np.mean(data['correct']) if data['correct'] else 0
    avg_incorrect = np.mean(data['incorrect']) if data['incorrect'] else 0
    
    results.append({
        'Class': class_name,
        'Avg Total': round(avg_total, 4),
        'Avg Correct': round(avg_correct, 4),
        'Avg Incorrect': round(avg_incorrect, 4),
        'Correct Count': data['correct_count'],
        'Incorrect Count': data['incorrect_count'],
        'Total Samples': len(data['proportions'])
    })

# Display final results
results_df = pd.DataFrame(results)
print("\nAverage Energy-Based Pointing Game Results Across All Folds:")
print(results_df.to_string(index=False))


