In [None]:
import sys
import subprocess


In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
%cd /content/gdrive/MyDrive/
!pip install nibabel
!pip install glob
!pip install quantus
!pip install monai


In [None]:
#Author: Michail Mamalakis
#Version: 0.1
#Licence:
#email:mm2703@cam.ac.uk

# name of weights store of main segmentation
import nibabel as nib
from scipy import ndimage
import numpy as np
import glob
import skimage.transform as skTrans
from scipy.ndimage import rotate
import quantus

d=184
w=184
h=64
c=1
format_file='nii'


def read_nifti_file(filepath):
  """	Read and load volume"""
  # Read file
  if format_file=='nii':
    scan = nib.load(filepath)
  else:
    scan = nib.load(filepath('gifti', 'ascii.gii'))
  # Get raw data
  scan = scan.get_fdata()
  return scan


def normalize(volume):
  """Normalize the volume"""
  min = np.min(volume)
  max = np.max(volume)
  volume[volume < min] = min
  volume[volume > max] = max
  volume = (volume - min) / (max - min)
  volume = volume.astype("float32")
  return volume


def resize_volume(img,zoom='off'):
  """Resize across z-axis"""
  # Set the desired depth
  print(img.shape)
  if img.shape[1]<img.shape[2]:
    img=np.transpose(img,(1,2,0,3))
  elif img.shape[1]>img.shape[2]:
    img=np.transpose(img,(2,0,1,3))
  else:
    print('img is ok')
  if zoom!='off':
    desired_depth = d
    desired_width = w
    desired_height = h
    # Get current depth
    if len(img.shape)==4:
      current_depth = img.shape[0]
      current_width = img.shape[1]
      current_height = img.shape[2]
    else:
      current_depth = img.shape[0]
      current_width = img.shape[1]
      current_height = img.shape[2]
	  # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
	  # Rotate
	  #print(img.shape)
    img = ndimage.rotate(img, 90, reshape=False)
	  # Resize across z-axis
    print(img.shape,width_factor,height_factor,depth_factor)
    if len(img.shape)==4:
      if img.shape[1]==img.shape[2]:
    #    img=np.transpose(img, (1, 2, 0, 3))
        print(img.shape)
      img = ndimage.zoom(img, (depth_factor, height_factor, width_factor,1), order=1)
    else:
      if img.shape[1]==img.shape[2]:
    #   img=np.transpose(img, (1, 2, 0))
        img = ndimage.zoom(img, (depth_factor, height_factor, width_factor), order=1)
  else:
    #img=np.transpose(img, (2,1,0,3))
    img=skTrans.resize(img,(d,w,h,c),order=1,preserve_range=True)
  return img

def save(store,volume):
  imgnthree1=nib.Nifti1Image(volume, affine=np.eye(4))
  imgnthree1.header.set_data_dtype(np.uint32)
  nib.save(imgnthree1,store)

def process_scan(path,store,case_ex,comb=6,zoom='off',list_given='none',d=184,w=184,h=184,c=1):
  """Read and resize volume"""
  # Read scan
  volume_tot=np.zeros((d,w,h,c))
  list2=[]
  list1=sorted((glob.glob(path+'*'+case_ex)))
  sum=0.1
  if comb==2:
    if list_given=='none':
      list2=[0.85,0.5,0.1]
    else:
      list2=list_given
    sum=1.45
  else:
    for x in range(1,comb):
      list2.append((comb-x)/comb)
      sum=sum+((comb-x)/comb)
    list2.append(0.1)
  #if listg!=[]:
  #list2=[0.8,0.4]
  #sum=1.20
  print(list1,list2)
  for i in range(comb):
    volume = read_nifti_file(list1[i])
    # Normalize
    volume = normalize(volume)
		# Resize width, height and depth
    #volume = resize_volume(volume,zoom)
    #volume =np.resize(volume ,[64,184,184,1])
    volume_tot=volume_tot+list2[i]*volume
  volume_tot=volume_tot/sum
  save(store+'total_'+case_ex,volume_tot)



