<a href="https://colab.research.google.com/github/nimrashaheen001/Programming_for_AI/blob/main/MMGFL(Final).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install SimpleITK
import SimpleITK as sitk

import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
import glob
import cv2
import numpy as np
from skimage import transform
import os
import pandas as pd
import torchvision
import csv


def resize_image(itkimage, newSize, resamplemethod):

    resampler = sitk.ResampleImageFilter()
    originSize = itkimage.GetSize()
    originSpacing = itkimage.GetSpacing()

    newSize = np.array(newSize, float)
    factor = originSize / newSize
    newSpacing = originSpacing * factor
    newSize = newSize.astype(np.int)

    if resamplemethod == sitk.sitkNearestNeighbor:
        resampler.SetOutputPixelType(sitk.sitkUInt8)
    else:
        resampler.SetOutputPixelType(sitk.sitkFloat32)

    resampler.SetReferenceImage(itkimage)
    resampler.SetSize(newSize.tolist())
    resampler.SetOutputSpacing(newSpacing.tolist())
    resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
    resampler.SetInterpolator(resamplemethod)
    itk_img_res = resampler.Execute(itkimage)

    return itk_img_res

def norm_image(image):
    max = torch.max(image)
    min = torch.min(image)
    image = (image - min) / (max - min)
    return image

from sklearn.metrics import confusion_matrix

def specificity(Y_test, Y_pred, n):

    spe = []
    con_mat = confusion_matrix(Y_test,Y_pred)
    for i in range(n):
        number = np.sum(con_mat[:,:])
        tp = con_mat[i][i]
        fn = np.sum(con_mat[i,:]) - tp
        fp = np.sum(con_mat[:,i]) - tp
        tn = number - tp - fn - fp
        spe1 = tn/ (tn + fp)
        spe.append(spe1)
    return spe


def ACC(Y_test, Y_pred, n):
    acc = []
    con_mat = confusion_matrix(Y_test, Y_pred)
    for i in range(n):
        number = np.sum(con_mat[:, :])
        tp = con_mat[i][i]
        fn = np.sum(con_mat[i, :]) - tp
        fp = np.sum(con_mat[:, i]) - tp
        tn = number - tp - fn - fp
        acc1 = (tp + tn) / number
        acc.append(acc1)

    return acc




class ADMdataset(Dataset):
    def __init__(self, data_txt):
        self.data_txt=data_txt
        self.datasets=[ ]

        for file in open(self.data_txt,'r'):
            image_file=file.strip('\n').split(' ')[0]
            image_label=file.strip('\n').split(' ')[1]
            self.datasets.append([image_file, image_label])
            # print(self.datasets)

    def __getitem__(self, idx):
        image = self.datasets[idx][0]
        dir_name = os.path.dirname(os.path.dirname(os.path.dirname(self.datasets[idx][0])))
        txt_file_path = os.path.join(dir_name, 'tabular.csv')
        series_reader = sitk.ImageSeriesReader()
        fileNames = series_reader.GetGDCMSeriesFileNames(self.datasets[idx][0])
        series_reader.SetFileNames(fileNames)
        images = series_reader.Execute()
        images = resize_image(images, (64, 64, 64), resamplemethod=sitk.sitkLinear)
        img_array = sitk.GetArrayFromImage(images)
        img_vol = torch.from_numpy(img_array)
        img_vol = norm_image(img_vol)
        img_vol = img_vol.unsqueeze(0).float()
        image_label = self.datasets[idx][1]
        image_label = float(image_label)
        c = os.path.basename(image)
        list = []
        with open(txt_file_path, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if row['PTID'] == c:
                    del row['PTID']
                    for v in row.values():
                        v = float(v)
                        list.append(v)
                    data = torch.tensor(list)
        image_label = torch.tensor(image_label)
        return img_vol, image_label, data
    def __len__(self):
        return len(self.datasets)

from _py_abc import ABCMeta
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections.abc import Sequence # Import Sequence from collections.abc instead of collections



def conv3d(in_channels, out_channels, kernel_size=3, stride=1):
    if kernel_size != 1:
        padding = 1
    else:
        padding = 0
    return nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)


class ConvBnReLU(nn.Module):
    def __init__(
        self, in_channels, out_channels, bn_momentum=0.05, kernel_size=3, stride=1, padding=1,
    ):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_momentum=0.05, stride=1):
        super().__init__()
        self.conv1 = conv3d(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
        self.conv2 = conv3d(out_channels, out_channels)
        self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=True)

        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm3d(out_channels, momentum=bn_momentum),
            )
        else:
            self.downsample = None

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)

        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Attention(nn.Module):
    def __init__(
        self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
    ):
        super().__init__()
        self.num_heads = heads
        head_dim = dim // heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

from _py_abc import ABCMeta
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections.abc import Sequence

from collections import OrderedDict
from typing import Any, Dict, Optional, Sequence

import torch
import torch.nn as nn


