In [1]:
import os
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
import numpy as np
import pandas as pd
from torchvision.transforms import functional as TF
from PIL import Image
from libraries.bcosconv2d import NormedConv2d
import pydicom 
import random
import matplotlib.pyplot as plt

from collections import OrderedDict

from libraries.bcosconv2d import NormedConv2d
from libraries.bcoslinear import BcosLinear
from pooling.flc_bcosconv2d import ModifiedFLCBcosConv2d

from pytorch_grad_cam import GradCAM  
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2



np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

# Paths
csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Pneumonia\training\grouped_data.csv"
image_folder = r"C:\Users\Admin\Documents\rsna-pneumonia-detection-challenge\stage_2_train_images"
splits_path = r"G:\Meine Ablage\Universität\Master Thesis\Pneumonia\training\splits\splits_balanced_fix.pkl"
model_path = r"C:\Users\Admin\Documents\MasterThesis\results\ResNet50_BCos\light_oversamp_nonorm\seed_0\pneumonia_detection_model_resnet_bcos_bestf1_1.pth"

# Load splits
with open(splits_path, 'rb') as f:
    splits = pickle.load(f)

# Dataset class
class PneumoniaDataset(Dataset):
    def __init__(self, dataframe, image_folder, transform=None):
        self.data = dataframe
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = os.path.join(self.image_folder, f"{row['patientId']}.dcm")
        label = row['Target']
        patient_id = row['patientId']

        dicom = pydicom.dcmread(image_path)
        image = dicom.pixel_array
        image = Image.fromarray(image).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long), patient_id

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
model.fc.linear = NormedConv2d(2048, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)

# Transformations and data loader setup
transform = transforms.Compose([transforms.ToTensor()])
data = pd.read_csv(csv_path)
first_split = splits[0]
val_idx = first_split[1]
val_data = data.iloc[val_idx]
val_dataset = PneumoniaDataset(val_data, image_folder, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Output directory setup
directory = r"C:\Users\Admin\Documents\MasterThesis\comparison_images\ResNet50_BCos_GradCAM"
os.makedirs(directory, exist_ok=True)

model.eval()
with torch.no_grad():
    i = 0
    for images, labels, patient_ids in val_loader:
        if i > 100:
            break
        images = images.to(device)
        six_channel_images = []
        
        for img_tensor in images:
            numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            pil_image = Image.fromarray(numpy_image)
            transformed_image = model.transform(pil_image)
            six_channel_images.append(transformed_image)
        six_channel_images = torch.stack(six_channel_images).to(device)

        for image, patient_id in zip(six_channel_images, patient_ids):
            if i > 100:
                break
            image = image[None]
            expl = model.explain(image)
            filename = f"{patient_id}_bcos_explanation.png"
            image_path_worse = os.path.join(directory, filename)
            plt.figure()
            plt.imshow(expl["explanation"])
            plt.axis('off')
            plt.savefig(image_path_worse, bbox_inches="tight", pad_inches=0)
            plt.close()
            i += 1            
            
            
model.train()
target_layer = model.layer4[-1]
gradcam = GradCAM(model=model, target_layers=[target_layer])

i = 0
for images, labels, patient_ids in val_loader:
    if i > 100:
        break
    
    original_images = [(img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for img in images]
    images = images.to(device)
    
    six_channel_images = []
    for img_tensor in images:
        numpy_image = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
        pil_image = Image.fromarray(numpy_image)
        transformed_image = model.transform(pil_image)
        six_channel_images.append(transformed_image)
    six_channel_images = torch.stack(six_channel_images).to(device)

    
    for idx, (image, patient_id) in enumerate(zip(six_channel_images, patient_ids)):
        if i > 100:
            break
        
        
        image.requires_grad_(True)  # Enable gradients for Grad-CAM processing
        
        grayscale_cam = gradcam(input_tensor=image.unsqueeze(0))[0]
        
        rgb_img = original_images[idx] / 255.0
        heatmap_resized = cv2.resize(grayscale_cam,
                                     (rgb_img.shape[1], rgb_img.shape[0]))
        
        cam_image = show_cam_on_image(rgb_img, heatmap_resized)

        # Save Grad-CAM visualization
        filename_gradcam = f"{patient_id}_gradcam.png"
        path_gradcam = os.path.join(directory, filename_gradcam)
        
        plt.figure()
        plt.imshow(cam_image)
        plt.axis('off')
        plt.savefig(path_gradcam, bbox_inches="tight", pad_inches=0)
        plt.close()
        i += 1

Using cache found in C:\Users\Admin/.cache\torch\hub\B-cos_B-cos-v2_main
