#### Preprocess

In [None]:
import torchio as tio

def center_crop_brain(subject, target_size, z_offset=4):
    """
    z_offset -> to remove layers near skull base
    """
    from copy import deepcopy
    import torch
    #create brain mask
    brain_mask = (subject['perf1'].data > 0).float()
    
    #get the x,y-coordinates of the center of the brain mask
    z_mid = subject.spatial_shape[-1]//2
    coordinates = torch.where(brain_mask[:,:,:,z_mid] > 0)
    centers = [int(torch.round(coordinates[1].float().mean())), 
               int(torch.round(coordinates[2].float().mean())),]
    
    #create crop transform
    start_x = max(0, centers[0]-(target_size[0]//2))
    start_y = max(0, centers[1]-(target_size[1]//2))
    start_z = max(0, z_mid - (target_size[2]//2 - z_offset))

    stop_x = subject.spatial_shape[0] - (start_x+(target_size[0]))
    stop_y = subject.spatial_shape[1] - (start_y+(target_size[1]))
    stop_z = subject.spatial_shape[2] - (start_z+(target_size[2]))
    
    cropper = tio.Crop(cropping=(start_x, stop_x,
                                 start_y, stop_y,
                                 start_z, stop_z)
                      )
    
    #apply cropping
    cropped_subject = cropper(subject)
    
    return cropped_subject

def preprocess(subject, target_size):
    #define resampling parameters
    if subject['perf1'].data.shape[1] == 256:
        resample_x = 2
        resample_y = 2
    else:
        resample_x = 2
        resample_y = 2
        
    resample_z = subject['perf1'].spacing[-1] #do not resample in z-axis

    #compose preprocessing w/o resizing
    # # use Resample and CropOrPad to resize -> 'resizing: in most medical image applications, resizing shouldn't be
    # # used as it will deform the physical object by scaling anisotropically along the different dimensions.
    # # The solution is typically applying Resample and CropOrPad' ~TORCHIO
    preprocess_tf = tio.Compose([
        tio.transforms.ToCanonical(),
        tio.transforms.Resample((resample_x, resample_y, resample_z)),
        # tio.transforms.ZNormalization(), #Perf-Images of Kai are already normalized
    ])

    for perf in ['perf1']:#, 'perf2', 'perf3', 'perf4', 'perf5']:
        subject[perf] = preprocess_tf(subject[perf])

    #crop around center, z-offset depending on image size
    if subject['perf1'].data.shape[-1] < 30:
        subject_cropped = center_crop_brain(subject, target_size, z_offset=0)
    else:
        if subject['perf1'].data.shape[-1] < 34:
            subject_cropped = center_crop_brain(subject, target_size, z_offset=2)
        else:
            subject_cropped = center_crop_brain(subject, target_size, z_offset=3)

    return subject_cropped

#### Dataset/Dataloader

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import torchio as tio
import torch    
from fastai.vision.all import *
from fastai.data.core import DataLoaders

def concatenate_subject_images(subject):
    image_keys = [key for key, value in subject.items() if isinstance(value, tio.ScalarImage)]
    image_tensors = [subject[key].data for key in image_keys]
    concatenated_images = torch.cat(image_tensors, dim=0)
    return concatenated_images

class Dataset(Dataset):
    def __init__(self, image_subjects, image_transform=None):
        self.image_subjects = image_subjects
        self.image_transform = image_transform

    def __len__(self):
        return len(self.image_subjects)

    def __getitem__(self, index):
        image = concatenate_subject_images(self.image_subjects[index])
        label = self.image_subjects[index]['target']
        nihss = self.image_subjects[index]['nihss']
        acc = self.image_subjects[index]['acc']
        
        # Apply transforms to subject       
        if self.image_transform:
            self.image_transform(image)
        
        return {
            'image': image,
            'label': label,
            'nihss': nihss,
            'acc': acc,
        }

def make_dls(train_subjects, valid_subjects, train_tf, valid_tf, train_bn=4, valid_bn=8):
    """Creates DataLoaders like FastAI with subjects (tio.subjects)
    train_tf, valid_tf: Define transforms for training/validation set
    train_bn, valid_bn: Training/validation batch size
    test_size: ratio validation set"""

    train_ds = Dataset(train_subjects, image_transform=train_tf)
    valid_ds = Dataset(valid_subjects, image_transform=valid_tf)
    
    # Dataloader
    train_loader = DataLoader(
        train_ds,                   
        batch_size=train_bn,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
    )

    valid_loader = DataLoader(
        valid_ds,
        batch_size=valid_bn,
        drop_last=False,
        pin_memory=True,
    )
        
    # dls
    dls = DataLoaders(train_loader, valid_loader)
    return dls
        
def make_dls_img(subjects, train_tf, valid_tf, train_bn=4, valid_bn=8, test_size=0.2, random_state=42, test_set=False):
    """Creates DataLoaders like FastAI with subjects (tio.subjects)
    train_tf, valid_tf: Define transforms for training/validation set
    train_bn, valid_bn: Training/validation batch size
    test_size: ratio validation set"""

    if test_set:
        train_ds = Dataset(subjects, image_transform=train_tf)
        valid_ds = Dataset(subjects, image_transform=valid_tf)

    else:
        # Train/valid splits
        train_subjects, valid_subjects = train_test_split(subjects,
                                                          test_size=test_size,
                                                          random_state=random_state)
        # Datasets
        train_ds = Dataset(train_subjects, image_transform=train_tf)
        valid_ds = Dataset(valid_subjects, image_transform=valid_tf)
    
    # Dataloader
    train_loader = DataLoader(
        train_ds,                   
        batch_size=train_bn,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
    )

    valid_loader = DataLoader(
        valid_ds,
        batch_size=valid_bn,
        drop_last=False,
        pin_memory=True,
    )
        
    # dls
    dls = DataLoaders(train_loader, valid_loader)
    return dls

# class CombinedDataset(Dataset):
#     def __init__(self, image_subjects, tabular_data, image_transform=None):
#         self.image_subjects = image_subjects
#         self.tabular_data = tabular_data
#         self.image_transform = image_transform

#     def __len__(self):
#         return len(self.image_subjects)

#     def __getitem__(self, index):
#         image = self.image_subjects[index]['perf1'][tio.DATA]
#         if type(self.tabular_data)==np.ndarray:
#             tabular = self.tabular_data[index]
#         else:
#             tabular = self.tabular_data.loc[index].to_numpy()
#         label = self.image_subjects[index]['target']
#         nihss = self.image_subjects[index]['nihss']
#         acc = self.image_subjects[index]['acc']
        
#         # Apply transforms to subject       
#         if self.image_transform:
#             self.image_transform(image)
        
#         return {
#             'image': image,
#             'tabular': tabular,
#             'label': label,
#             'nihss': nihss,
#             'acc': acc,
#         }

# class Dataset(Dataset):
#     def __init__(self, image_subjects, image_transform=None):
#         self.image_subjects = image_subjects
#         self.image_transform = image_transform

#     def __len__(self):
#         return len(self.image_subjects)

#     def __getitem__(self, index):
#         image = self.image_subjects[index]['perf1'][tio.DATA]
#         label = self.image_subjects[index]['target']
#         nihss = self.image_subjects[index]['nihss']
#         acc = self.image_subjects[index]['acc']
        
#         # Apply transforms to subject       
#         if self.image_transform:
#             self.image_transform(image)
        
#         return {
#             'image': image,
#             'label': label,
#             'nihss': nihss,
#             'acc': acc,
#         }
        
def make_dls_img_and_tab(subjects, tabular, train_tf, valid_tf, train_bn=4, valid_bn=8, test_size=0.2, random_state=42, test_set=False):
    """Creates DataLoaders like FastAI with 1)subjects (tio.subjects) & 2)tabular data
    train_tf, valid_tf: Define transforms for training/validation set
    train_bn, valid_bn: Training/validation batch size
    test_size: ratio validation set"""

    if test_set:
        train_ds = CombinedDataset(subjects, tabular, image_transform=train_tf)
        valid_ds = CombinedDataset(subjects, tabular, image_transform=valid_tf)

    else:
        # Train/valid splits
        train_subjects, valid_subjects, train_tab, valid_tab = train_test_split(subjects,
                                                                            tabular,
                                                                            test_size=test_size,
                                                                            random_state=random_state)
        # Datasets
        train_ds = CombinedDataset(train_subjects, train_tab, image_transform=train_tf)
        valid_ds = CombinedDataset(valid_subjects, valid_tab, image_transform=valid_tf)
    
    # Dataloader
    train_loader = DataLoader(
        train_ds,                   
        batch_size=train_bn,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
    )

    valid_loader = DataLoader(
        valid_ds,
        batch_size=valid_bn,
        drop_last=False,
        pin_memory=True,
    )
        
    # dls
    dls = DataLoaders(train_loader, valid_loader)
    return dls

def make_dls_img(subjects, train_tf, valid_tf, train_bn=4, valid_bn=8, test_size=0.2, random_state=42, test_set=False):
    """Creates DataLoaders like FastAI with subjects (tio.subjects)
    train_tf, valid_tf: Define transforms for training/validation set
    train_bn, valid_bn: Training/validation batch size
    test_size: ratio validation set"""

    if test_set:
        train_ds = Dataset(subjects, image_transform=train_tf)
        valid_ds = Dataset(subjects, image_transform=valid_tf)

    else:
        # Train/valid splits
        train_subjects, valid_subjects = train_test_split(subjects,
                                                          test_size=test_size,
                                                          random_state=random_state)
        # Datasets
        train_ds = Dataset(train_subjects, image_transform=train_tf)
        valid_ds = Dataset(valid_subjects, image_transform=valid_tf)
    
    # Dataloader
    train_loader = DataLoader(
        train_ds,                   
        batch_size=train_bn,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
    )

    valid_loader = DataLoader(
        valid_ds,
        batch_size=valid_bn,
        drop_last=False,
        pin_memory=True,
    )
        
    # dls
    dls = DataLoaders(train_loader, valid_loader)
    return dls

#### Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.video import r3d_18

#Building blocks
class DAFT(nn.Module):
    def __init__(self, in_channels, num_tabular_features):
        super(DAFT, self).__init__()
        self.auxiliary_network = nn.Sequential(
            nn.Linear(num_tabular_features, 64),
            nn.ReLU(),
            nn.Linear(64, in_channels * 2)
        )

    def forward(self, x, tabular):
        batch_size, channels, *spatial_dims = x.size()
        params = self.auxiliary_network(tabular).view(batch_size, channels, 2)
        scale, bias = params.split(1, dim=2)
        return x * scale.view(batch_size, channels, *[1]*len(spatial_dims)) + \
               bias.view(batch_size, channels, *[1]*len(spatial_dims))

class ResNet3D18(nn.Module):    
    def __init__(self, num_classes, in_channels=1):
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        from torchvision.models.video import r3d_18
        
        super(ResNet3D18, self).__init__()
        self.resnet = r3d_18(weights=False)
        
        self.resnet.stem[0] = nn.Conv3d(in_channels, 64, kernel_size=(3, 7, 7), 
                                        stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.resnet(x)

#Combined model for imaging and tabular data
class CombinedModel_DAFT(nn.Module):
    def __init__(self, num_classes, num_tabular_features, in_channels=3):
        super(CombinedModel_DAFT, self).__init__()
        self.resnet3d = ResNet3D18(num_classes, in_channels)
        self.fc = nn.Linear(num_classes, num_classes)
        
        # Add DAFT modules
        self.daft1 = DAFT(64, num_tabular_features)
        self.daft2 = DAFT(128, num_tabular_features)
        self.daft3 = DAFT(256, num_tabular_features)
        self.daft4 = DAFT(512, num_tabular_features)
    
    def freeze_resnet(self, unfreeze_last_layer=True, unfreeze_last_conv_layers=2):
        for param in self.resnet3d.parameters():
            param.requires_grad = False
        
        if unfreeze_last_conv_layers > 0:
            layers_to_unfreeze = list(self.resnet3d.resnet.layer4.children())[-unfreeze_last_conv_layers:]
            for layer in layers_to_unfreeze:
                for param in layer.parameters():
                    param.requires_grad = True
        
        if unfreeze_last_layer:
            for param in self.resnet3d.resnet.fc.parameters():
                param.requires_grad = True
    
    def unfreeze_resnet(self):
        for param in self.resnet3d.parameters():
            param.requires_grad = True
    
    def forward(self, image, tabular):
        # Apply DAFT after each ResNet block
        x = self.resnet3d.resnet.stem(image)
        x = self.resnet3d.resnet.layer1(x)
        x = self.daft1(x, tabular)
        x = self.resnet3d.resnet.layer2(x)
        x = self.daft2(x, tabular)
        x = self.resnet3d.resnet.layer3(x)
        x = self.daft3(x, tabular)
        x = self.resnet3d.resnet.layer4(x)
        x = self.daft4(x, tabular)
        x = self.resnet3d.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        resnet_out = self.resnet3d.resnet.fc(x)
        logits = self.fc(resnet_out)
        return logits

#### Learner/Callbacks

In [None]:
import numpy as np
from fastai.optimizer import Adam
from fastai.basics import progress_bar, CancelFitException
import torch
import matplotlib.pyplot as plt
from fastai.imports import noop
from sklearn.metrics import f1_score

class My_Learner_img_and_tab:
    """Learner for combined image and tabular data."""
    
    np = __import__('numpy')
    from fastai.optimizer import Adam
    
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=Adam):
        store_attr()
        for cb in cbs: cb.learner = self

    def one_batch(self):
        import torch
        self('before_batch')
        
        if self.model.training:
            images = self.batch['image'].float().cuda()
            tabular = self.batch['tabular'].float().cuda()
            yb = self.batch['label'].cuda()

            self.preds = self.model(images, tabular)
            self.loss = self.loss_func(self.preds, yb)
            self.loss.backward()
            self.opt.step()
        
        else:
            with torch.no_grad():
                images = self.batch['image'].float().cuda()
                tabular = self.batch['tabular'].float().cuda()
                yb = self.batch['label'].cuda()

                self.preds = self.model(images, tabular)
                self.loss = self.loss_func(self.preds, yb)
                
        acc = (self.preds.argmax(dim=1) == yb).float().sum()
        self.accs.append(acc)
        n = len(yb)
        
        self.losses.append(self.loss * n)
        self.ns.append(n)
        self.y_pred.append(self.preds.argmax(dim=1).cpu())
        self.y_true.append(yb.cpu())
        self('after_batch')

    def one_epoch(self, train):
        from fastai.basics import progress_bar
        self.model.training = train
        self('before_epoch')
        self.accs, self.losses, self.ns, self.y_pred, self.y_true = [], [], [], [], []
        
        dl = self.dls.train if train else self.dls.valid
        for self.num, self.batch in enumerate(progress_bar(dl, leave=False)):
            self.one_batch()
        
        n = sum(self.ns)
        y_true = torch.cat(self.y_true).numpy()
        y_pred = torch.cat(self.y_pred).numpy()
        f1 = f1_score(y_true, y_pred, average='weighted')
        
        phase = "TRAINING" if self.model.training else "VALIDATION"
        print(f'{"Epoch: " + str(self.epoch) if self.model.training else ""}')
        print(f'{phase} Loss: {sum(self.losses).item()/n:.4f}, Accuracy: {sum(self.accs).item()/n:.4f}, F1-score: {f1:.4f}')
        
        if self.model.training: 
            self.tr_losses_acc.append(sum(self.losses).item()/n)
            self.tr_f1_scores.append(f1)
        else: 
            self.va_losses_acc.append(sum(self.losses).item()/n)
            self.va_f1_scores.append(f1)
            print('----------------------------------------------------------')
        self('after_epoch')
    
    def fit(self, n_epochs):
        from fastai.basics import CancelFitException
        from fastai.basics import progress_bar
        import torch
        self('before_fit')
        self.tr_losses_acc, self.va_losses_acc = [], []
        self.tr_f1_scores, self.va_f1_scores = [], []
        
        self.model = self.model.float().cuda()
    
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        try:
            for self.epoch in progress_bar(range(n_epochs)):
                self.one_epoch(True)
                self.one_epoch(False)
        except CancelFitException: pass
        self('after_fit')
        
    def validate(self):
        self.one_epoch(False)
        
        import matplotlib.pyplot as plt
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        ax1.plot(self.tr_losses_acc, label='Training')
        ax1.plot(self.va_losses_acc, label='Validation')
        ax1.set_title('Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_yscale('log')
        ax1.legend()
        
        ax2.plot(self.tr_f1_scores, label='Training')
        ax2.plot(self.va_f1_scores, label='Validation')
        ax2.set_title('F1-score')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('F1-score')
        ax2.legend()
        
        plt.tight_layout()
        plt.show()
        self('after_fit')
        
    def __call__(self, name):
        from fastai.imports import noop
        for cb in self.cbs: getattr(cb, name, noop)()

            

import numpy as np
from fastai.optimizer import Adam
from fastai.basics import progress_bar, CancelFitException
import torch
import matplotlib.pyplot as plt
from fastai.imports import noop
from sklearn.metrics import f1_score

class My_Learner_img:
    """Learner for combined image and tabular data."""
    
    np = __import__('numpy')
    from fastai.optimizer import Adam
    
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=Adam):
        store_attr()
        for cb in cbs: cb.learner = self

    def one_batch(self):
        import torch
        self('before_batch')
        
        if self.model.training:
            images = self.batch['image'].float().cuda()
            yb = self.batch['label'].cuda()

            self.preds = self.model(images)
            self.loss = self.loss_func(self.preds, yb)
            self.loss.backward()
            self.opt.step()
        
        else:
            with torch.no_grad():
                images = self.batch['image'].float().cuda()
                yb = self.batch['label'].cuda()

                self.preds = self.model(images)
                self.loss = self.loss_func(self.preds, yb)
                
        acc = (self.preds.argmax(dim=1) == yb).float().sum()
        self.accs.append(acc)
        n = len(yb)
        
        self.losses.append(self.loss * n)
        self.ns.append(n)
        self.y_pred.append(self.preds.argmax(dim=1).cpu())
        self.y_true.append(yb.cpu())
        self('after_batch')

    def one_epoch(self, train):
        from fastai.basics import progress_bar
        self.model.training = train
        self('before_epoch')
        self.accs, self.losses, self.ns, self.y_pred, self.y_true = [], [], [], [], []
        
        dl = self.dls.train if train else self.dls.valid
        for self.num, self.batch in enumerate(progress_bar(dl, leave=False)):
            self.one_batch()
        
        n = sum(self.ns)
        y_true = torch.cat(self.y_true).numpy()
        y_pred = torch.cat(self.y_pred).numpy()
        f1 = f1_score(y_true, y_pred, average='weighted')
        
        phase = "TRAINING" if self.model.training else "VALIDATION"
        print(f'{"Epoch: " + str(self.epoch) if self.model.training else ""}')
        print(f'{phase} Loss: {sum(self.losses).item()/n:.4f}, Accuracy: {sum(self.accs).item()/n:.4f}, F1-score: {f1:.4f}')
        
        if self.model.training: 
            self.tr_losses_acc.append(sum(self.losses).item()/n)
            self.tr_f1_scores.append(f1)
        else: 
            self.va_losses_acc.append(sum(self.losses).item()/n)
            self.va_f1_scores.append(f1)
            print('----------------------------------------------------------')
        self('after_epoch')
    
    def fit(self, n_epochs):
        from fastai.basics import CancelFitException
        from fastai.basics import progress_bar
        import torch
        self('before_fit')
        self.tr_losses_acc, self.va_losses_acc = [], []
        self.tr_f1_scores, self.va_f1_scores = [], []
        
        self.model = self.model.float().cuda()
    
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        try:
            for self.epoch in progress_bar(range(n_epochs)):
                self.one_epoch(True)
                self.one_epoch(False)
        except CancelFitException: pass
        self('after_fit')
        
    def validate(self):
        self.one_epoch(False)
        
        import matplotlib.pyplot as plt
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        ax1.plot(self.tr_losses_acc, label='Training')
        ax1.plot(self.va_losses_acc, label='Validation')
        ax1.set_title('Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_yscale('log')
        ax1.legend()
        
        ax2.plot(self.tr_f1_scores, label='Training')
        ax2.plot(self.va_f1_scores, label='Validation')
        ax2.set_title('F1-score')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('F1-score')
        ax2.legend()
        
        plt.tight_layout()
        plt.show()
        self('after_fit')
        
    def __call__(self, name):
        from fastai.imports import noop
        for cb in self.cbs: getattr(cb, name, noop)()


#Callbacks
class GetAttr:
    """From FastAI: Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`"""
    _default='default'
    def _component_attr_filter(self,k):
        if k.startswith('__') or k in ('_xtra',self._default): return False
        xtra = getattr(self,'_xtra',None)
        return xtra is None or k in xtra
    def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)]
    def __getattr__(self,k):
        if self._component_attr_filter(k):
            attr = getattr(self,self._default,None)
            if attr is not None: return getattr(attr,k)
        raise AttributeError(k)
    def __dir__(self): return custom_dir(self,self._dir())
#     def __getstate__(self): return self.__dict__
    def __setstate__(self,data): self.__dict__.update(data)

class Callback(GetAttr): _default='learner'

class OneCycle(Callback):
    """Install Onecycle-learning rate"""
    def __init__(self, base_lr): self.base_lr = base_lr
    def before_fit(self): self.lrs = []

    def before_batch(self):
        if not self.model.training: return
        n = len(self.dls.train)
        bn = self.epoch*n + self.num
        mn = self.n_epochs*n
        pct = bn/mn
        pct_start,div_start = 0.25,10
        if pct<pct_start:
            pct /= pct_start
            lr = (1-pct)*self.base_lr/div_start + pct*self.base_lr
        else:
            pct = (pct-pct_start)/(1-pct_start)
            lr = (1-pct)*self.base_lr
        self.opt.lr = lr
        self.lrs.append(lr)

#### Interpretation

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import RocCurveDisplay, precision_recall_curve, average_precision_score
from sklearn.calibration import calibration_curve
from fastai.basics import progress_bar
from scipy import stats
import torchio as tio
from sklearn.utils import resample
import torch.nn.functional as F
from fastprogress.fastprogress import progress_bar

def create_combined_plot(cm_disp, y_true, y_pred_proba, class_names, n_classes, download, download_path, dpi, title, fig=None):
    if fig is None:
        fig = plt.figure(figsize=(8, 18))
    
    axes = fig.subplots(3, 1)
    plt.subplots_adjust(hspace=0.8)  # Adjust vertical space between subplots
    
    # Confusion Matrix
    cm_disp.plot(ax=axes[0], cmap='Blues', colorbar=False)
    for text in cm_disp.text_.ravel():
        text.set_fontsize(16)
    axes[0].set_title(title, fontsize=20, pad=20, fontweight='bold')
    axes[0].set_xlabel("Predicted", fontsize=18)
    axes[0].set_ylabel("True", fontsize=18)
    axes[0].tick_params(axis='both', labelsize=14)
    
    # ROC Curve
    if n_classes == 2:
        plot_binary_roc(y_true, y_pred_proba[:, 1], axes[1], download, download_path, dpi)
    else:
        plot_multiclass_roc(y_true, y_pred_proba, n_classes, axes[1], download, download_path, dpi)
    
    # Precision-Recall Curve
    if n_classes == 2:
        plot_binary_pr(y_true, y_pred_proba[:, 1], axes[2], download, download_path, dpi)
    else:
        plot_multiclass_pr(y_true, y_pred_proba, n_classes, axes[2], download, download_path, dpi)
    
    # Set equal aspect ratio for ROC and PR curves
    axes[1].set_aspect('equal', adjustable='box')
    axes[2].set_aspect('equal', adjustable='box')

    axes[1].set_title(' ', fontsize=20, pad=20, fontweight='bold')
    
    # Ensure all subplots have the same size
    plt.tight_layout()
    
    return fig

def create_combined_plot_roc_prc(cm_disp, y_true, y_pred_proba, class_names, n_classes, download, download_path, dpi, title, fig=None):
    if fig is None:
        fig = plt.figure(figsize=(8, 18))
    
    axes = fig.subplots(2, 1)
    plt.subplots_adjust(hspace=0.8)  # Adjust vertical space between subplots
    axes[0].set_title(title, fontsize=20, pad=20, fontweight='bold')

    # ROC Curve
    if n_classes == 2:
        plot_binary_roc(y_true, y_pred_proba[:, 1], axes[0], download, download_path, dpi)
    else:
        plot_multiclass_roc(y_true, y_pred_proba, n_classes, axes[0], download, download_path, dpi)
    
    # Precision-Recall Curve
    if n_classes == 2:
        plot_binary_pr(y_true, y_pred_proba[:, 1], axes[1], download, download_path, dpi)
    else:
        plot_multiclass_pr(y_true, y_pred_proba, n_classes, axes[1], download, download_path, dpi)
    
    # Set equal aspect ratio for ROC and PR curves
    axes[0].set_aspect('equal', adjustable='box')
    axes[1].set_aspect('equal', adjustable='box')

    # axes[0].set_title(' ', fontsize=20, pad=20, fontweight='bold')
    
    # Ensure all subplots have the same size
    plt.tight_layout()
    
    return fig

def Interp_from_learner(learn, class_names,
                        c,
                        use_tabular=True,
                        download=False, 
                        download_path='/media/user/Elements/combined_plot.tiff',
                        dpi=1200,
                        title=None):
    """Get Class report incl. Sens/Spec/F1, Confusion matrix and ROC from learner"""
    
    y_pred, y_pred_proba, y_true = [], [], []
    
    for num, batch in enumerate(progress_bar(learn.dls.valid, leave=False)):
        with torch.no_grad():
            images = batch['image'].float().cuda()
            if use_tabular:
                tabular = batch['tabular'].float().cuda()
            yb = batch['label'].cuda()

            learn.model = learn.model.float().cuda()
            if use_tabular:
                preds = learn.model.eval()(images, tabular)
            else:
                preds = learn.model.eval()(images)

            preds_proba = F.softmax(preds, dim=1)
            y_pred_proba.append(preds_proba.cpu().numpy())
            
            y_pred.extend(preds.argmax(dim=1).cpu().numpy().tolist())
            y_true.extend(yb.cpu().numpy().tolist())
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred_proba = np.vstack(y_pred_proba)
    
    n_classes = y_pred_proba.shape[1]
    
    # Classification report and Confusion Matrix
    class_report = metrics.classification_report(y_true, y_pred)
    cm = metrics.confusion_matrix(y_true, y_pred)
    cm_disp = metrics.ConfusionMatrixDisplay(cm, display_labels=class_names)
    
    print('----------------------------------------------------------')
    print(class_report)
    print('----------------------------------------------------------')
    
    # Create combined plot
    fig = create_combined_plot(cm_disp, y_true, y_pred_proba, class_names, n_classes, download, download_path, dpi, title, fig=None)
    
    if download:
        fig.savefig(download_path, dpi=dpi)
    plt.show()
    
    # Return the figure along with cm_disp, y_true, and y_pred_proba
    return fig, cm_disp, y_true, y_pred_proba


def plot_binary_roc(y_true, y_pred_proba, ax, download, download_path, dpi):    
    n_bootstraps = 1000
    rng_seed = 42
    bootstrapped_auc = []

    rng = np.random.RandomState(rng_seed)
    for i in range(n_bootstraps):
        indices = rng.randint(0, len(y_true), len(y_true))
        y_true_bootstrap = y_true[indices]
        y_pred_proba_bootstrap = y_pred_proba[indices]
        
        fpr, tpr, _ = metrics.roc_curve(y_true_bootstrap, y_pred_proba_bootstrap)
        bootstrapped_auc.append(metrics.auc(fpr, tpr))

    auc, std_auc = np.mean(bootstrapped_auc), np.std(bootstrapped_auc)

    RocCurveDisplay.from_predictions(
        y_true, 
        y_pred_proba, 
        ax=ax, 
        color='royalblue',
    )

    # Plot confidence intervals
    tprs = []
    base_fpr = np.linspace(0, 1, 101)

    for i in range(n_bootstraps):
        indices = rng.randint(0, len(y_true), len(y_true))
        y_true_bootstrap = y_true[indices]
        y_pred_proba_bootstrap = y_pred_proba[indices]
        
        fpr, tpr, _ = metrics.roc_curve(y_true_bootstrap, y_pred_proba_bootstrap)
        tpr = np.interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

    tprs = np.array(tprs)
    mean_tprs = tprs.mean(axis=0)
    std_tprs = tprs.std(axis=0)

    tprs_upper = np.minimum(mean_tprs + std_tprs, 1)
    tprs_lower = mean_tprs - std_tprs

    ax.fill_between(base_fpr, tprs_lower, tprs_upper, color='royalblue', alpha=0.3)
    ax.legend(loc="lower right")
    ax.set_aspect('equal')
    ax.plot([0, 1], [0, 1], linestyle='--', color='gray')
    
    ax.set_xlabel("False Positive Rate", fontsize=18)
    ax.set_ylabel("True Positive Rate", fontsize=18)
    ax.tick_params(axis='both', labelsize=14)
    roc_line = ax.get_lines()[-2]  #get the ROC curve line
    ax.legend([roc_line],
              [f'AUROC = {auc:.2f} ± {std_auc:.2f})'],
              loc="lower right",
              prop={'size': 14})

def plot_binary_pr(y_true, y_pred_proba, ax, download, download_path, dpi):    
    n_bootstraps = 1000
    rng_seed = 42
    
    precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
    average_precision = average_precision_score(y_true, y_pred_proba)
    
    bootstrapped_ap = []
    rng = np.random.RandomState(rng_seed)
    
    for _ in range(n_bootstraps):
        indices = rng.randint(0, len(y_true), len(y_true))
        y_true_bootstrap = y_true[indices]
        y_pred_proba_bootstrap = y_pred_proba[indices]
        ap = average_precision_score(y_true_bootstrap, y_pred_proba_bootstrap)
        bootstrapped_ap.append(ap)
    
    ap_std = np.std(bootstrapped_ap)
    
    ax.plot(recall, precision, color='chocolate', label=f'AUPRC = {average_precision:.2f} ± {ap_std:.2f}')
    ax.fill_between(recall, precision - ap_std, precision + ap_std, color='chocolate', alpha=0.3)
    ax.set_xlim([-0.01, 1.01])
    ax.set_ylim([-0.01, 1.1])  # The upper limit is set to 1.05 to give a little space above the curve
    ax.set_aspect('equal', adjustable='box')

    ax.legend(loc='lower left')
    ax.set_xlabel("Recall", fontsize=18)
    ax.set_ylabel("Precision", fontsize=18)
    ax.tick_params(axis='both', labelsize=14)
    ax.legend(prop={'size': 14})

def delong_roc_variance(ground_truth, predictions):
    order = np.lexsort((predictions, ground_truth))
    ground_truth = ground_truth[order]
    predictions = predictions[order]
    
    total_pos = np.sum(ground_truth)
    total_neg = len(ground_truth) - total_pos
    
    pos_ranks = np.where(ground_truth == 1)[0] + 1
    neg_ranks = np.where(ground_truth == 0)[0] + 1
    
    pos_ranks_sum = np.sum(pos_ranks)
    neg_ranks_sum = np.sum(neg_ranks)
    
    auc = (pos_ranks_sum - total_pos * (total_pos + 1) / 2) / (total_pos * total_neg)
    
    v01 = (auc * (1 - auc)) / (total_neg - 1)
    v10 = (auc * (1 - auc)) / (total_pos - 1)
    
    sx = np.zeros(len(ground_truth))
    sy = np.zeros(len(ground_truth))
    
    for i in range(len(ground_truth)):
        if ground_truth[i] == 1:
            sx[i] = (neg_ranks_sum - total_neg * i) / (total_pos * total_neg)
        else:
            sy[i] = (pos_ranks_sum - total_pos * (len(ground_truth) - i)) / (total_pos * total_neg)
    
    var_auc = (np.sum(sx ** 2) * v01 + np.sum(sy ** 2) * v10) / len(ground_truth)
    return var_auc

def delong_test(ground_truth, predictions_1, predictions_2):
    var_auc_1 = delong_roc_variance(ground_truth, predictions_1)
    var_auc_2 = delong_roc_variance(ground_truth, predictions_2)
    
    auc_1 = metrics.roc_auc_score(ground_truth, predictions_1)
    auc_2 = metrics.roc_auc_score(ground_truth, predictions_2)
    
    cov_auc = delong_roc_covariance(ground_truth, predictions_1, predictions_2)
    
    z = (auc_1 - auc_2) / np.sqrt(var_auc_1 + var_auc_2 - 2 * cov_auc)
    p = 2 * (1 - stats.norm.cdf(abs(z)))
    
    return z, p

def delong_roc_covariance(ground_truth, predictions_1, predictions_2):
    pos = np.where(ground_truth == 1)[0]
    neg = np.where(ground_truth == 0)[0]
    
    total_pos = len(pos)
    total_neg = len(neg)
    
    pos_ranks_1 = np.argsort(predictions_1)
    pos_ranks_2 = np.argsort(predictions_2)
    
    covariance = 0
    for i in range(total_pos):
        for j in range(total_neg):
            covariance += (
                (pos_ranks_1[pos[i]] < pos_ranks_1[neg[j]]) == (pos_ranks_2[pos[i]] < pos_ranks_2[neg[j]])
            ) - 0.5
    
    covariance /= (total_pos * total_neg)
    return covariance

def Interp_from_two_learners(learners, class_names, c,
                             download=False, 
                             download_path1='/media/user/Elements/ROC_combined.tiff',
                             download_path2='/media/user/Elements/PR_combined.tiff',
                             dpi=1200):
    """
    Get ROC, PRC, and Brier score from two learners
    - Input = expects two different types of learners: first one uses image&tabular data, second one uses only image data,
    list of two learners: [learner_img_tab, learner_img]
    - class_names = list of class names ['A', 'B',...] -> class 0, 1,...
    - c = in_channels
    - can download plots if download=True, add paths where to download and put to .tiff
    """
    
    results = []
    for i, learn in enumerate(learners):
        y_pred, y_pred_proba, y_true = [], [], []
        
        for num, batch in enumerate(progress_bar(learn.dls.valid, leave=False)):
            with torch.no_grad():
                images = batch['image'].float().cuda()
                yb = batch['label'].cuda()

                learn.model = learn.model.float().cuda()
                if i == 0:  # First learner uses both image and tabular data
                    tabular = batch['tabular'].float().cuda()
                    preds = learn.model.eval()(images, tabular)
                else:  # Second learner uses only image data
                    preds = learn.model.eval()(images)
                
                preds_proba = F.softmax(preds, dim=1) #apply Softmax to get probabilities
                y_pred_proba.append(preds_proba.cpu().numpy())
                
                y_pred.extend(preds.argmax(dim=1).cpu().numpy().tolist())
                y_true.extend(yb.cpu().numpy().tolist())
        
        results.append({
            'y_true': np.array(y_true),
            'y_pred': np.array(y_pred),
            'y_pred_proba': np.vstack(y_pred_proba)
        })
    
    n_classes = results[0]['y_pred_proba'].shape[1]
    
    if n_classes == 2:
        plot_binary_roc_combined(results, download, download_path1, dpi)
        plot_binary_pr_combined(results, download, download_path2, dpi)
        # plot_binary_brier_combined(results, download, download_path4, dpi)

        # Perform DeLong test
        y_true = results[0]['y_true']  # Assuming both models have the same ground truth
        y_pred_1 = results[0]['y_pred_proba'][:, 1]
        y_pred_2 = results[1]['y_pred_proba'][:, 1]
        
        z, p = delong_test(y_true, y_pred_1, y_pred_2)
        print(f"DeLong test results:")
        print(f"p-value: {p:.4f}")
        
def plot_binary_roc_combined(results, download, download_path, dpi, n_bootstraps=1000, confidence_level=0.95):
    fig, ax = plt.subplots(figsize=(6, 6))
    colors = ['royalblue', 'darkorange']
    labels = ['Image+Tabular', 'Image Only']
    
    for i, (result, color, label) in enumerate(zip(results, colors, labels)):
        y_true, y_pred_proba = result['y_true'], result['y_pred_proba'][:, 1]
        
        # Calculate the original ROC curve
        fpr, tpr, _ = metrics.roc_curve(y_true, y_pred_proba)
        roc_auc = metrics.auc(fpr, tpr)
        
        # Bootstrap to calculate confidence intervals
        tprs = []
        aucs = []
        mean_fpr = np.linspace(0, 1, 100)
        
        for _ in range(n_bootstraps):
            # Resample with replacement
            indices = resample(range(len(y_true)), n_samples=len(y_true))
            y_true_boot = y_true[indices]
            y_pred_proba_boot = y_pred_proba[indices]
            
            # Calculate ROC curve for the bootstrap sample
            fpr_boot, tpr_boot, _ = metrics.roc_curve(y_true_boot, y_pred_proba_boot)
            tprs.append(np.interp(mean_fpr, fpr_boot, tpr_boot))
            aucs.append(metrics.auc(fpr_boot, tpr_boot))
        
        # Calculate confidence intervals
        tprs = np.array(tprs)
        mean_tpr = np.mean(tprs, axis=0)
        std_tpr = np.std(tprs, axis=0)
        
        tprs_upper = np.minimum(mean_tpr + stats.norm.ppf((1 + confidence_level) / 2) * std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - stats.norm.ppf((1 + confidence_level) / 2) * std_tpr, 0)
        
        # Plot ROC curve with confidence interval
        ax.plot(fpr, tpr, color=color, lw=2, label=f'{label} (AUC = {roc_auc:.2f})')
        ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color=color, alpha=0.3,
                        label=f'{confidence_level*100:.0f}% CI')
    
    ax.plot([0, 1], [0, 1], color='gray', linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.legend(loc="lower right")
    plt.xlabel("False Positive Rate", fontsize=12)
    plt.ylabel("True Positive Rate", fontsize=12)
    
    # Add DeLong test results to the plot
    y_true = results[0]['y_true']
    y_pred_1 = results[0]['y_pred_proba'][:, 1]
    y_pred_2 = results[1]['y_pred_proba'][:, 1]
    z, p = delong_test(y_true, y_pred_1, y_pred_2)
    ax.text(0.75, 0.25, f"DeLong test:\np = {p:.4f}", 
            transform=ax.transAxes, fontsize=10, 
           )
    
    if download:
        plt.savefig(download_path, dpi=dpi)
    plt.show()
    
def plot_binary_pr_combined(results, download, download_path, dpi, n_bootstraps=1000, confidence_level=0.95):
    fig, ax = plt.subplots(figsize=(6, 6))
    colors = ['chocolate', 'forestgreen']
    labels = ['Image+Tabular', 'Image Only']
    
    for i, (result, color, label) in enumerate(zip(results, colors, labels)):
        y_true, y_pred_proba = result['y_true'], result['y_pred_proba'][:, 1]
        
        # Calculate the original PR curve
        precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
        average_precision = average_precision_score(y_true, y_pred_proba)
        
        # Bootstrap to calculate confidence intervals
        precisions = []
        avg_precisions = []
        mean_recall = np.linspace(0, 1, 100)
        
        for _ in range(n_bootstraps):
            # Resample with replacement
            indices = resample(range(len(y_true)), n_samples=len(y_true))
            y_true_boot = y_true[indices]
            y_pred_proba_boot = y_pred_proba[indices]
            
            # Calculate PR curve for the bootstrap sample
            precision_boot, recall_boot, _ = precision_recall_curve(y_true_boot, y_pred_proba_boot)
            precisions.append(np.interp(mean_recall, recall_boot[::-1], precision_boot[::-1]))
            avg_precisions.append(average_precision_score(y_true_boot, y_pred_proba_boot))
        
        # Calculate confidence intervals
        precisions = np.array(precisions)
        mean_precision = np.mean(precisions, axis=0)
        std_precision = np.std(precisions, axis=0)
        
        precisions_upper = np.minimum(mean_precision + stats.norm.ppf((1 + confidence_level) / 2) * std_precision, 1)
        precisions_lower = np.maximum(mean_precision - stats.norm.ppf((1 + confidence_level) / 2) * std_precision, 0)
        
        # Plot PR curve with confidence interval
        ax.plot(recall, precision, color=color, lw=2, label=f'{label} (AP = {average_precision:.2f})')
        ax.fill_between(mean_recall, precisions_lower, precisions_upper, color=color, alpha=0.3,
                        label=f'{confidence_level*100:.0f}% CI')
    
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.legend(loc="lower left")
    plt.xlabel("Recall", fontsize=12)
    plt.ylabel("Precision", fontsize=12)
    
    if download:
        plt.savefig(download_path, dpi=dpi)
    plt.show()

import numpy as np
import torch
from sklearn.inspection import permutation_importance
from sklearn.metrics import log_loss

class CustomModelWrapper:
    def __init__(self, model, images):
        self.model = model
        self.images = images
    
    def fit(self, X, y):
        return self
    
    def predict_proba(self, X_tab):
        X_tab = torch.tensor(X_tab, dtype=torch.float32).to(next(self.model.parameters()).device)
        images = self.images.to(next(self.model.parameters()).device)
        with torch.no_grad():
            preds = self.model(images, X_tab)
        return preds.cpu().numpy()

def custom_scorer(estimator, X, y):
    y_pred = estimator.predict_proba(X)
    return -log_loss(y, y_pred)

def get_tabular_importance(model, dataloader, n_repeats=10):
    model.eval()
    
    all_X_tab = []
    all_y = []
    all_images = []
    
    with torch.no_grad():
        for batch in dataloader:
            all_X_tab.append(batch['tabular'])
            all_y.append(batch['label'])
            all_images.append(batch['image'])
    
    X_tab = torch.cat(all_X_tab, dim=0)
    y = torch.cat(all_y, dim=0)
    images = torch.cat(all_images, dim=0)
    
    wrapped_model = CustomModelWrapper(model, images)
    
    perm_importance = permutation_importance(
        wrapped_model,
        X_tab.cpu().numpy(),
        y.cpu().numpy(),
        n_repeats=n_repeats,
        random_state=42,
        scoring=custom_scorer
    )
    
    return perm_importance.importances_mean

def plot_feature_importances(importances, feature_names=None):
    """
    Visualize feature importances without confidence intervals.
    
    Parameters:
    - importances: array of importance scores
    - feature_names: list of feature names (optional)
    """
    # Sort features by importance
    sorted_idx = importances.argsort()#[::1]
    importances = importances[sorted_idx]
    
    if feature_names is None:
        feature_names = [f"Feature {i}" for i in range(len(importances))]
    else:
        feature_names = [feature_names[i] for i in sorted_idx]
    
    y_pos = np.arange(len(feature_names))
    
    fig, ax = plt.subplots(figsize=(12.5, 8))
    
    # Plot horizontal bars
    ax.barh(y_pos, importances, align='center', alpha=0.8)
    
    # Customize the plot
    ax.set_yticks(y_pos)
    ax.set_yticklabels(feature_names)
    ax.set_xlabel('Feature Importance')
    ax.set_title('Feature Importances')
    
    # Add importance values at the end of each bar
    for i, v in enumerate(importances):
        ax.text(v, i, f' {v:.3f}', va='center')
    
    plt.tight_layout()
    plt.show()

def Wrong_instances(learn, c, use_tabular=True, display=True):
    """Get wrongly predicted instances incl. Accessionnumber (has to be defined in subject)
    learner
    c=channels_in"""
    
    predictions = []
    nihss = []
    
    for num,batch in enumerate(progress_bar(learn.dls.valid, leave=False)):
        with torch.no_grad():
            images = batch['image'].float().cuda()
            yb = batch['label'].cuda()
            if use_tabular:
                tabular = batch['tabular'].float().cuda()
            y_nihss = batch['nihss'].float()
            acc_num = batch['acc']

            learn.model = learn.model.float().cuda()
            if use_tabular:
                preds = learn.model.eval()(images, tabular)
            else:
                preds = learn.model.eval()(images)
            
            # Calculate softmax probabilities
            probs = F.softmax(preds, dim=1)
            
            for i in range(len(yb)):
                if preds.argmax(dim=1)[i] != yb[i]:
                    pred_class = preds.argmax(dim=1)[i]
                    pred_prob = probs[i][pred_class].item() * 100  # Convert to percentage
                    
                    if display:
                        print(f'Batch: {num}, Prediction: {pred_class} (Probability: {pred_prob:.2f}%), '
                              f'True NIHSS: {int(y_nihss[i])}, Accessionnumber: {int(acc_num[i])}')
                    
                    predictions.append(int(pred_class))
                    nihss.append(int(y_nihss[i]))

def Correct_instances(learn, c, use_tabular=True, display=True):
    """Get wrongly predicted instances incl. Accessionnumber (has to be defined in subject)
    learner
    c=channels_in"""
    
    predictions = []
    nihss = []
    
    for num,batch in enumerate(progress_bar(learn.dls.valid, leave=False)):
        with torch.no_grad():
            images = batch['image'].float().cuda()
            yb = batch['label'].cuda()
            if use_tabular:
                tabular = batch['tabular'].float().cuda()
            y_nihss = batch['nihss'].float()
            acc_num = batch['acc']

            learn.model = learn.model.float().cuda()
            if use_tabular:
                preds = learn.model.eval()(images, tabular)
            else:
                preds = learn.model.eval()(images)
            
            # Calculate softmax probabilities
            probs = F.softmax(preds, dim=1)
            
            for i in range(len(yb)):
                if preds.argmax(dim=1)[i] == yb[i]:
                    pred_class = preds.argmax(dim=1)[i]
                    pred_prob = probs[i][pred_class].item() * 100  # Convert to percentage
                    
                    if display:
                        print(f'Batch: {num}, Prediction: {pred_class} (Probability: {pred_prob:.2f}%), '
                              f'True NIHSS: {int(y_nihss[i])}, Accessionnumber: {int(acc_num[i])}')
                    
                    predictions.append(int(pred_class))
                    nihss.append(int(y_nihss[i]))

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image, tabular_data=None, target_class=1):
        from scipy.ndimage import zoom
        
        self.model.eval()
        
        #Ensure input_image is 5D: [batch, channels, depth, height, width]
        if input_image.dim() == 4:
            input_image = input_image.unsqueeze(0)  #Add batch dimension if not present

        if tabular_data is not None:
            tabular_data = tabular_data.to(dtype=torch.float32)
            output = self.model(input_image, tabular_data)
        else:
            output = self.model(input_image)
        
        self.model.zero_grad()
        output[:, target_class].backward()
        
        if self.gradients is None or self.activations is None:
            print("Warning: gradients or activations are None")
            return None

        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3, 4])
        for i in range(self.activations.shape[1]):
            self.activations[:, i, :, :, :] *= pooled_gradients[i]
        
        heatmap = torch.mean(self.activations, dim=1).squeeze().cpu().numpy()
        heatmap = np.maximum(heatmap, 0)
        heatmap /= np.max(heatmap)
        
        #Resize heatmap to match input image dimensions
        zoom_factors = (input_image.shape[2] / heatmap.shape[0],
                        input_image.shape[3] / heatmap.shape[1],
                        input_image.shape[4] / heatmap.shape[2])
        heatmap = zoom(heatmap, zoom_factors)
        
        return heatmap
    
    def plot_cam(self, input_image, tabular_data=None, download=False, sl=-1, target_class=1,
                 alpha=0.5,
                 norm=None,
                 download_path='/media/user/Elements/GradCAM.tiff',
                 dpi=300):
        import matplotlib.pyplot as plt
        import matplotlib.colors as mcolors
        
        heatmap = self.generate_cam(input_image, tabular_data, target_class=target_class)

        if sl==-1:
            sl=heatmap.shape[2]//2
        else:
            sl=sl
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)

        # norm = mcolors.Normalize(vmin=0, vmax=1, clip=False)
        plt.imshow(input_image.cpu().numpy()[0,:,:,sl], cmap='viridis', norm=norm)
        plt.title('Original Image')
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(input_image.cpu().numpy()[0,:,:,sl], cmap='viridis', norm=norm, alpha=1)
        plt.imshow(heatmap[:,:,sl], cmap='jet', alpha=alpha)
        plt.title('Grad-CAM')
        plt.axis('off')
        
        if download:
            plt.savefig(download_path, dpi=dpi)
        
        plt.show()