class ConcatHNN1FC(nn.Module):
    def __init__(self, in_channels=1, n_outputs=3, bn_momentum=0.1, n_basefilters=4, ndim_non_img=10) -> None:
        super().__init__()
        self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
        self.pool1 = nn.MaxPool3d(2, stride=2)  # 32
        self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
        self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 16
        self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 8
        self.block4 = ResBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 4
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(8 * n_basefilters + ndim_non_img, 32)
        self.fc1 = nn.Linear(32, n_outputs)


    def forward(self, image, tabular):
        out = self.conv1(image)
        out = self.pool1(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = torch.cat((out, tabular), dim=1)
        out = self.fc(out)
        out = self.fc1(out)
        return out


class ConcatHNN2FC(nn.Module):
    def __init__(self, in_channels=1, n_outputs=3, bn_momentum=0.1, n_basefilters=4, ndim_non_img=10, bottleneck_dim=15) -> None:
        super().__init__()
        self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
        self.pool1 = nn.MaxPool3d(2, stride=2)  # 32
        self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
        self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 16
        self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 8
        self.block4 = ResBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 4
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(8 * n_basefilters + ndim_non_img, bottleneck_dim)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(bottleneck_dim, n_outputs)

    def forward(self, image, tabular):
        out = self.conv1(image)
        out = self.pool1(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = torch.cat((out, tabular), dim=1)
        out = self.fc(out)
        out = self.relu(out)
        out = self.fc1(out)

        return out


class HeterogeneousResNet(nn.Module):
    def __init__(self, in_channels=1, n_outputs=3, bn_momentum=0.1, n_basefilters=4) -> None:
        super().__init__()

        self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
        self.pool1 = nn.MaxPool3d(2, stride=2)  # 32
        self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
        self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 16
        self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 8
        self.block4 = ResBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 4
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(8 * n_basefilters, n_outputs)


    def forward(self, image):
        out = self.conv1(image)
        out = self.pool1(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out


class InteractiveHNN(nn.Module):

    def __init__(self, in_channels=1, n_outputs=3, bn_momentum=0.1, n_basefilters=4, ndim_non_img=10) -> None:
        super().__init__()

        # ResNet
        self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
        self.pool1 = nn.MaxPool3d(2, stride=2)  # 32
        self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
        self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 16
        self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 8
        self.block4 = ResBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 4
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(8 * n_basefilters, n_outputs)

        layers = [
            ("aux_base", nn.Linear(ndim_non_img, 15, bias=False)),
            ("aux_relu", nn.ReLU()),
            # ("aux_dropout", nn.Dropout(p=0.2, inplace=True)),
            ("aux_1", nn.Linear(15, n_basefilters, bias=False)),
        ]
        self.aux = nn.Sequential(OrderedDict(layers))

        self.aux_2 = nn.Linear(n_basefilters, n_basefilters, bias=False)
        self.aux_3 = nn.Linear(n_basefilters, 2 * n_basefilters, bias=False)
        self.aux_4 = nn.Linear(2 * n_basefilters, 4 * n_basefilters, bias=False)


    def forward(self, image, tabular):
        out = self.conv1(image)
        out = self.pool1(out)
        tabular = tabular.to(torch.float32)

        attention = self.aux(tabular)
        batch_size, n_channels = out.size()[:2]
        out = torch.mul(out, attention.view(batch_size, n_channels, 1, 1, 1))
        out = self.block1(out)

        attention = self.aux_2(attention)
        batch_size, n_channels = out.size()[:2]
        out = torch.mul(out, attention.view(batch_size, n_channels, 1, 1, 1))
        out = self.block2(out)

        attention = self.aux_3(attention)
        batch_size, n_channels = out.size()[:2]
        out = torch.mul(out, attention.view(batch_size, n_channels, 1, 1, 1))
        out = self.block3(out)

        attention = self.aux_4(attention)
        batch_size, n_channels = out.size()[:2]
        out = torch.mul(out, attention.view(batch_size, n_channels, 1, 1, 1))
        out = self.block4(out)

        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out1 = F.softmax(out, dim=1)

        return out1

class DAFT2021(nn.Module):
    def __init__(
            self,
            in_channels: int=1,
            n_outputs: int=3,
            bn_momentum: float = 0.1,
            n_basefilters: int = 4,
            filmblock_args: Optional[Dict[Any, Any]] = None,
    ) -> None:
        super().__init__()

        if filmblock_args is None:
            filmblock_args = {}

        if filmblock_args is None:
            filmblock_args = {}

        self.split_size = 4 * n_basefilters
        self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
        self.pool1 = nn.MaxPool3d(2, stride=2)  # 32
        self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
        self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 16
        self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 8
        self.blockX = DAFTBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, **filmblock_args)  # 4
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(8 * n_basefilters, n_outputs)



    def forward(self, image, tabular):
        out = self.conv1(image)
        out = self.pool1(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.blockX(out, tabular)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out


class FilmBase(nn.Module, metaclass=ABCMeta):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        bn_momentum: float,
        stride: int,
        ndim_non_img: int,
        location: int,
        activation: str,
        scale: bool,
        shift: bool,
    ) -> None:

        super().__init__()

        # sanity checks
        if location not in set(range(5)):
            raise ValueError(f"Invalid location specified: {location}")
        if activation not in {"tanh", "sigmoid", "linear"}:
            raise ValueError(f"Invalid location specified: {location}")
        if (not isinstance(scale, bool) or not isinstance(shift, bool)) or (not scale and not shift):
            raise ValueError(
                f"scale and shift must be of type bool:\n    -> scale value: {scale}, "
                "scale type {type(scale)}\n    -> shift value: {shift}, shift type: {type(shift)}"
            )
        # ResBlock
        self.conv1 = conv3d(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum, affine=(location != 3))
        self.conv2 = conv3d(out_channels, out_channels)
        self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm3d(out_channels, momentum=bn_momentum),
            )
        else:
            self.downsample = None
        # Film-specific variables
        self.location = location
        if self.location == 2 and self.downsample is None:
            raise ValueError("This is equivalent to location=1 and no downsampling!")
        # location decoding
        self.film_dims = 0
        if location in {0, 1, 2}:
            self.film_dims = in_channels
        elif location in {3, 4}:
            self.film_dims = out_channels
        if activation == "sigmoid":
            self.scale_activation = nn.Sigmoid()
        elif activation == "tanh":
            self.scale_activation = nn.Tanh()
        elif activation == "linear":
            self.scale_activation = None


    def rescale_features(self, feature_map, x_aux):
        """method to recalibrate feature map x"""

    def forward(self, feature_map, x_aux):

        if self.location == 0:
            feature_map = self.rescale_features(feature_map, x_aux)
        residual = feature_map

        if self.location == 1:
            residual = self.rescale_features(residual, x_aux)

        if self.location == 2:
            feature_map = self.rescale_features(feature_map, x_aux)
        out = self.conv1(feature_map)
        out = self.bn1(out)

        if self.location == 3:
            out = self.rescale_features(out, x_aux)
        out = self.relu(out)

        if self.location == 4:
            out = self.rescale_features(out, x_aux)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(residual)
        out += residual
        out = self.relu(out)

        return out


class DAFTBlock(FilmBase):
    # Block for ZeCatNet
    def __init__(
        self,
        in_channels: int=1,
        out_channels: int=3,
        bn_momentum: float = 0.1,
        stride: int = 2,
        ndim_non_img: int = 10,
        location: int = 3,
        activation: str = "linear",
        scale: bool = True,
        shift: bool = True,
        bottleneck_dim: int = 15,
    ) -> None:

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            bn_momentum=bn_momentum,
            stride=stride,
            ndim_non_img=ndim_non_img,
            location=location,
            activation=activation,
            scale=scale,
            shift=shift,
        )

        self.bottleneck_dim = bottleneck_dim
        aux_input_dims = self.film_dims
        # shift and scale decoding
        self.split_size = 0
        if scale and shift:
            self.split_size = self.film_dims
            self.scale = None
            self.shift = None
            self.film_dims = 2 * self.film_dims
        elif not scale:
            self.scale = 1
            self.shift = None
        elif not shift:
            self.shift = 0
            self.scale = None

        # create aux net
        layers = [
            ("aux_base", nn.Linear(ndim_non_img + aux_input_dims, self.bottleneck_dim, bias=False)),
            ("aux_relu", nn.ReLU()),
            ("aux_out", nn.Linear(self.bottleneck_dim, self.film_dims, bias=False)),
        ]
        self.aux = nn.Sequential(OrderedDict(layers))

    def rescale_features(self, feature_map, x_aux):

        squeeze = self.global_pool(feature_map)
        squeeze = squeeze.view(squeeze.size(0), -1)
        squeeze = torch.cat((squeeze, x_aux), dim=1)

        attention = self.aux(squeeze)
        if self.scale == self.shift:
            v_scale, v_shift = torch.split(attention, self.split_size, dim=1)
            v_scale = v_scale.view(*v_scale.size(), 1, 1, 1).expand_as(feature_map)
            v_shift = v_shift.view(*v_shift.size(), 1, 1, 1).expand_as(feature_map)
            if self.scale_activation is not None:
                v_scale = self.scale_activation(v_scale)
        elif self.scale is None:
            v_scale = attention
            v_scale = v_scale.view(*v_scale.size(), 1, 1, 1).expand_as(feature_map)
            v_shift = self.shift
            if self.scale_activation is not None:
                v_scale = self.scale_activation(v_scale)
        elif self.shift is None:
            v_scale = self.scale
            v_shift = attention
            v_shift = v_shift.view(*v_shift.size(), 1, 1, 1).expand_as(feature_map)
        else:
            raise AssertionError(
                f"Sanity checking on scale and shift failed. Must be of type bool or None: {self.scale}, {self.shift}"
            )

        return (v_scale * feature_map) + v_shift