In [None]:
#case_ex='GradCam.nii.gz'
#store='XAI_MHL/Shap/R/R_PCS_1'
#path='XAI_MHL/Shap/R/PCS/'
#comb=6
#zoom='off'
#process_scan(path,store,case_ex,comb,zoom)


In [None]:
!pip install torch

In [None]:
from functools import partial
from typing import Any, Union, Optional
from typing import Tuple, Dict
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import sys
import subprocess
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

device= torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_inplanes():
    return [64, 128, 256, 512]


def get_avgpool():
    return [0, 1, (1, 1), (1, 1, 1)]


class ResNetBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        in_planes: int,
        planes: int,
        spatial_dims: int = 3,
        stride: int = 1,
        inplace: bool = True,
        downsample: Union[nn.Module,partial] = None,
    ) -> None:
         super().__init__()

         conv_type: Callable = Conv[Conv.CONV, spatial_dims]
         norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

         self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
         self.bn1 = norm_type(planes)
         self.relu = nn.ReLU(inplace=inplace)
         self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
         self.bn2 = norm_type(planes)
         self.downsample = downsample
         self.stride = stride
         self.relu2 = nn.ReLU(inplace=inplace)



    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        out: torch.Tensor = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu2(out)
        return out

class ResNetBottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        in_planes: int,
        planes: int,
        spatial_dims: int = 3,
        stride: int = 1,
        inplace:bool=True,
        downsample: Union[nn.Module,partial] = None,
    ) -> None:
        """
        Args:
            in_planes: number of input channels.
            planes: number of output channels (taking expansion into account).
            spatial_dims: number of spatial dimensions of the input image.
            stride: stride to use for second conv layer.
            downsample: which downsample layer to use.
        """

        super().__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

        self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)

        self.bn1 = norm_type(planes)
        self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = norm_type(planes)
        self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = norm_type(planes * self.expansion)
        self.relu = nn.ReLU(inplace=inplace)
        self.downsample = downsample
        self.stride = stride
        self.relu2 = nn.ReLU(inplace=inplace)
        self.relu3 = nn.ReLU(inplace=inplace)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        out: torch.Tensor = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu3(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu2(out)

        return out

class ResNet(nn.Module):
     def __init__(
        self,
        block: Union[type[Union[ResNetBlock ,ResNetBottleneck]] ,str],
        layers: list[int],
        block_inplanes: list[int],
        spatial_dims: int = 3,
        n_input_channels: int = 3,
        conv1_t_size: Union[tuple[int] , int] = 7,
        conv1_t_stride: Union[tuple[int] , int] = 1,
        no_max_pool: bool = False,
        shortcut_type: str = "B",
        widen_factor: float = 1.0,
        num_classes: int = 400,
        feed_forward: bool = True,
        inplace: bool=True,
        bias_downsample: bool = True,  # for backwards compatibility (also see PR #5477)
    ) -> None:
        super().__init__()

        if isinstance(block, str):
            if block == "basic":
                block = ResNetBlock
            elif block == "bottleneck":
                block = ResNetBottleneck
            else:
                raise ValueError("Unknown block '%s', use basic or bottleneck" % block)

        conv_type: type[Union[nn.Conv1d , nn.Conv2d , nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
        norm_type: type[Union[nn.BatchNorm1d , nn.BatchNorm2d , nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
        pool_type: type[Union[nn.MaxPool1d , nn.MaxPool2d , nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims]
        avgp_type: type[Union[nn.AdaptiveAvgPool1d , nn.AdaptiveAvgPool2d , nn.AdaptiveAvgPool3d]] = Pool[
            Pool.ADAPTIVEAVG, spatial_dims
        ]

        block_avgpool = get_avgpool()
        block_inplanes = [int(x * widen_factor) for x in block_inplanes]
        self.inplace=inplace
        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool
        self.bias_downsample = bias_downsample

        conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims)
        conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims)

        self.conv1 = conv_type(
            n_input_channels,
            self.in_planes,
            kernel_size=conv1_kernel_size,  # type: ignore
            stride=conv1_stride,  # type: ignore
            padding=tuple(k // 2 for k in conv1_kernel_size),  # type: ignore
            bias=False,
        )
        self.bn1 = norm_type(self.in_planes)
        self.relu = nn.ReLU(inplace=inplace)
        self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
        self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2)
        self.layer3 = self._make_layer(block, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=2)
        self.layer4 = self._make_layer(block, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=2)
        self.avgpool = avgp_type(block_avgpool[spatial_dims])
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) if feed_forward else None
        self.relu2 = nn.ReLU(inplace=inplace)

        for m in self.modules():
            if isinstance(m, conv_type):
                nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
            elif isinstance(m, norm_type):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)

     def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor:
        out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device)
        out = torch.cat([out.data, zero_pads], dim=1)
        return out

     def _make_layer(
        self,
        block: type[Union[ResNetBlock , ResNetBottleneck]],
        planes: int,
        blocks: int,
        spatial_dims: int,
        shortcut_type: str,
        stride: int = 1,
    ) -> nn.Sequential:
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

        downsample: Union(nn.Module , partial) = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if look_up_option(shortcut_type, {"A", "B"}) == "A":
                downsample = partial(
                    self._downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride,
                    spatial_dims=spatial_dims,
                )
            else:
                downsample = nn.Sequential(
                    conv_type(
                        self.in_planes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=self.bias_downsample,
                    ),
                    norm_type(planes * block.expansion),
                )
        layers = [block(in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, inplace=self.inplace, downsample=downsample)]
        self.in_planes = planes * block.expansion
        for _i in range(0, blocks):
            layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims,inplace=self.inplace))

        return nn.Sequential(*layers)


     def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        if self.fc is not None:
            x = self.fc(x)

        return x