class Fusion2022(nn.Module):
    def __init__(self, in_channels=1, n_outputs=3, bn_momentum=0.1, n_basefilters=4, ndim_non_img=10) -> None:
        super().__init__()


        self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
        self.pool1 = nn.MaxPool3d(2, stride=2)  # 32
        self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
        self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 16
        self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2)  # 8
        self.blockX = FBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum)  # 4
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(32, 16)
        self.fc1 = nn.Linear(16, n_outputs)

        layers = [
            ("aux_base", nn.Linear(ndim_non_img, 7, bias=False)),
            ("aux_relu", nn.ReLU()),
            # ("aux_dropout", nn.Dropout(p=0.2, inplace=True)),
            ("aux_1", nn.Linear(7, 4, bias=False)),
        ]

        layers2 = [
            ("aux_base", nn.Linear(ndim_non_img, 15, bias=False)),
            ("aux_relu", nn.ReLU()),
            # ("aux_dropout", nn.Dropout(p=0.2, inplace=True)),
            ("aux_1", nn.Linear(15, 32, bias=False)),
        ]
        self.aux = nn.Sequential(OrderedDict(layers))
        self.aux1 = nn.Sequential(OrderedDict(layers2))



    def forward(self, image, tabular):
        out = self.conv1(image)
        out = self.pool1(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.blockX(out, tabular)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)

        return out

#
class FBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_momentum=0.05, stride=1, n_basefilters=4, ndim_non_img=10):
        super().__init__()
        self.conv1 = conv3d(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
        self.conv2 = conv3d(out_channels, out_channels)
        self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.global_pool1 = nn.AvgPool3d(5)
        self.fc5 = nn.Linear(32, 4)

        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm3d(out_channels, momentum=bn_momentum),
            )
        else:
            self.downsample = None

        layers = [
            ("aux_base", nn.Linear(ndim_non_img, 7, bias=False)),
            ("aux_relu", nn.ReLU()),
            # ("aux_dropout", nn.Dropout(p=0.2, inplace=True)),
            ("aux_1", nn.Linear(7, n_basefilters, bias=False)),
        ]
        self.aux = nn.Sequential(OrderedDict(layers))

    def forward(self, x, tabular):
        residual = x
        out = self.conv1(x)
        out2 = self.bn1(out)
        batch_size, n_channels = out2.size()[:2]
        attention = self.aux(tabular)
        cat2 = self.fc5(out2)
        out = torch.cat(attention, cat2)
        out = Attention(out)

        out2 = self.relu(out)
        out = self.conv2(out2)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

from sklearn.metrics import confusion_matrix
import numpy as np



def ACC(Y_test, Y_pred, n):
    acc = []
    con_mat = confusion_matrix(Y_test, Y_pred)
    for i in range(n):
        number = np.sum(con_mat[:, :])
        tp = con_mat[i][i]
        fn = np.sum(con_mat[i, :]) - tp
        fp = np.sum(con_mat[:, i]) - tp
        tn = number - tp - fn - fp
        acc1 = (tp + tn) / number
        acc.append(acc1)

    return acc
import os
from google.colab import drive
drive.mount('/content/drive')


def creat_filelist(input_path, classes):

    dir_image1 = []
    file_list = []
    for index, name in enumerate(classes):
        print('index', index)
        index_str = str(index)
        dir_image1_temp = input_path + '/' + name # Remove the trailing '/'
        # Check if the path is a directory before calling os.listdir()
        if os.path.isdir(dir_image1_temp):
            for dir2 in os.listdir(dir_image1_temp):
                dir_image2_temp = dir_image1_temp + '/' + dir2 + ' ' + index_str
                file_list.append(dir_image2_temp)

    return dir_image1, file_list

def creat_txtfile(output_path, file_list):
    with open(output_path, 'w') as f:
        for list in file_list:
            print(list)
            f.write(str(list) + '\n')


def main():
    dir_image0 = "/content/drive/MyDrive/train"
    dir_image1 = os.listdir(dir_image0)
    classes = dir_image1
    print(classes)
    dir_list, file_list = creat_filelist(dir_image0, classes)

    # Create train.txt
    output_path = 'train.txt'
    creat_txtfile(output_path, file_list)

    # --- Create val-oasis.txt ---
    val_dir = "/content/drive/MyDrive/train/val" # Adjust path if necessary
    _, val_file_list = creat_filelist(val_dir, classes)
    val_output_path = 'val-oasis.txt'
    creat_txtfile(val_output_path, val_file_list)
    # --- ---

if __name__ == '__main__':
    main()

import torch


def norm_image(image):
    max = torch.max(image)
    min = torch.min(image)
    image = (image - min) / (max - min)
    return image

!pip install SimpleITK
import SimpleITK as sitk


def resize_image(itkimage, newSize, resamplemethod):

    resampler = sitk.ResampleImageFilter()
    originSize = itkimage.GetSize()
    originSpacing = itkimage.GetSpacing()

    newSize = np.array(newSize, float)
    factor = originSize / newSize
    newSpacing = originSpacing * factor
    newSize = newSize.astype(np.int)

    if resamplemethod == sitk.sitkNearestNeighbor:
        resampler.SetOutputPixelType(sitk.sitkUInt8)
    else:
        resampler.SetOutputPixelType(sitk.sitkFloat32)

    resampler.SetReferenceImage(itkimage)
    resampler.SetSize(newSize.tolist())
    resampler.SetOutputSpacing(newSpacing.tolist())
    resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
    resampler.SetInterpolator(resamplemethod)
    itk_img_res = resampler.Execute(itkimage)

    return itk_img_res

from sklearn.metrics import confusion_matrix
import numpy as np

def specificity(Y_test, Y_pred, n):

    spe = []
    con_mat = confusion_matrix(Y_test,Y_pred)
    for i in range(n):
        number = np.sum(con_mat[:,:])
        tp = con_mat[i][i]
        fn = np.sum(con_mat[i,:]) - tp
        fp = np.sum(con_mat[:,i]) - tp
        tn = number - tp - fn - fp
        spe1 = tn/ (tn + fp)
        spe.append(spe1)
    return spe

import SimpleITK as sitk
import numpy as np
import torch

def resize_image(itkimage, newSize, resamplemethod):

    resampler = sitk.ResampleImageFilter()
    originSize = itkimage.GetSize()
    originSpacing = itkimage.GetSpacing()

    newSize = np.array(newSize, float)
    factor = originSize / newSize
    newSpacing = originSpacing * factor
    newSize = newSize.astype(np.int)

    if resamplemethod == sitk.sitkNearestNeighbor:
        resampler.SetOutputPixelType(sitk.sitkUInt8)
    else:
        resampler.SetOutputPixelType(sitk.sitkFloat32)

    resampler.SetReferenceImage(itkimage)
    resampler.SetSize(newSize.tolist())
    resampler.SetOutputSpacing(newSpacing.tolist())
    resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
    resampler.SetInterpolator(resamplemethod)
    itk_img_res = resampler.Execute(itkimage)

    return itk_img_res

def norm_image(image):
    max = torch.max(image)
    min = torch.min(image)
    image = (image - min) / (max - min)
    return image

from sklearn.metrics import confusion_matrix

def specificity(Y_test, Y_pred, n):

    spe = []
    con_mat = confusion_matrix(Y_test,Y_pred)
    for i in range(n):
        number = np.sum(con_mat[:,:])
        tp = con_mat[i][i]
        fn = np.sum(con_mat[i,:]) - tp
        fp = np.sum(con_mat[:,i]) - tp
        tn = number - tp - fn - fp
        spe1 = tn/ (tn + fp)
        spe.append(spe1)
    return spe


def ACC(Y_test, Y_pred, n):
    acc = []
    con_mat = confusion_matrix(Y_test, Y_pred)
    for i in range(n):
        number = np.sum(con_mat[:, :])
        tp = con_mat[i][i]
        fn = np.sum(con_mat[i, :]) - tp
        fp = np.sum(con_mat[:, i]) - tp
        tn = number - tp - fn - fp
        acc1 = (tp + tn) / number
        acc.append(acc1)

    return acc