def _resnet(
    arch: str,
    block: type[Union[ResNetBlock , ResNetBottleneck]],
    layers: list[int],
    block_inplanes: list[int],
    pretrained: bool,
    progress: bool,
    inplace=True,
    **kwargs: Any,
) -> ResNet:
  #  print('the block is: ',block)
    model: ResNet = ResNet(block, layers, block_inplanes, inplace=inplace, **kwargs)
    if pretrained:
        # Author of paper zipped the state_dict on googledrive,
        # so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
        # Would like to load dict from url but need somewhere to save the state dicts.
        raise NotImplementedError(
            "Currently not implemented. You need to manually download weights provided by the paper's author"
            " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
            "Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified"
            "here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730"
        )
    return model
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    """ResNet-18 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet("resnet18", ResNetBlock, [2, 2, 2, 2], get_inplanes(), pretrained, progress, **kwargs)


In [None]:
!pip install tensorflow

In [None]:
#Author: Michail Mamalakis
#Version: 0.1
#Licence:MIT
#email:mm2703@cam.ac.uk

#an extention including Resnet3DBuilder from https://github.com/JihongJu/keras-resnet3d
# pip install git+https://github.com/JihongJu/keras-resnet3d.git

from __future__ import division, print_function
import os
import zipfile
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import  MaxPooling1D, Lambda, MultiHeadAttention, Input, Conv2D, Concatenate, MaxPooling2D, AveragePooling2D, AveragePooling1D, Dense, Flatten, Reshape, Activation, Dropout, Dense
from tensorflow.keras.models import Model
# name of weights store of main segmentation
from tensorflow.keras.utils import to_categorical

class create_3Dnet:

	def __init__(self, model,height,width,depth,channels,classes,name="",do=0.3, path="/home/mm2703/code/s3D/test/",backbone="simple_3d",paral='off',b_w='simple_3d_gpu_L_skeleton_3d_image_classification.h5'):
		self.model=model
		self.height=height
		self.width=width
		self.depth=depth
		self.channels=channels
		self.classes=classes
		self.path=path
		self.name=name
		self.do=do
		self.backbone=backbone
		self.backb_w=b_w
		self.par=paral
	def model_builder(self):
		if self.model=="simple_3d":
			init_model=self.simple_3d()
			model=self.MLP(init_model)
		elif self.model=='simple_MHL':
			init_model=self.tune_MHL(backbone=self.backbone,name=self.name,store_model=self.path,parallel=self.par)

			model_file=str(self.path + "/"+self.backb_w)
			if os.path.exists(model_file):
				print(model_file)
				init_model.load_weights(model_file,by_name=True, skip_mismatch=True)
			model=self.MLP(init_model)
		elif self.model=='double_3d':
			init_model=self.double_3d()
			model=self.MLP(init_model)
		else:
			print("no model is given")
		return model


	def simple_3d(self,backbone_use='off'):
		"""Build a 3D convolutional neural network model."""

		inputs = keras.Input((self.width, self.height, self.depth, 1))

		x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
		x = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
		x = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
		x = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
		x = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.GlobalAveragePooling3D()(x)
		x = layers.Dense(units=512, activation="relu")(x)
		x = layers.Dropout(self.do)(x)
		if backbone_use=='off':
			outputs = layers.Dense(units=1024, activation="softmax")(x)
		else:
			outputs = layers.Dense(units=15376, activation="softmax")(x)
		# Define the model.
		model = Model(inputs, outputs, name="3dcnn")
		return model


	def double_3d(self):

		inputs1 = keras.Input((self.width, self.height, self.depth, 1))
		inputs2 = keras.Input((self.width, self.height, self.depth, 1))

		x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs1)
		x1 = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x1)
		x2 = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x2)
		x3 = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x3)
		x4 = layers.MaxPool3D(pool_size=2)(x)
		#x = layers.BatchNormalization()(x)

		x = layers.GlobalAveragePooling3D()(x4)
		x = layers.Dense(units=512, activation="relu")(x)
		x5 = layers.Dropout(self.do)(x)


		y = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs2)
		y1 = layers.MaxPool3D(pool_size=2)(y)
                #x = layers.BatchNormalization()(x)
		y = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(y1)
		y2 = layers.MaxPool3D(pool_size=2)(y)
                #x = layers.BatchNormalization()(x)

		y = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(y2)
		y3 = layers.MaxPool3D(pool_size=2)(y)
                #x = layers.BatchNormalization()(x)

		y = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(y3)
		y4 = layers.MaxPool3D(pool_size=2)(y)
                #x = layers.BatchNormalization()(x)

		y = layers.GlobalAveragePooling3D()(y4)
		y = layers.Dense(units=512, activation="relu")(y)
		y5 = layers.Dropout(self.do)(y)

		Rx1=layers.Flatten(name='flatten_tunedRx1')(x1)
		Ry1=layers.Flatten(name='flatten_tunedRy1')(y1)
		R1=layers.MultiHeadAttention(num_heads=2,key_dim=self.height,attention_axes=(1))(Rx1,Ry1)

		Rx2=layers.Flatten(name='flatten_tunedRx2')(x2)
		Ry2=layers.Flatten(name='flatten_tunedRy2')(y2)
		R2=layers.MultiHeadAttention(num_heads=2,key_dim=self.height,attention_axes=(1))(Rx2,Ry2)

		Rx3=layers.Flatten(name='flatten_tunedRx3')(x3)
		Ry3=layers.Flatten(name='flatten_tunedRy3')(y3)
		R3=layers.MultiHeadAttention(num_heads=2,key_dim=self.height,attention_axes=(1))(Rx3,Ry3)


		Rx4=layers.Flatten(name='flatten_tunedRx4')(x4)
		Ry4=layers.Flatten(name='flatten_tunedRy4')(y4)
		R4=layers.MultiHeadAttention(num_heads=2,key_dim=self.height,attention_axes=(1))(Rx4,Ry4)

		R=layers.Concatenate()([R1,R2,R3,R4])
		print(R.shape)
		R=tf.reshape(R, [1, 158898176, 1])
		rg1=layers.MaxPooling1D(pool_size=64)(R)
		rg1f=layers.Flatten(name='flatten_rg1')(rg1)
		rg = layers.Dense(units=(4096), activation="relu")(rg1f)
		print(rg.shape)
		return Model(inputs=[inputs1,inputs2],outputs=rg,name="double_3D")


	def tune_MHL(self,backbone="none",name="",attention="_3d_image_classification",store_model="",parallel='off'):
		inputs=keras.Input((self.width,self.height,self.depth,1))
		if backbone=="none":
			x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
			x = layers.MaxPool3D(pool_size=2)(x)
			x = layers.BatchNormalization()(x)

			x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
			x = layers.MaxPool3D(pool_size=2)(x)
			x = layers.BatchNormalization()(x)
			print("case M-Head attention MHL ")
			x = layers.GlobalAveragePooling3D()(x)
			rc = layers.Dense(units=(self.height*self.width), activation="relu")(x)

		elif backbone=="simple_3d_tune":
			Smodel=self.simple_3d('on')
			model_file=str(store_model + "/"+self.backb_w)
			print(model_file)
			if os.path.exists(model_file):
				Smodel.load_weights(model_file,by_name=True, skip_mismatch=True)
				print('load denset weights')
			rc=Smodel(inputs)
			#Rc=Smodel.output

		elif backbone=="simple_3d":
			x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
			x = layers.MaxPool3D(pool_size=2)(x)
			x = layers.BatchNormalization()(x)

			x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
			x = layers.MaxPool3D(pool_size=2)(x)
			x = layers.BatchNormalization()(x)

			x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
			x = layers.MaxPool3D(pool_size=2)(x)
			rc1 = layers.BatchNormalization()(x)

			x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
			x = layers.MaxPool3D(pool_size=2)(x)
			rc2 = layers.BatchNormalization()(x)

			x = layers.GlobalAveragePooling3D()(x)
			x = layers.Dense(units=512, activation="relu")(x)
			rc = layers.Dense(units=(15376), activation="relu")(x)
			print("case M-Head attention simple model ")
		else:
			print("No none backbone network try resnet50, densenet121, or none!")
		Rdo=layers.Flatten(name='flatten_tunedR')(rc)
		if parallel=='on':
			Rd1=layers.Flatten(name='flatten_tunedR1')(rc1)
			Rd2=layers.Flatten(name='flatten_tunedR2')(rc2)
			R=layers.MultiHeadAttention(num_heads=3,key_dim=self.height,attention_axes=(1))(Rd1,Rd2,Rdo)
			Rd=R
		else:
			Rd=Rdo
		xrgb=layers.MultiHeadAttention(num_heads=2,key_dim=self.height,attention_axes=(1))(Rd,Rd)
		print(xrgb.shape)
		f=layers.Flatten(name='flatten_R')(xrgb)
		rgb = layers.Dense(units=(15376), activation="relu")(f)
		if parallel=='on':
			rgb1=layers.Reshape([124,124,1,1])(rgb)
			rgb2=layers.Reshape([124,124,1,1])(rc)
			rgbc=layers.Concatenate(axis=3)([rgb1,rgb2])
			r=layers.Reshape([124,124,2,1])(rgbc)
			rgbo = layers.MaxPool3D(pool_size=(1,1,2))(r)
		else:
			rgbo=rgb
		rgb11=layers.Reshape([124,124,1,1])(rgbo)
		RCC=layers.Conv3D(filters=self.depth, kernel_size=1, activation="relu")(rgb11)
		rgx=layers.Reshape([124,124,self.depth,1])(RCC)
		x = layers.GlobalAveragePooling3D()(rgx)
		Rdx=layers.Flatten(name='flatten_tunedRx')(x)
		rg = layers.Dense(units=1024, activation="relu")(Rdx)
		return Model(inputs, rg,name="3dmhl")


	def MLP(self,pretrained_model):

		new_DL=pretrained_model.output
		new_DL=layers.Flatten()(new_DL)
		new_DL=layers.Dense(1024, activation="relu")(new_DL)   #64
		new_DL=layers.Dropout(self.do)(new_DL)
		new_DL=layers.Dense(512, activation="relu")(new_DL)    #64
		new_DL=layers.Dropout(self.do)(new_DL)
		new_DL=layers.Dense(self.classes, activation="softmax")(new_DL) #2
		return Model(inputs=pretrained_model.input, outputs=new_DL)

In [None]:
case_ex='.nii'
store='XAI_simple/total/sf/R/nPCS_total'
path='XAI_simple/total/sf/R/nPCS'
comb=2
list2=[0.85,0.1,0.5]
zoom='off'
process_scan(path,store,case_ex,comb,zoom,list2)

#XAI_simple/total/sk/L/PCS *black petrubation
#XAI_simple/total/sf/L/PCS
#XAI_simple/total/sk/R/PCS *mean petrubation
#XAI_simple/total/sf/R/PCS *black petrubation

In [None]:
import monai
import torch

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
n=1
expl='XAI_simple/total/sf/R/nPCS_totaltotal_.nii'
#expl='XAI_simple/total/sf/R/nPCS2total_GradCam.nii' #(1 is shap)
input='XAI_simple/total/sf/R/nPCS0total_shape.nii'
#PATH='QCNN/resnet18_L_no_norm_crop_resize'
weight_file=("simple3D/simple_3dRsf_simple3d_new_data_3d_image_classification.h5")
x=255*normalize(read_nifti_file(input))
a=255*normalize(read_nifti_file(expl))
print(a.max())
y=np.ones(n,dtype=int) #np.zeros(n), np.ones(n)
#checkpoint = torch.load(PATH+".pt",map_location=torch.device(device))
cn=create_3Dnet('simple_3d',184,184,64,1,2,name="",do=0.3,backbone='none',paral='off')
model_3=cn.model_builder()
model_3.load_weights(weight_file,by_name=True)
net=model_3
#net=resnet18(spatial_dims=3,n_input_channels=1,num_classes=2).to(device)
print(a.shape)
a=np.resize(a,[64,184,184,1])
x=np.resize(x,[64,184,184,1])
a=np.expand_dims(a, axis=0)
x=np.expand_dims(x, axis=0)
#net.load_state_dict(checkpoint['model_state_dict'])
y_=y[:].reshape(n,1)
x_=x.transpose(0,4,3,2,1)
x_1=x.transpose(0,3,2,1,4)
a_=a.transpose(0,4,3,2,1)#.detach().cpu().numpy()
a_1=a.transpose(0,3,2,1,4)#.detach().cpu().numpy()
x_11=(x_1)#.detach().cpu().numpy()
y_o=(y_)#.detach().cpu().numpy()
#net.eval()
#net.to(device)
rc=quantus.Complexity(disable_warnings=True)(model=net, x_batch=x_,y_batch=y_,a_batch=a_,device=device,channel_first=True)
print(rc)

rd=quantus.FaithfulnessCorrelation(nr_runs=100,  subset_size=4000,  perturb_baseline="black", perturb_func=quantus.perturb_func.baseline_replacement_by_indices,similarity_func=quantus.similarity_func.correlation_pearson,  abs=True, return_aggregate=False, disable_warnings = True)(model=net,x_batch=x_11, y_batch=y_o,a_batch=a_1,channel_first=True,device=device)
print(rd)