# E:\projects\pythonProject6

import os
from shutil import copy,rmtree
import random
import shutil

def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)



random.seed(0)
split_rate=0.2
cwd=os.getcwd()
data_root = '/content/drive/MyDrive/train'
# print(data_root)
# E:\projects\pythonProject6\dataset1
origin_flower_path=os.path.join(data_root)
# print(origin_flower_path)
# E:\projects\pythonProject6\dataset1\train1
assert os.path.exists(origin_flower_path), f"Path does not exist: {origin_flower_path}" # Add a more informative message
flower_class=[cla for cla in os.listdir(origin_flower_path)
                  if os.path.isdir(os.path.join(origin_flower_path,cla))]
# print(flower_class)
# 生成了一个花类列表
# E:\projects\pythonProject6\dataset1\train
train_root=os.path.join(data_root,"train")
mk_file(train_root)
for cla in flower_class:
    mk_file(os.path.join(train_root,cla))

# E:\projects\pythonProject6\dataset1\val
val_root=os.path.join(data_root,"val")
mk_file(val_root)
for cla in flower_class:
    mk_file(os.path.join(val_root,cla))


for cla in flower_class:
    cla_path=os.path.join(origin_flower_path,cla)
    # print(cla_path)
    # 定位到原始数据集的各类花的文件夹
    images=os.listdir(cla_path)
    # print(images)
    num=len(images)
    # print(num)
    eval_index=random.sample(images,k=int(num*split_rate))
    k = int(num * split_rate)
    # print(k)

    for index,image in enumerate(images):
        if image in eval_index:
            image_path=os.path.join(cla_path,image)
            # print(image_path)
            new_path=os.path.join(val_root,cla)
            # print(new_path)
            shutil.move(image_path,new_path)
        else:
            image_path=os.path.join(cla_path,image)
            new_path=os.path.join(train_root,cla)
            shutil.move(image_path,new_path)
        print("\r[{}] processing [{}/{}]".format(cla,index+1,num),end="")
    print()

print("processing done!")

!pip install modAL
import torch
import torch.nn.functional as F

# In the file modal/__init__.py, add the following lines:

# Import the Fusion2022 class from the correct location (e.g., modal.py)
#from .modal import Fusion2022  # Assuming Fusion2022 is in modal.py

# Add Fusion2022 to the __all__ list to make it importable
#__all__ = ["Fusion2022", ...] # Add other members if necessary


class SaveValues:
    def __init__(self, m):
        # register a hook to save values of activations and gradients
        self.activations = None
        self.gradients = None
        self.forward_hook = m.register_forward_hook(self.hook_fn_act)
        self.backward_hook = m.register_backward_hook(self.hook_fn_grad)

    def hook_fn_act(self, module, input, output):
        self.activations = output

    def hook_fn_grad(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def remove(self):
        self.forward_hook.remove()
        self.backward_hook.remove()


class CAM(object):
    """ Class Activation Mapping """

    def __init__(self, model, target_layer):
        """
        Args:
            model: ResNet_linear()
            target_layer: conv_layer before Global Average Pooling
        """

        self.model = model
        self.target_layer = target_layer

        # save values of activations and gradients in target_layer
        self.values = SaveValues(self.target_layer)

    def forward(self, x, y):
        """
        Args:
            x: input image. shape => (N, 3, T, H, W)
        Return:
            heatmap: class activation mappings of the predicted class
        """

        # object classification
        score = self.model(x, y)
        prob = F.softmax(score, dim=1)
        max_prob, idx = torch.max(prob, dim=1)
        print(
            "predicted action ids {}\t probability {}".format(
                idx.item(), max_prob.item()
            )
        )

        # cam can be calculated from the weights of linear layer and activations
        weight_fc = list(self.model._modules.get("fc1").parameters())[0].to("cpu").data
        cam = self.getCAM(self.values, weight_fc, idx.item())

        return cam

    def __call__(self, x, y):
        return self.forward(x, y)

    def getCAM(self, values, weight_fc, idx):
        """
        values: the activations and gradients of target_layer
            activations: feature map before GAP.  shape => (N, C, T, H, W)
        weight_fc: the weight of fully connected layer.  shape => (num_classes, C)
        idx: predicted class id
        cam: class activation map.  shape => (1, num_classes, H, W)
        """

        cam = F.conv3d(values.activations, weight=weight_fc[:, :, None, None, None])
        _, _, t, h, w = cam.shape

        # class activation mapping only for the predicted class
        # cam is normalized with min-max.
        cam = cam[:, idx, :, :, :]
        cam -= torch.min(cam)
        cam /= torch.max(cam)
        cam = cam.view(1, 1, t, h, w)

        return cam.data


class GradCAM(CAM):
    """ Grad CAM """

    def __init__(self, model, target_layer):
        super().__init__(model, target_layer)

        """
        Args:
            model: a base model to get CAM, which need not have global pooling and fully connected layer.
            target_layer: conv_layer you want to visualize
        """

    def forward(self, x, idx=None):
        """
        Args:
            x: input image. shape =>(1, 3, T, H, W)
            idx: ground truth index => (1, C)
        Return:
            heatmap: class activation mappings of the predicted class
        """

        if isinstance(self.model, Fusion2022):
            score = self.model(x)
        else:
            score, _, _ = self.model(x)

        prob = torch.softmax(score, dim=1)

        if idx is None:
            prob, idx = torch.max(prob, dim=1)
            idx = idx.item()
            prob = prob.item()
            print("predicted class ids {}\t probability {}".format(idx, prob))

        # caluculate cam of the predicted class
        cam = self.getGradCAM(self.values, score, idx)

        return cam, idx

    def __call__(self, x, idx=None):
        return self.forward(x, idx)

    def getGradCAM(self, values, score, idx):
        """
        values: the activations and gradients of target_layer
            activations: feature map before GAP.  shape => (1, C, T, H, W)
        score: the output of the model before softmax
        idx: predicted class id
        cam: class activation map.  shape=> (1, 1, T, H, W)
        """

        self.model.zero_grad()

        score[0, idx].backward(retain_graph=True)

        activations = values.activations
        gradients = values.gradients
        n, c, _, _, _ = gradients.shape
        alpha = gradients.view(n, c, -1).mean(2)
        alpha = alpha.view(n, c, 1, 1, 1)

        # shape => (1, 1, H', W')
        cam = (alpha * activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        cam -= torch.min(cam)
        cam /= torch.max(cam)

        return cam.data

import torch
import torch.nn.functional as F

#from modal import HeterogeneousResNet


class SaveValues:
    def __init__(self, m):
        # register a hook to save values of activations and gradients
        self.activations = None
        self.gradients = None
        self.forward_hook = m.register_forward_hook(self.hook_fn_act)
        self.backward_hook = m.register_backward_hook(self.hook_fn_grad)

    def hook_fn_act(self, module, input, output):
        self.activations = output

    def hook_fn_grad(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def remove(self):
        self.forward_hook.remove()
        self.backward_hook.remove()


class CAM(object):
    """ Class Activation Mapping """

    def __init__(self, model, target_layer):
        """
        Args:
            model: ResNet_linear()
            target_layer: conv_layer before Global Average Pooling
        """

        self.model = model
        self.target_layer = target_layer

        # save values of activations and gradients in target_layer
        self.values = SaveValues(self.target_layer)

    def forward(self, x):
        """
        Args:
            x: input image. shape => (N, 3, T, H, W)
        Return:
            heatmap: class activation mappings of the predicted class
        """

        # object classification
        score = self.model(x)
        prob = F.softmax(score, dim=1)
        max_prob, idx = torch.max(prob, dim=1)
        print(
            "predicted action ids {}\t probability {}".format(
                idx.item(), max_prob.item()
            )
        )

        # cam can be calculated from the weights of linear layer and activations
        weight_fc = list(self.model._modules.get("fc").parameters())[0].to("cpu").data
        cam = self.getCAM(self.values, weight_fc, idx.item())

        return cam

    def __call__(self, x):
        return self.forward(x)

    def getCAM(self, values, weight_fc, idx):
        """
        values: the activations and gradients of target_layer
            activations: feature map before GAP.  shape => (N, C, T, H, W)
        weight_fc: the weight of fully connected layer.  shape => (num_classes, C)
        idx: predicted class id
        cam: class activation map.  shape => (1, num_classes, H, W)
        """

        cam = F.conv3d(values.activations, weight=weight_fc[:, :, None, None, None])
        _, _, t, h, w = cam.shape

        # class activation mapping only for the predicted class
        # cam is normalized with min-max.
        cam = cam[:, idx, :, :, :]
        cam -= torch.min(cam)
        cam /= torch.max(cam)
        cam = cam.view(1, 1, t, h, w)

        return cam.data


class GradCAM(CAM):
    """ Grad CAM """

    def __init__(self, model, target_layer):
        super().__init__(model, target_layer)

        """
        Args:
            model: a base model to get CAM, which need not have global pooling and fully connected layer.
            target_layer: conv_layer you want to visualize
        """


    def forward(self, x, tabular):
        residual = x
        out = self.conv1(x)
        out2 = self.bn1(out)
        batch_size, n_channels = out2.size()[:2]
        attention = self.aux(tabular)
        # attention = self.global_pool1(attention) # Add this line to reduce dimensions of attention
        # attention = attention.view(attention.size(0), -1) # Add this line to reshape attention

        # cat2 = self.fc5(out2) # Comment out this line as it is not being used anymore
        out = torch.cat((attention, out2.view(out2.size(0), -1)), dim=1)  # Edit this line to concatenate correctly
        # out = Attention(out) # Remove this line as Attention is not defined

        out2 = self.relu(out)
        out = self.conv2(out2)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

    def __call__(self, x, idx=None):
        return self.forward(x, idx)

    def getGradCAM(self, values, score, idx):
        """
        values: the activations and gradients of target_layer
            activations: feature map before GAP.  shape => (1, C, T, H, W)
        score: the output of the model before softmax
        idx: predicted class id
        cam: class activation map.  shape=> (1, 1, T, H, W)
        """

        self.model.zero_grad()

        score[0, idx].backward(retain_graph=True)

        activations = values.activations
        gradients = values.gradients
        n, c, _, _, _ = gradients.shape
        alpha = gradients.view(n, c, -1).mean(2)
        alpha = alpha.view(n, c, 1, 1, 1)

        # shape => (1, 1, H', W')
        cam = (alpha * activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        cam -= torch.min(cam)
        cam /= torch.max(cam)

        return cam.data

import copy

import torch
#from loaddata import ADMdataset
from torch.utils.data.dataloader import DataLoader

#from model import ConcatHNN2FC
import numpy as np
from tqdm import *

from sklearn.metrics import roc_curve, auc
from sklearn.metrics import classification_report
from itertools import cycle

from torch.autograd import Variable

import matplotlib.pyplot as plt
import time

from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import roc_auc_score
#from tool import specificity, ACC

decay=0.9

train_txt='train.txt'
valid_txt='val-oasis.txt'

model=ConcatHNN2FC()
model.cuda()

traindata=ADMdataset(train_txt)
validdata=ADMdataset(valid_txt)

train_loader = DataLoader(traindata, batch_size=8, num_workers=0, pin_memory=True, shuffle=True, drop_last=True)
valid_loader = DataLoader(validdata, batch_size=8, num_workers=0, pin_memory=True, shuffle=False, drop_last=True)

loss_f = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0023, weight_decay=0.001)
train_data_size = len(traindata)
print(train_data_size)
valid_data_size = len(validdata)


Use_gpu = torch.cuda.is_available()
if Use_gpu:
    model = model.cuda()

epoch_n = 30
time_open = time.time()
best_acc = 0.0
best_epoch = 0
best_model_wts = copy.deepcopy(model.state_dict())

training_loss = []
val_loss = []
training_accuracy = []
val_accuracy = []
num_class = 3

score_list = []
label_list = []
epoch_record1 = []
epoch_record2 = []

classes = ('AD', 'CN', 'MCI')

for epoch in range(epoch_n):
    epoch_start = time.time()

    epoch_record3 = []

    epoch_record11 = []
    epoch_record12 = []
    epoch_record13 = []

    model.train()

    train_loss = 0.0
    train_acc = 0.0
    valid_loss = 0.0
    valid_acc = 0.0
    correct = 0


    print('epoch {}/{}'.format(epoch, epoch_n - 1))
    print('-' * 10)
    pbar = tqdm(enumerate(train_loader), total=len(traindata), desc=f'epoch {epoch + 1}/{epoch_n}', unit='patient')

    for idx, (image, label, tabular) in pbar:
        X1 = image
        X2 = tabular
        Y = label
        X1, X2, Y = Variable(X1).cuda(), Variable(X2).cuda(), Variable(Y).cuda()
        y_pred = model(X1, X2)
        score_tmp = y_pred
        score_list.extend(score_tmp.detach().cpu().numpy())
        label_list.extend(Y.cpu().numpy())
        loss = loss_f(y_pred, Y.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss.item()*X1.size(0)
        ret, predictions = torch.max(y_pred.data, 1)
        correct_counts = predictions.eq(Y.data.view_as(predictions))
        acc=torch.mean(correct_counts.type(torch.FloatTensor))
        train_acc+=acc.item()*X1.size(0)
        print(train_acc / train_data_size)
    score_array = np.array(score_list)

    label_tensor = torch.tensor(label_list, dtype=torch.int64)
    label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
    label_onehot = torch.zeros(label_tensor.shape[0], num_class)
    label_onehot.scatter_(1, label_tensor, 1)
    label_onehot = np.array(label_onehot)
    AUC = roc_auc_score(label_onehot, score_array, multi_class='ovr')
    print(AUC)
    with torch.no_grad():
        model.eval()

        score_list = []
        label_list = []

        for j, (inputs, labels, tabular) in enumerate(valid_loader):
            inputs1 = Variable(inputs).cuda()
            inputs2 = Variable(tabular).cuda()
            labels = Variable(labels).cuda()
            outputs = model(inputs1, inputs2)
            score_tmp = outputs
            score_list.extend(score_tmp.detach().cpu().numpy())
            label_list.extend(labels.cpu().numpy())
            loss = loss_f(outputs, labels.long())
            valid_loss += loss.item() * inputs.size(0)
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            valid_acc += acc.item() * inputs1.size(0)
            epoch_record11.extend(labels.cpu().numpy())
            epoch_record12.extend(predictions.cpu().numpy())
        score_array = np.array(score_list)
        label_tensor = torch.tensor(label_list, dtype=torch.int64)
        label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
        label_onehot = torch.zeros(label_tensor.shape[0], num_class)
        label_onehot.scatter_(1, label_tensor, 1)
        label_onehot = np.array(label_onehot)
        AUC = roc_auc_score(label_onehot, score_array, multi_class='ovr')
        print(AUC)

        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(len(classes)):
            fpr[i], tpr[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
            print(roc_auc[i])
        lw = 2
        colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
        for i, color in zip(range(len(classes)), colors):
            plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                     label = 'ROC curve of class {0} (area = {1:0.2f})'
                     ''.format(i, roc_auc[i]))
        plt.plot([0, 1], [0, 1], 'k--', lw=lw)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('epoch {}/{}'.format(epoch, epoch_n - 1))
        plt.legend(loc='lower right')
        plt.show()
        acc = ACC(epoch_record11, epoch_record12, 3)
        print(acc)
        SPE1 = specificity(epoch_record11, epoch_record12, 3)
        print(SPE1)
        cr = classification_report(epoch_record11, epoch_record12, target_names=classes)
        print(cr)
    avg_train_loss = train_loss / train_data_size
    avg_train_acc = train_acc / train_data_size
    training_loss.append(train_loss / train_data_size)
    training_accuracy.append(avg_train_acc)

    avg_valid_loss = valid_loss / valid_data_size
    avg_valid_acc = valid_acc / valid_data_size
    val_loss.append(valid_loss / valid_data_size)
    val_accuracy.append(avg_valid_acc)

    if best_acc < avg_valid_acc:
        best_acc = avg_valid_acc
        best_epoch = epoch + 1
        best_model_wts = copy.deepcopy(model.state_dict())
        model.load_state_dict(best_model_wts)
        torch.save(model.state_dict(), 'best_model.pt')
    epoch_end = time.time()
    print(
        "Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(
            epoch + 1, avg_valid_loss, avg_train_acc * 100, avg_valid_loss, avg_valid_acc * 100, epoch_end - epoch_start
        ))
    print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))


plt.plot(np.arange(0, epoch_n), val_loss, label = 'val', marker = 'o')
plt.plot(np.arange(0, epoch_n), training_loss, label = 'train', marker = 'o')
plt.title('loss per epoch 8')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend()
plt.show()

plt.plot(np.arange(0, epoch_n), val_accuracy, label = 'val_acc', marker = 'x')
plt.plot(np.arange(0, epoch_n), training_accuracy, label = 'train_acc', marker = 'x')
plt.title('Accuracy per epoch 8')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

import argparse
from collections import OrderedDict
from matplotlib import pyplot as plt
import torch
#from CAM import GradCAM, CAM
#from modal import ConcatHNN1FC
from torch.utils.data.dataloader import DataLoader
import SimpleITK as sitk
#from tool import resize_image,norm_image
import pandas as pd
import os
from torchvision.utils import save_image
from matplotlib import pyplot
import cv2 as cv
import cv2
import numpy as np


def get_arguments():
    """
    parse all the arguments from command line inteface
    return a list of parsed arguments
    """

    parser = argparse.ArgumentParser(description="visualization")
    parser.add_argument(
        "--video_dir", type=str, default="./videos", help="path of a config file"
    )
    parser.add_argument(
        "--save_dir", type=str, default="./cams", help="path of a config file"
    )

    return parser.parse_args()



# image_path = 'C:/Users/medical/Desktop/seg/'




# def load_png(path, table_path):
#     d = 0
#     img_array = np.zeros([64, 64, 64])
#     fileList = os.listdir(path)
#     for file in fileList:
#         if file == '.DS_Store':
#             continue
#         fileName = path + file
#         png_data = cv.imread(fileName)
#         print(png_data.shape)
#         binary_data = cv.cvtColor(png_data, cv.COLOR_BGR2GRAY)
#         print(binary_data.shape)
#         # resize (64, 64)
#         resized_img = cv.resize(binary_data, (64, 64))
#         print(resized_img.shape)
#         img_array[:,:,d] = resized_img
#         d += 1
#     img_vol = torch.from_numpy(img_array)
#     img_vol = norm_image(img_vol)
#     clip = img_vol.unsqueeze(0).unsqueeze(0).float()
#
#
#     dir_name = os.path.dirname(os.path.dirname(os.path.dirname(table_path)))
#     txt_file_path = os.path.join(dir_name, 'tabular_vis.csv')
#
#     df = pd.read_csv(txt_file_path, header=None)
#     dataset = df.values
#     dataset = dataset.astype(float)
#     d = torch.from_numpy(dataset)
#
#     # [1,1,128,128,128]
#     return clip, d


def load_image(path):
    series_reader = sitk.ImageSeriesReader()
    fileNames = series_reader.GetGDCMSeriesFileNames(path)
    series_reader.SetFileNames(fileNames)
    images = series_reader.Execute()


    images = resize_image(images, (64, 64, 64), resamplemethod=sitk.sitkLinear)
    img_array = sitk.GetArrayFromImage(images)
    img_vol = torch.from_numpy(img_array)
    img_vol = norm_image(img_vol)
    clip = img_vol.unsqueeze(0).unsqueeze(0).float()


    dir_name = os.path.dirname(os.path.dirname(os.path.dirname(path)))
    txt_file_path = os.path.join(dir_name, 'tabular_vis.csv')

    df = pd.read_csv(txt_file_path, header=None)
    dataset = df.values
    dataset = dataset.astype(float)
    d = torch.from_numpy(dataset)


    return clip, d

image_path = '/content/drive/MyDrive/train'
tabular_path = 'OASIS-1.csv'






args = get_arguments()
model = ConcatHNN1FC()

state_dict = torch.load("best_model-fusion-0111.pt", map_location=lambda storage, loc: storage)


model.load_state_dict(state_dict)

target_layer = model.blockX.conv2

wrapped_model = CAM(model, target_layer)

model.eval()

with torch.no_grad():
    clip, tabular = load_image(image_path)
    # print(tabular)

    cam = wrapped_model(clip, tabular.to(torch.float32))
    # print(cam.shape)
    # print(clip.shape)
    heatmaps = visualize(clip, cam)
    # heatmaps = 255*heatmaps


# print(torch.min(heatmaps))
# print(torch.max(heatmaps))

save_path = 'E:/projects/pythonProject9/output10'
os.makedirs(save_path)
for i in range(clip.shape[2]):
    heatmap = heatmaps[:, :, i].squeeze()

    # print(heatmap.shape)

    save_image(heatmap, os.path.join(save_path, "{:0>3}.jpg".format(str(i))))
print("Done")







[1;30;43mStreaming output truncated to the last 5000 lines.[0m
index 573
index 574
index 575
index 576
index 577
index 578
index 579
index 580
index 581
index 582
index 583
index 584
index 585
index 586
index 587
index 588
index 589
index 590
index 591
index 592
index 593
index 594
index 595
index 596
index 597
index 598
index 599
index 600
index 601
index 602
index 603
index 604
index 605
index 606
index 607
index 608
index 609
index 610
index 611
index 612
index 613
index 614
index 615
index 616
index 617
index 618
index 619
index 620
index 621
index 622
index 623
index 624
index 625
index 626
index 627
index 628
index 629
index 630
index 631
index 632
index 633
index 634
index 635
index 636
index 637
index 638
index 639
index 640
index 641
index 642
index 643
index 644
index 645
index 646
index 647
index 648
index 649
index 650
index 651
index 652
index 653
index 654
index 655
index 656
index 657
index 658
index 659
index 660
index 661
index 662
index 663
index 664
index 665
index

epoch 1/30:   0%|          | 0/2 [00:00<?, ?patient/s]


ValueError: Found array with 0 sample(s) (shape=(0, 3)) while a minimum of 1 is required.