### Project Members
* Onur Şero
* Yunus Emre Aydın
* Yunus Emre Türkoğlu
* Ful Belin Korukoğlu


### 1. Initialize

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
csv_dir = "/content/drive/MyDrive/Colab Notebooks/CMPE58P/Project/petct_data.csv"
xlsx_dir = "/content/petct_data.xlsx"
image_dir = "/content/drive/MyDrive/Colab Notebooks/CMPE58P/Project/PETCT_data"


In [None]:
%pip install pynrrd
%pip install torch
%pip install numpy
%pip install SimpleITK
%pip install matplotlib
%pip install pandas
%pip install torchio
%pip install torchmetrics

### 2. ImageDataset

In [None]:
import os
import numpy as np
import torch 
import torchvision.transforms as T
import torchio as tio
import nrrd
import SimpleITK as sitk
import matplotlib.pyplot as plt
from datetime import datetime
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageDatasetClinical(torch.utils.data.Dataset):
    
    def __init__(self, partition, data_dir, df, split):
        self.partition = partition
        self.data_dir = data_dir
        self.df = df
        self.split = split
        self.ct_images = []
        self.pet_images = []
        self.seg_images = []
        self.tabular_data = []
        self.labels = []

        if data_dir.endswith("/") == False:
            self.data_dir = self.data_dir+ "/"

        list_dir = os.listdir(self.data_dir)
        # development_image_number = 50
        # list_dir = list_dir[0:development_image_number]

        if self.split == "train":
          split_dir = list_dir[:int(len(list_dir)*0.7)]
          print("Train size: ", len(split_dir))
        elif self.split == "val":    
          split_dir = list_dir[int(len(list_dir)*0.7):] 
          print("Val size: ", len(split_dir))

        unordered_clinical_data = self.get_clinical_data(xlsx_dir)
        for patient_dir in split_dir:
          clinical_data = unordered_clinical_data[(unordered_clinical_data[:,0] == int(patient_dir)).nonzero(as_tuple=True)[0],1:][0]
          img_ct, img_pet, img_seg, label = self.normalize_item(patient_dir)
          self.ct_images.append(img_ct)
          self.pet_images.append(img_pet)
          self.seg_images.append(img_seg)
          self.labels.append(label)
          self.tabular_data.append(clinical_data)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        img_ct = self.ct_images[idx]
        img_pet = self.pet_images[idx]
        img_seg = self.seg_images[idx]
        
        if self.split == "train":
          if self.partition == "ct":
              img = img_ct
              img = self.augmentation(img)
          if self.partition == "pet":
              img = img_pet
              img = self.augmentation(img)
          if self.partition == "seg":
              img = img_seg
          if self.partition == "ct_seg":
              img = img_ct + img_seg
          if self.partition == "ct_mult_seg":
              img = np.multiply(img_ct, 0.1 + (img_seg / 1.2))
          if self.partition == "pet_seg":
              img = img_pet + img_seg
          if self.partition == "pet_ct":
              img = img_pet + img_ct
              img = self.augmentation(img)
          if self.partition == "mixed":
              img = img_pet + img_ct + img_seg

        if self.split == "val":
          if self.partition == "ct" or self.partition == "ct_seg" or self.partition == "ct_mult_seg":
            img = img_ct      
          if self.partition == "pet" or self.partition == "pet_seg":
            img = img_pet
          if self.partition == "mixed" or self.partition == "pet_ct":
            img = img_pet + img_ct
        
        img = img.permute(2,0,1).unsqueeze(0)
        return img, self.tabular_data[idx], self.labels[idx]
    
    def normalize_item(self,path):
        print("Read image for ", path, ", Partition: ", self.partition)
        img_ct, img_pet, img_seg = self.read_image(path)

        if self.df.loc[self.df['Number'] == int(path), 'node'].values[0] == 1:
            label = np.ones(1)
        if self.df.loc[self.df['Number'] == int(path), 'node'].values[0] == 0:
            label= np.zeros(1)

        return img_ct, img_pet, img_seg, label

    def read_image(self, path):
       
        img_ct_raw, _ = nrrd.read(self.data_dir + path + f"/{path}_CT.nrrd")   
        img_ct_raw = (img_ct_raw-np.min(img_ct_raw))/(np.max(img_ct_raw)- np.min(img_ct_raw)) #normalization
        img_ct_raw = img_ct_raw.astype('float32')
        img_pet_raw, _ = nrrd.read(self.data_dir + path + f"/{path}_PET.nrrd")
        img_pet_raw = (img_pet_raw-np.min(img_pet_raw))/(np.max(img_pet_raw)- np.min(img_pet_raw)) #normalization
        img_pet_raw = img_pet_raw.astype('float32')
        img_seg_raw, _ = nrrd.read(self.data_dir + path + f"/{path}_SEG.nrrd") 
        img_seg_raw = img_seg_raw.astype('float32')
        
        z_mask =np.nonzero(img_seg_raw)[2]
        z_avg = ((min(z_mask) + max(z_mask)) / 2) / img_seg_raw.shape[2]
        z_start = int((z_avg - 0.1) * img_seg_raw.shape[2])
        z_end = int((z_avg + 0.1) * img_seg_raw.shape[2])
        
        if z_start < 0:
          z_start = 0
        if z_end > (img_seg_raw.shape[2] - 1):
          z_end = img_seg_raw.shape[2] - 1

        img_ct = img_ct_raw[ int(img_ct_raw.shape[0] * 0.4) : int(img_ct_raw.shape[0] * 0.6) ,  int(img_ct_raw.shape[1] * 0.4) : int(img_ct_raw.shape[1] * 0.6), z_start: z_end]
        img_pet = img_pet_raw[ int(img_pet_raw.shape[0] * 0.4) : int(img_pet_raw.shape[0] * 0.6) ,  int(img_pet_raw.shape[1] * 0.4) : int(img_pet_raw.shape[1] * 0.6), z_start: z_end]
        img_seg = img_seg_raw[ int(img_seg_raw.shape[0] * 0.4) : int(img_seg_raw.shape[0] * 0.6) ,  int(img_seg_raw.shape[1] * 0.4) : int(img_seg_raw.shape[1] * 0.6), z_start: z_end]
        print("Shape: ", img_ct.shape)
        del img_ct_raw
        del img_pet_raw
        del img_seg_raw
        gc.collect()
        #registraion with ct images
        
        img_pet = self.registration(img_ct, img_pet)
        img_seg = self.registration(img_ct, img_seg)


        img_ct = torch.tensor(img_ct)
        img_pet = torch.tensor(img_pet)
        img_seg = torch.tensor(img_seg)
        img_ct = torch.nn.functional.interpolate(img_ct, 128)
        img_pet = torch.nn.functional.interpolate(img_pet, 128)
        img_seg = torch.nn.functional.interpolate(img_seg, 128)       
        return img_ct, img_pet, img_seg
    
    def registration(self, img1, img2):
    
        start = datetime.now()
        img1 = sitk.GetImageFromArray(img1)
        img2 = sitk.GetImageFromArray(img2)

        initial_transform = sitk.CenteredTransformInitializer(
            img1, 
            img2,  
            sitk.Euler3DTransform(),
            sitk.CenteredTransformInitializerFilter.GEOMETRY,
        )
     
        registration_method = sitk.ImageRegistrationMethod()

        # Similarity metric settings.
        registration_method.SetMetricAsCorrelation()
        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
        registration_method.SetMetricSamplingPercentage(0.01)

        registration_method.SetInterpolator(sitk.sitkLinear)

        # Optimizer settings.
        registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=10)
        registration_method.SetOptimizerScalesFromPhysicalShift()
        
        # Don't optimize in-place, we would possibly like to run this cell multiple times.
        registration_method.SetInitialTransform(initial_transform, inPlace= False)
        final_transform = registration_method.Execute(img1,img2)
        moving_resampled = sitk.Resample(
            img2,
            img1,
            final_transform,
            sitk.sitkLinear,
            0.0,
            img2.GetPixelID(),
        ) 

        final_image = sitk.GetArrayFromImage(moving_resampled)
        stop = datetime.now()
        #print('Diff Time in sec: ', (stop - start).total_seconds())
        return final_image
    
    
    def augmentation(self,data):
        
        transform = T.Compose([
        T.RandomApply([
            tio.transforms.RandomBiasField(p=0.25),
            tio.transforms.RandomGhosting(p=0.25),
            tio.transforms.RandomSpike(p=0.25),
            tio.transforms.RandomAffine(degrees=10, scales=0., translation=0., p=0.25)],
            p=0.8)
        ])

        x = data.unsqueeze(-1) # x is a random variable to store data
        x = transform(x)
        data = x.squeeze(-1)
        
        return data

    def get_clinical_data(self,path):
        
        # Load clinical data from Excel file
        clinical_data = pd.read_excel(path)  # Replace with the path to your Excel file
        clinical_data = clinical_data.drop(columns=['node'])
        # Separate numerical and categorical columns
        
        cat_colnames = ['Number', 'metastasis', 'T_yale', 'SUV_prostate', 'psa', 'SUV_prostate']
        clinical_columns = clinical_data.loc[:, cat_colnames]

        # Convert numerical columns to a numpy array
        numerical_data = clinical_columns.to_numpy()

        # Impute missing values in the numerical data
        imputer = KNNImputer(n_neighbors=5)
        numerical_data_imputed = imputer.fit_transform(numerical_data)

        # Convert categorical columns to a numpy array
        categorical_data = clinical_columns.values
        categorical_data = categorical_data.astype(str)

        # Encode categorical columns using one-hot encoding
        encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
        categorical_data_encoded = encoder.fit_transform(categorical_data)

        # Concatenate the imputed numerical data with the categorical data
        clinical_data_processed = numerical_data_imputed

        # Convert the processed clinical data to a tensor
        x_clinical = torch.Tensor(clinical_data_processed)
        
        return x_clinical

In [None]:
import os
import numpy as np
import torch 
import torchvision.transforms as T
import torchio as tio
import nrrd
import SimpleITK as sitk
import matplotlib.pyplot as plt
from datetime import datetime
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageDataset(torch.utils.data.Dataset):
    
    def __init__(self, partition, data_dir, df, split):
        self.partition = partition
        self.data_dir = data_dir
        self.df = df
        self.split = split
        self.ct_images = []
        self.pet_images = []
        self.seg_images = []
        self.tabular_data = []
        self.labels = []

        if data_dir.endswith("/") == False:
            self.data_dir = self.data_dir+ "/"

        list_dir = os.listdir(self.data_dir)

        if self.split == "train":
          split_dir = list_dir[:int(len(list_dir)*0.7)]
          print("Train size: ", len(split_dir))
        elif self.split == "val":    
          split_dir = list_dir[int(len(list_dir)*0.7):] 
          print("Val size: ", len(split_dir))

        unordered_clinical_data = self.get_tabular_data(xlsx_dir)
        for patient_dir in split_dir:
          clinical_data = unordered_clinical_data[(unordered_clinical_data[:,0] == int(patient_dir)).nonzero(as_tuple=True)[0],1:][0]
          img_ct, img_pet, img_seg, label = self.normalize_item(patient_dir)
          self.ct_images.append(img_ct)
          self.pet_images.append(img_pet)
          self.seg_images.append(img_seg)
          self.labels.append(label)
          self.tabular_data.append(clinical_data)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        img_ct = self.ct_images[idx]
        img_pet = self.pet_images[idx]
        img_seg = self.seg_images[idx]
        
        if self.split == "train":
          if self.partition == "ct":
              img = img_ct
              img = self.augmentation(img)
          if self.partition == "pet":
              img = img_pet
              img = self.augmentation(img)
          if self.partition == "seg":
              img = img_seg
          if self.partition == "ct_seg":
              img = img_ct + img_seg
          if self.partition == "ct_mult_seg":
              img = np.multiply(img_ct, 0.1 + (img_seg / 1.2))
          if self.partition == "pet_seg":
              img = img_pet + img_seg
          if self.partition == "pet_ct":
              img = img_pet + img_ct
              img = self.augmentation(img)
          if self.partition == "mixed":
              img = img_pet + img_ct + img_seg

        if self.split == "val":
          if self.partition == "ct" or self.partition == "ct_seg" or self.partition == "ct_mult_seg":
            img = img_ct      
          if self.partition == "pet" or self.partition == "pet_seg":
            img = img_pet
          if self.partition == "mixed" or self.partition == "pet_ct":
            img = img_pet + img_ct
        
        img = img.permute(2,0,1).unsqueeze(0)
        return img, self.tabular_data[idx], self.labels[idx]
    
    def normalize_item(self,path):
        print("Read image for ", path, ", Partition: ", self.partition)
        img_ct, img_pet, img_seg = self.read_image(path)

        if self.df.loc[self.df['Number'] == int(path), 'node'].values[0] == 1:
            label = np.ones(1)
        if self.df.loc[self.df['Number'] == int(path), 'node'].values[0] == 0:
            label= np.zeros(1)

        return img_ct, img_pet, img_seg, label

    def read_image(self, path):
       
        img_ct_raw, _ = nrrd.read(self.data_dir + path + f"/{path}_CT.nrrd")   
        img_ct_raw = (img_ct_raw-np.min(img_ct_raw))/(np.max(img_ct_raw)- np.min(img_ct_raw)) #normalization
        img_ct_raw = img_ct_raw.astype('float32')
        img_pet_raw, _ = nrrd.read(self.data_dir + path + f"/{path}_PET.nrrd")
        img_pet_raw = (img_pet_raw-np.min(img_pet_raw))/(np.max(img_pet_raw)- np.min(img_pet_raw)) #normalization
        img_pet_raw = img_pet_raw.astype('float32')
        img_seg_raw, _ = nrrd.read(self.data_dir + path + f"/{path}_SEG.nrrd") 
        img_seg_raw = img_seg_raw.astype('float32')
        
        z_mask =np.nonzero(img_seg_raw)[2]
        z_avg = ((min(z_mask) + max(z_mask)) / 2) / img_seg_raw.shape[2]
        z_start = int((z_avg - 0.1) * img_seg_raw.shape[2])
        z_end = int((z_avg + 0.1) * img_seg_raw.shape[2])
        
        if z_start < 0:
          z_start = 0
        if z_end > (img_seg_raw.shape[2] - 1):
          z_end = img_seg_raw.shape[2] - 1

        img_ct = img_ct_raw[ int(img_ct_raw.shape[0] * 0.4) : int(img_ct_raw.shape[0] * 0.6) ,  int(img_ct_raw.shape[1] * 0.4) : int(img_ct_raw.shape[1] * 0.6), z_start: z_end]
        img_pet = img_pet_raw[ int(img_pet_raw.shape[0] * 0.4) : int(img_pet_raw.shape[0] * 0.6) ,  int(img_pet_raw.shape[1] * 0.4) : int(img_pet_raw.shape[1] * 0.6), z_start: z_end]
        img_seg = img_seg_raw[ int(img_seg_raw.shape[0] * 0.4) : int(img_seg_raw.shape[0] * 0.6) ,  int(img_seg_raw.shape[1] * 0.4) : int(img_seg_raw.shape[1] * 0.6), z_start: z_end]
        print("Shape: ", img_ct.shape)
        del img_ct_raw
        del img_pet_raw
        del img_seg_raw
        gc.collect()
        #registraion with ct images
        
        img_pet = self.registration(img_ct, img_pet)
        img_seg = self.registration(img_ct, img_seg)


        img_ct = torch.tensor(img_ct)
        img_pet = torch.tensor(img_pet)
        img_seg = torch.tensor(img_seg)
        img_ct = torch.nn.functional.interpolate(img_ct, 128)
        img_pet = torch.nn.functional.interpolate(img_pet, 128)
        img_seg = torch.nn.functional.interpolate(img_seg, 128)       
        return img_ct, img_pet, img_seg
    
    def registration(self, img1, img2):
    
        start = datetime.now()
        img1 = sitk.GetImageFromArray(img1)
        img2 = sitk.GetImageFromArray(img2)

        initial_transform = sitk.CenteredTransformInitializer(
            img1, 
            img2,  
            sitk.Euler3DTransform(),
            sitk.CenteredTransformInitializerFilter.GEOMETRY,
        )
     
        registration_method = sitk.ImageRegistrationMethod()

        # Similarity metric settings.
        registration_method.SetMetricAsCorrelation()
        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
        registration_method.SetMetricSamplingPercentage(0.01)

        registration_method.SetInterpolator(sitk.sitkLinear)

        # Optimizer settings.
        registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=10)
        registration_method.SetOptimizerScalesFromPhysicalShift()
        
        # Don't optimize in-place, we would possibly like to run this cell multiple times.
        registration_method.SetInitialTransform(initial_transform, inPlace= False)
        final_transform = registration_method.Execute(img1,img2)
        moving_resampled = sitk.Resample(
            img2,
            img1,
            final_transform,
            sitk.sitkLinear,
            0.0,
            img2.GetPixelID(),
        ) 

        final_image = sitk.GetArrayFromImage(moving_resampled)
        stop = datetime.now()
        return final_image
    
    
    def augmentation(self,data):
        
        transform = T.Compose([
        T.RandomApply([
            tio.transforms.RandomBiasField(p=0.25),
            tio.transforms.RandomGhosting(p=0.25),
            tio.transforms.RandomSpike(p=0.25),
            tio.transforms.RandomAffine(degrees=10, scales=0., translation=0., p=0.25)],
            p=0.8)
        ])

        x = data.unsqueeze(-1) # x is a random variable to store data
        x = transform(x)
        data = x.squeeze(-1)
        
        return data

      
    def get_tabular_data(self,path):
        
        # Load clinical data from Excel file
        clinical_data = pd.read_excel(path)  # Replace with the path to your Excel file
        clinical_data = clinical_data.drop(columns=['node'])
        # Separate numerical and categorical columns
        numerical_columns = clinical_data.select_dtypes(include='number')
        categorical_columns = clinical_data.select_dtypes(exclude='number')

        # Convert numerical columns to a numpy array
        numerical_data = numerical_columns.to_numpy()

        # Impute missing values in the numerical data
        imputer = KNNImputer(n_neighbors=5)
        numerical_data_imputed = imputer.fit_transform(numerical_data)

        # Convert categorical columns to a numpy array
        categorical_data = categorical_columns.values
        categorical_data = categorical_data.astype(str)

        # Encode categorical columns using one-hot encoding
        encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
        categorical_data_encoded = encoder.fit_transform(categorical_data)

        # Concatenate the imputed numerical data with the categorical data
        clinical_data_processed = np.concatenate([numerical_data_imputed, categorical_data_encoded], axis=1)

        # Convert the processed clinical data to a tensor
        x_clinical = torch.Tensor(clinical_data_processed)
        
        return x_clinical

    def get_clinical_data(self,path):
        
        # Load clinical data from Excel file
        clinical_data = pd.read_excel(path)  # Replace with the path to your Excel file
        clinical_data = clinical_data.drop(columns=['node'])
        # Separate numerical and categorical columns
        
        cat_colnames = ['Number', 'metastasis', 'T_yale', 'SUV_prostate', 'psa', 'SUV_prostate']
        clinical_columns = clinical_data.loc[:, cat_colnames]

        # Convert numerical columns to a numpy array
        numerical_data = clinical_columns.to_numpy()

        # Impute missing values in the numerical data
        imputer = KNNImputer(n_neighbors=5)
        numerical_data_imputed = imputer.fit_transform(numerical_data)

        # Convert categorical columns to a numpy array
        categorical_data = clinical_columns.values
        categorical_data = categorical_data.astype(str)

        # Encode categorical columns using one-hot encoding
        encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
        categorical_data_encoded = encoder.fit_transform(categorical_data)

        # Concatenate the imputed numerical data with the categorical data
        clinical_data_processed = np.concatenate([numerical_data_imputed, categorical_data_encoded], axis=1)

        # Convert the processed clinical data to a tensor
        x_clinical = torch.Tensor(clinical_data_processed)
        
        return x_clinical

### 3. Network

In [None]:

import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

#reference: https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/models/resnet.py#L31


def get_inplanes():
    return [64, 128, 256, 512]


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv1x1x1(in_planes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 block_inplanes,
                 drop_rate,
                 n_input_channels=1,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 shortcut_type='B',
                 widen_factor=1.0,
                 n_classes=1):
        super().__init__()
        self.drop_rate = drop_rate
        self.sigmoid = nn.Sigmoid()
        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        self.conv1 = nn.Conv3d(n_input_channels,
                               self.in_planes,
                               kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2),
                               padding=(conv1_t_size // 2, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       shortcut_type,
                                       stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)
        self.dropout = torch.nn.Dropout(drop_rate)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
                                out.size(3), out.size(4))
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

        out = torch.cat([out.data, zero_pads], dim=1)

        return out

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(self._downsample_basic_block,
                                     planes=planes * block.expansion,
                                     stride=stride)
            else:
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = self.dropout(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = torch.sigmoid(x)

        return x


def generate_model(model_depth, drop_rate, **kwargs):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), drop_rate, **kwargs)
    elif model_depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), drop_rate, **kwargs)
    elif model_depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), drop_rate, **kwargs)
    elif model_depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), drop_rate, **kwargs)
    elif model_depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), drop_rate, **kwargs)
    elif model_depth == 152:
        model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), drop_rate, **kwargs)
    elif model_depth == 200:
        model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), drop_rate, **kwargs)

    return model

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import KNNImputer

import numpy as np


class MultimodalFusionModel(nn.Module):
    def __init__(self):
        super(MultimodalFusionModel, self).__init__()

        # Part 1: ResNet for feature extraction from COVID-19 images
        self.resnet = generate_model(50, 0.2)
        

        # Part 2: FC layers for clinical information and feature concatenation
        self.fc_clinical = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.fc_concat = nn.Sequential(
            nn.Linear(2048 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # Part 3: Final FC layer for COVID-19 prediction
        self.fc_final = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_image, x_clinical):
        x_image, x_clinical
        # Part 1: 3-D ResNet for feature extraction 
        x_image = self.resnet.conv1(x_image)  
        x_image = self.resnet.bn1(x_image)
        x_image = self.resnet.relu(x_image)
        x_image = self.resnet.maxpool(x_image)

        x_image = self.resnet.layer1(x_image)
        x_image = self.resnet.layer2(x_image)
        x_image = self.resnet.layer3(x_image)
        x_image = self.resnet.layer4(x_image)

        x_image = self.resnet.avgpool(x_image)
        x_image = self.resnet.dropout(x_image)
        x_image = torch.flatten(x_image, 1)
        # Part 2: FC layers for clinical information and feature concatenation
        x_clinical = self.fc_clinical(x_clinical)
        x_fusion = torch.cat((x_image, x_clinical), dim=1)
        x_fusion = self.fc_concat(x_fusion)

        # Part 3: Final FC layer for COVID-19 prediction
        output = self.fc_final(x_fusion)
        output = self.sigmoid(output)

        return output

### 4. AverageMeter

In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
      ##### START OF YOUR CODE #####
      self.sum = []
      self.size = []
      self.avg = 0
      self.count = 0

      self.reset()

    def reset(self): 
      self.sum = []
      self.size = []
      self.avg = 0
      self.count = 0

    def update(self, value, size):
   
      self.sum.append(value*size)   
      self.size.append(size)

      sum1 = 0  
      size1 = 0    
      for i in range(0, len(self.sum)):
        sum1 += self.sum[i]
        size1 += self.size[i]
        self.count = size1
        self.avg = sum1/self.count


      ##### END OF YOUR CODE #####

In [None]:
avg_meter = AverageMeter()
avg_meter.update(100, 5)
avg_meter.update(50, 5)

print(avg_meter.avg, avg_meter.count)

75.0 10


### 5. Training

In [None]:
device = torch.device('cuda')
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score
from torchmetrics.classification import BinaryAUROC 

def training(train_loader, model, criterion, optimizer):
  # Let's start by initializing our AverageMeters.
  avg_meters = {'loss': AverageMeter(),
                'acc': AverageMeter()}

  metric_collection = MetricCollection({ #reference:https://pub.towardsai.net/improve-your-model-validation-with-torchmetrics-b457d3954dcd
    'prec': Precision(task='binary', average='macro').to(device),
    'rec': Recall(task='binary', average='macro').to(device),
    'f1': F1Score(task='binary').to(device),
    'auc': BinaryAUROC().to(device)})              

    #load data
  for data in train_loader:
      
    inputs, tabular_data, labels = data
    inputs, tabular_data, labels = inputs.to(device), tabular_data.to(device), labels.to(device)

    #zero gradients
    optimizer.zero_grad()

    output = model(inputs, tabular_data).to(device)


    #Metrics  
    loss = criterion(output.float(), labels.float())

    loss_value = loss.item()
    total = labels.size(0)
    correct_count = (output.round() == labels).sum().item()
    accuracy = correct_count / total

    avg_meters['loss'].update(loss_value, len(data))
    avg_meters['acc'].update(accuracy, len(data))

    metric_collection.update(output, labels)

    #gradients and learning weights
    loss.backward()
    optimizer.step()
  
  val_metrics = metric_collection.compute()
  metric_collection.reset()  

  return dict([('loss', avg_meters['loss'].avg),
                ('acc', avg_meters['acc'].avg),
                ('f1', val_metrics['f1'].item()),
                ('prec', val_metrics['prec'].item()),
                ('rec', val_metrics['rec'].item()),
                ('auc', val_metrics['auc'].item())])

def validation(val_loader, model, criterion):
  avg_meters = {'loss': AverageMeter(),
                'acc': AverageMeter()}
  metric_collection = MetricCollection({
    'prec': Precision(task='binary', average='macro').to(device),
    'rec': Recall(task='binary', average='macro').to(device),
    'f1': F1Score(task='binary').to(device),
    'auc': BinaryAUROC().to(device)})                

  model.eval()
 
  with torch.no_grad():
     #load data
    for data in val_loader:
      
      inputs, tabular_data, labels = data
      inputs, tabular_data, labels = inputs.to(device), tabular_data.to(device), labels.to(device)

      
      output = model(inputs, tabular_data).to(device)

      #Metrics  
      loss = criterion(output.float(), labels.float())

      loss_value = loss.item()
      total = labels.size(0)
      correct_count = (output.round() == labels).sum().item()
      accuracy = correct_count / total

      avg_meters['loss'].update(loss_value, len(data))
      avg_meters['acc'].update(accuracy, len(data))

      metric_collection.update(output, labels)

  val_metrics = metric_collection.compute()
  metric_collection.reset()

  return dict([('loss', avg_meters['loss'].avg),
                ('acc', avg_meters['acc'].avg),
                ('f1', val_metrics['f1'].item()),
                ('prec', val_metrics['prec'].item()),
                ('rec', val_metrics['rec'].item()),
                ('auc', val_metrics['auc'].item())])

### 6. Run test

In [None]:
## Define Global
initial_config = {
    "data_dir": image_dir,
    "image_size": 128,
    "train_batch_size": 64,
    "val_batch_size": 32,
    "test_batch_size": 1,
    "activation": "relu",
    "drop_rate": .2,
    "optimizer": "Adam",
    "learning_rate": 1e-5,
    "l2_reg": 1e-6, # Weight decay
    "nb_epoch": 10,
    "early_stopping": 8, # trigger value for early stopping
}
run_dir = '/content/drive/MyDrive/CMPE_runs/cmpe58p_project'

learning_rate  =  initial_config["learning_rate"]
drop_rate  =  initial_config["drop_rate"]
weight_decay  =  initial_config["l2_reg"]
activation  =  initial_config["activation"]
optimizer_name  =  initial_config["optimizer"]
patience = initial_config["early_stopping"]


In [None]:
import pandas as pd

df = pd.read_csv(csv_dir)
data_dir = initial_config['data_dir']
data_dict_train = ImageDataset('ct_seg', data_dir, df, "train")
data_dict_val = ImageDataset('ct_seg', data_dir, df, "val")
print("Data is loaded")  

Train size:  187
Read image for  0268 , Partition:  ct_seg




Shape:  (103, 103, 74)
Read image for  0005 , Partition:  ct_seg
Shape:  (103, 103, 73)
Read image for  0100 , Partition:  ct_seg
Shape:  (103, 103, 73)
Read image for  0210 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0211 , Partition:  ct_seg
Shape:  (103, 103, 83)
Read image for  0104 , Partition:  ct_seg
Shape:  (103, 103, 73)
Read image for  0041 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0231 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0077 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0128 , Partition:  ct_seg
Shape:  (103, 103, 82)
Read image for  0182 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0016 , Partition:  ct_seg
Shape:  (103, 103, 72)
Read image for  0283 , Partition:  ct_seg
Shape:  (103, 103, 107)
Read image for  0168 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0101 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0176 , Partition:  ct_seg
Shape:  (103, 103, 82)
R



Shape:  (103, 103, 74)
Read image for  0114 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0278 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0081 , Partition:  ct_seg
Shape:  (103, 103, 67)
Read image for  0032 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0152 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0040 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0123 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0281 , Partition:  ct_seg
Shape:  (103, 103, 90)
Read image for  0125 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0085 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0307 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0253 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0013 , Partition:  ct_seg
Shape:  (103, 103, 72)
Read image for  0194 , Partition:  ct_seg
Shape:  (103, 103, 91)
Read image for  0098 , Partition:  ct_seg
Shape:  (103, 103, 65)
Re

### Fusion Test
It can be changed dynamically. Also network can be added dynamically. Create ImageDataset just once.

In [None]:
from torch import optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights, ResNet50_Weights
from torch import nn
import os
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.model_selection import KFold
from torch.utils.data import Subset

def run_test(seg_type):
  from torch.utils.data import DataLoader
  import torch.optim as optim
  import os
  from google.colab import drive
  profile_dir = '/content/drive/MyDrive/CMPE_runs/'
  run_name = "cmpe58p_project"
  full_dir = profile_dir + run_name

  if os.path.exists(full_dir) == False:
    os.makedirs(full_dir)

  print("learning_rate: ", learning_rate, ", drop_rate: ", drop_rate, ", weight_decay: ", weight_decay, ", activation: ", activation, ", optimizer: ", optimizer_name)
  
  model = MultimodalFusionModel().to(device)
  model = torch.nn.DataParallel(model)

  optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.1, patience=2,min_lr=0.001)
  criterion = nn.BCELoss()
  trigger_times = 0

  #dataloader dictionary
  data_dict_train.partition = seg_type
  # dataloader_full = DataLoader(data_dict_train, 1)
  print("Data is loaded, training starting")
  kfold = KFold(n_splits=5, shuffle=True)
  scores = []
  f1s = []
  precs = []
  recs = []
  for fold, (train, test) in enumerate(kfold.split(data_dict_train)):
    print('Fold [%d/%d]' % (fold+1, 5))

    # Create train and test subsets
    train_subset = Subset(data_dict_train, train)
    test_subset = Subset(data_dict_train, test)
    train_loader = DataLoader(train_subset, 1)
    test_loader = DataLoader(test_subset, 1)
    # Reset the model for each fold
    model = None
    print("learning_rate: ", learning_rate, ", drop_rate: ", drop_rate, ", weight_decay: ", weight_decay, ", activation: ", activation, ", optimizer: ", optimizer_name)
  
    model = MultimodalFusionModel().to(device)
    model = torch.nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.1, patience=2,min_lr=0.001)
    criterion = nn.BCELoss()
    trigger_times = 0
    best_accuracy = 0
    f1 = 0
    prec = 0
    rec = 0
    for e in range(initial_config["nb_epoch"]):  # loop over the dataset multiple times    
      train_dict = training(train_loader, model, criterion, optimizer)
      val_dict = validation(test_loader, model, criterion)
      print("Epoch: ", e, ', train_loss: ', train_dict['loss'], ", train_acc: ",  train_dict['acc'], ',val_loss: ', val_dict['loss'], ", val_acc: ",  val_dict['acc'])
      print("Epoch: ", e, ', train_f1_score: ', train_dict['f1'], ", train_precision: ",  train_dict['prec'], ',train_recall: ', train_dict['rec'], ", train_auc: ",  train_dict['auc'])
      print("Epoch: ", e, ', val_f1_score: ', val_dict['f1'], ", val_precision: ",  val_dict['prec'], ',val_recall: ', val_dict['rec'], ", val_auc: ",  val_dict['auc'])


      scheduler.step(val_dict['loss'])
      if(((val_dict['acc']) - 0.01) >= best_accuracy):
          
          rec = val_dict['rec']
          prec = val_dict['prec']
          f1 = val_dict['f1']
          best_accuracy = (val_dict['acc'])
          print("Save Model, best_accuracy: ", best_accuracy)
          torch.save(model, os.path.join(run_dir, "model.pt"))
      
      if val_dict['acc'] > best_accuracy + 0.01:
        trigger_times = 0  



    
      if(((val_dict['acc']) - 0.01) < best_accuracy):
        trigger_times += 1
        if trigger_times == patience:
          print("Early stopping in epoch: ", e, "best_accuracy: ", best_accuracy, "current_acc: ", val_dict['acc'])
          break
      scores.append(best_accuracy)
      f1s.append(f1)
      precs.append(prec)
      recs.append(rec)

  load_model = torch.load(os.path.join(run_dir, "model.pt"))
  # test_dict = validation(dataloader_val[], load_model, criterion)
  # print("test_acc: ", test_dict['acc'], "test_loss: ", test_dict['loss'])
  mean_f1 = np.array(f1s).mean()
  std_f1 = np.array(f1s).std()
  mean_prec = np.array(precs).mean()
  std_prec = np.array(precs).std()
  mean_rec = np.array(recs).mean()
  std_rec = np.array(recs).std()
  mean_auc = np.array(scores).mean()
  std_auc = np.array(scores).std()
  print("F1 Score Mean: ", mean_f1, ', F1 Score Std: ', std_f1, ", Precision Mean: ",  mean_prec, ',Precision Std: ', std_prec)
  print("Recall Mean: ", mean_rec, "Recall Std: ", std_rec, "AUC Mean: ", mean_auc, "AUC Std: ", std_auc)

In [None]:
run_test('ct')

In [None]:
run_test('pet')

In [None]:
run_test('ct_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.6904420698649131 , train_acc:  0.6845637583892618 ,val_loss:  0.6503450799929468 , val_acc:  0.7368421052631579
Epoch:  0 , train_f1_score:  0.11320754885673523 , train_precision:  1.0 ,train_recall:  0.05999999865889549 , train_auc:  0.5002020001411438
Epoch:  0 , val_f1_score:  0.375 , val_precision:  1.0 ,val_recall:  0.23076923191547394 , val_auc:  0.4738461375236511
Save Model, best_accuracy:  0.7368421052631579
Epoch:  1 , train_loss:  0.6146966489219425 , train_acc:  0.7181208053691275 ,val_loss:  0.4652065403918785 , val_acc:  0.8421052631578947
Epoch:  1 , train_f1_score:  0.4324324429035187 , train_precision:  0.6666666865348816 ,train_recall:  0.3199999928474426 , train_auc:  0.704242467880249
Epoch:  1 

In [None]:
run_test('ct_mult_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.6296973705616773 , train_acc:  0.6577181208053692 ,val_loss:  0.8605739148823839 , val_acc:  0.7368421052631579
Epoch:  0 , train_f1_score:  0.4137931168079376 , train_precision:  0.4864864945411682 ,train_recall:  0.36000001430511475 , train_auc:  0.6412121057510376
Epoch:  0 , val_f1_score:  0.375 , val_precision:  1.0 ,val_recall:  0.23076923191547394 , val_auc:  0.5169230699539185
Save Model, best_accuracy:  0.7368421052631579
Epoch:  1 , train_loss:  0.5700314129208959 , train_acc:  0.6912751677852349 ,val_loss:  0.41058373790369124 , val_acc:  0.8421052631578947
Epoch:  1 , train_f1_score:  0.4390243887901306 , train_precision:  0.5625 ,train_recall:  0.36000001430511475 , train_auc:  0.6993939280509949
Epoch

In [None]:
run_test('pet_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.5804490283281611 , train_acc:  0.7583892617449665 ,val_loss:  0.5892073521880727 , val_acc:  0.7105263157894737
Epoch:  0 , train_f1_score:  0.6326530575752258 , train_precision:  0.6458333134651184 ,train_recall:  0.6200000047683716 , train_auc:  0.7612121105194092
Epoch:  0 , val_f1_score:  0.4761904776096344 , val_precision:  0.625 ,val_recall:  0.38461539149284363 , val_auc:  0.7230769395828247
Save Model, best_accuracy:  0.7105263157894737
Epoch:  1 , train_loss:  0.5018892890004903 , train_acc:  0.7785234899328859 ,val_loss:  0.5005916923186496 , val_acc:  0.7631578947368421
Epoch:  1 , train_f1_score:  0.6117647290229797 , train_precision:  0.7428571581840515 ,train_recall:  0.5199999809265137 , train_auc:  

In [None]:
run_test('pet_ct')

In [None]:
run_test('mixed')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.597491244366705 , train_acc:  0.7114093959731543 ,val_loss:  0.5184268156115553 , val_acc:  0.7631578947368421
Epoch:  0 , train_f1_score:  0.5376344323158264 , train_precision:  0.5813953280448914 ,train_recall:  0.5 , train_auc:  0.6913130879402161
Epoch:  0 , val_f1_score:  0.5263158082962036 , val_precision:  0.8333333134651184 ,val_recall:  0.38461539149284363 , val_auc:  0.7723076939582825
Save Model, best_accuracy:  0.7631578947368421
Epoch:  1 , train_loss:  0.5521226927693527 , train_acc:  0.7315436241610739 ,val_loss:  0.5417314867271145 , val_acc:  0.6578947368421053
Epoch:  1 , train_f1_score:  0.5121951103210449 , train_precision:  0.65625 ,train_recall:  0.41999998688697815 , train_auc:  0.72767674922

### CNN Test

In [None]:
device = torch.device('cuda')
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score
from torchmetrics.classification import BinaryAUROC 

def training(train_loader, model, criterion, optimizer):
  # Let's start by initializing our AverageMeters.
  avg_meters = {'loss': AverageMeter(),
                'acc': AverageMeter()}

  metric_collection = MetricCollection({ #reference:https://pub.towardsai.net/improve-your-model-validation-with-torchmetrics-b457d3954dcd
    'prec': Precision(task='binary', average='macro').to(device),
    'rec': Recall(task='binary', average='macro').to(device),
    'f1': F1Score(task='binary').to(device),
    'auc': BinaryAUROC().to(device)})              

    #load data
  for data in train_loader:
      
    inputs, tabular_data, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    #zero gradients
    optimizer.zero_grad()

    output = model(inputs).to(device)


    #Metrics  
    loss = criterion(output.float(), labels.float())

    loss_value = loss.item()
    total = labels.size(0)
    correct_count = (output.round() == labels).sum().item()
    accuracy = correct_count / total

    avg_meters['loss'].update(loss_value, len(data))
    avg_meters['acc'].update(accuracy, len(data))

    metric_collection.update(output, labels)

    #gradients and learning weights
    loss.backward()
    optimizer.step()
  
  val_metrics = metric_collection.compute()
  metric_collection.reset()  

  return dict([('loss', avg_meters['loss'].avg),
                ('acc', avg_meters['acc'].avg),
                ('f1', val_metrics['f1'].item()),
                ('prec', val_metrics['prec'].item()),
                ('rec', val_metrics['rec'].item()),
                ('auc', val_metrics['auc'].item())])

def validation(val_loader, model, criterion):
  avg_meters = {'loss': AverageMeter(),
                'acc': AverageMeter()}
  metric_collection = MetricCollection({
    'prec': Precision(task='binary', average='macro').to(device),
    'rec': Recall(task='binary', average='macro').to(device),
    'f1': F1Score(task='binary').to(device),
    'auc': BinaryAUROC().to(device)})                

  model.eval()
 
  with torch.no_grad():
     #load data
    for data in val_loader:
      
      inputs, tabular_data, labels = data
      inputs, labels = inputs.to(device), labels.to(device)

      
      output = model(inputs).to(device)

      #Metrics  
      loss = criterion(output.float(), labels.float())

      loss_value = loss.item()
      total = labels.size(0)
      correct_count = (output.round() == labels).sum().item()
      accuracy = correct_count / total

      avg_meters['loss'].update(loss_value, len(data))
      avg_meters['acc'].update(accuracy, len(data))

      metric_collection.update(output, labels)

  val_metrics = metric_collection.compute()
  metric_collection.reset()

  return dict([('loss', avg_meters['loss'].avg),
                ('acc', avg_meters['acc'].avg),
                ('f1', val_metrics['f1'].item()),
                ('prec', val_metrics['prec'].item()),
                ('rec', val_metrics['rec'].item()),
                ('auc', val_metrics['auc'].item())])

In [None]:
from torch import optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights, ResNet50_Weights
from torch import nn
import os
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.model_selection import KFold
from torch.utils.data import Subset

def run_test(seg_type):
  from torch.utils.data import DataLoader
  import torch.optim as optim
  import os
  from google.colab import drive
  profile_dir = '/content/drive/MyDrive/CMPE_runs/'
  run_name = "cmpe58p_project"
  full_dir = profile_dir + run_name

  if os.path.exists(full_dir) == False:
    os.makedirs(full_dir)

  print("learning_rate: ", learning_rate, ", drop_rate: ", drop_rate, ", weight_decay: ", weight_decay, ", activation: ", activation, ", optimizer: ", optimizer_name)
  
  model = generate_model(50, 0.2).to(device)
  model = torch.nn.DataParallel(model)

  optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.1, patience=2,min_lr=0.001)
  criterion = nn.BCELoss()
  trigger_times = 0

  #dataloader dictionary
  data_dict_train.partition = seg_type
  # dataloader_full = DataLoader(data_dict_train, 1)
  print("Data is loaded, training starting")
  kfold = KFold(n_splits=5, shuffle=True)
  scores = []
  f1s = []
  precs = []
  recs = []
  for fold, (train, test) in enumerate(kfold.split(data_dict_train)):
    print('Fold [%d/%d]' % (fold+1, 5))

    # Create train and test subsets
    train_subset = Subset(data_dict_train, train)
    test_subset = Subset(data_dict_train, test)
    train_loader = DataLoader(train_subset, 1)
    test_loader = DataLoader(test_subset, 1)
    # Reset the model for each fold
    model = None
    print("learning_rate: ", learning_rate, ", drop_rate: ", drop_rate, ", weight_decay: ", weight_decay, ", activation: ", activation, ", optimizer: ", optimizer_name)
  
    model = generate_model(50, 0.2).to(device)
    model = torch.nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.1, patience=2,min_lr=0.001)
    criterion = nn.BCELoss()
    trigger_times = 0
    best_accuracy = 0
    f1 = 0
    prec = 0
    rec = 0
    for e in range(initial_config["nb_epoch"]):  # loop over the dataset multiple times    
      train_dict = training(train_loader, model, criterion, optimizer)
      val_dict = validation(test_loader, model, criterion)
      print("Epoch: ", e, ', train_loss: ', train_dict['loss'], ", train_acc: ",  train_dict['acc'], ',val_loss: ', val_dict['loss'], ", val_acc: ",  val_dict['acc'])
      print("Epoch: ", e, ', train_f1_score: ', train_dict['f1'], ", train_precision: ",  train_dict['prec'], ',train_recall: ', train_dict['rec'], ", train_auc: ",  train_dict['auc'])
      print("Epoch: ", e, ', val_f1_score: ', val_dict['f1'], ", val_precision: ",  val_dict['prec'], ',val_recall: ', val_dict['rec'], ", val_auc: ",  val_dict['auc'])


      scheduler.step(val_dict['loss'])
      if(((val_dict['acc']) - 0.01) >= best_accuracy):
          
          rec = val_dict['rec']
          prec = val_dict['prec']
          f1 = val_dict['f1']
          best_accuracy = (val_dict['acc'])
          print("Save Model, best_accuracy: ", best_accuracy)
          torch.save(model, os.path.join(run_dir, "model.pt"))
      
      if val_dict['acc'] > best_accuracy + 0.01:
        trigger_times = 0  



    
      if(((val_dict['acc']) - 0.01) < best_accuracy):
        trigger_times += 1
        if trigger_times == patience:
          print("Early stopping in epoch: ", e, "best_accuracy: ", best_accuracy, "current_acc: ", val_dict['acc'])
          break
      scores.append(best_accuracy)
      f1s.append(f1)
      precs.append(prec)
      recs.append(rec)

  load_model = torch.load(os.path.join(run_dir, "model.pt"))
  # test_dict = validation(dataloader_val[], load_model, criterion)
  # print("test_acc: ", test_dict['acc'], "test_loss: ", test_dict['loss'])
  mean_f1 = np.array(f1s).mean()
  std_f1 = np.array(f1s).std()
  mean_prec = np.array(precs).mean()
  std_prec = np.array(precs).std()
  mean_rec = np.array(recs).mean()
  std_rec = np.array(recs).std()
  mean_auc = np.array(scores).mean()
  std_auc = np.array(scores).std()
  print("F1 Score Mean: ", mean_f1, ', F1 Score Std: ', std_f1, ", Precision Mean: ",  mean_prec, ',Precision Std: ', std_prec)
  print("Recall Mean: ", mean_rec, "Recall Std: ", std_rec, "AUC Mean: ", mean_auc, "AUC Std: ", std_auc)

In [None]:
run_test('ct_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.6968355592865272 , train_acc:  0.5704697986577181 ,val_loss:  0.6709546807565188 , val_acc:  0.7631578947368421
Epoch:  0 , train_f1_score:  0.20000000298023224 , train_precision:  0.3076923191547394 ,train_recall:  0.14814814925193787 , train_auc:  0.446393758058548
Epoch:  0 , val_f1_score:  0.0 , val_precision:  0.0 ,val_recall:  0.0 , val_auc:  0.26819923520088196
Save Model, best_accuracy:  0.7631578947368421
Epoch:  1 , train_loss:  0.7307991857496684 , train_acc:  0.6375838926174496 ,val_loss:  0.4965349327540025 , val_acc:  0.8421052631578947
Epoch:  1 , train_f1_score:  0.3863636255264282 , train_precision:  0.5 ,train_recall:  0.31481480598449707 , train_auc:  0.5943470001220703
Epoch:  1 , val_f1_score: 

In [None]:
run_test('ct_mult_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.6569252082165455 , train_acc:  0.6644295302013423 ,val_loss:  0.8412543899918857 , val_acc:  0.6052631578947368
Epoch:  0 , train_f1_score:  0.0 , train_precision:  0.0 ,train_recall:  0.0 , train_auc:  0.40820956230163574
Epoch:  0 , val_f1_score:  0.0 , val_precision:  0.0 ,val_recall:  0.0 , val_auc:  0.3246377110481262
Save Model, best_accuracy:  0.6052631578947368
Epoch:  1 , train_loss:  0.7367859329815779 , train_acc:  0.6577181208053692 ,val_loss:  0.6971753746628039 , val_acc:  0.631578947368421
Epoch:  1 , train_f1_score:  0.3199999928474426 , train_precision:  0.4444444477558136 ,train_recall:  0.25 , train_auc:  0.5693069696426392
Epoch:  1 , val_f1_score:  0.5 , val_precision:  0.5384615659713745 ,val_

In [None]:
run_test('pet_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.6575936291041791 , train_acc:  0.6644295302013423 ,val_loss:  0.7284844235370034 , val_acc:  0.6578947368421053
Epoch:  0 , train_f1_score:  0.0 , train_precision:  0.0 ,train_recall:  0.0 , train_auc:  0.44727271795272827
Epoch:  0 , val_f1_score:  0.0 , val_precision:  0.0 ,val_recall:  0.0 , val_auc:  0.23692309856414795
Save Model, best_accuracy:  0.6578947368421053
Epoch:  1 , train_loss:  0.7307806586559187 , train_acc:  0.6174496644295302 ,val_loss:  0.5636892855951661 , val_acc:  0.7105263157894737
Epoch:  1 , train_f1_score:  0.3132530152797699 , train_precision:  0.39393940567970276 ,train_recall:  0.25999999046325684 , train_auc:  0.5373737215995789
Epoch:  1 , val_f1_score:  0.2666666805744171 , val_pre

In [None]:
run_test('mixed')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.7617426696639733 , train_acc:  0.4697986577181208 ,val_loss:  1.0132714196255332 , val_acc:  0.631578947368421
Epoch:  0 , train_f1_score:  0.2752293646335602 , train_precision:  0.25 ,train_recall:  0.30612245202064514 , train_auc:  0.4306122660636902
Epoch:  0 , val_f1_score:  0.0 , val_precision:  0.0 ,val_recall:  0.0 , val_auc:  0.1607142835855484
Save Model, best_accuracy:  0.631578947368421
Epoch:  1 , train_loss:  0.781084462800282 , train_acc:  0.6174496644295302 ,val_loss:  0.6253805442860252 , val_acc:  0.7368421052631579
Epoch:  1 , train_f1_score:  0.19718310236930847 , train_precision:  0.3181818127632141 ,train_recall:  0.1428571492433548 , train_auc:  0.5
Epoch:  1 , val_f1_score:  0.5 , val_precisi

### Clinical

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import KNNImputer

import numpy as np


class MultimodalFusionModel(nn.Module):
    def __init__(self):
        super(MultimodalFusionModel, self).__init__()

        # Part 1: ResNet for feature extraction from COVID-19 images
        self.resnet = generate_model(50, 0.2)
        

        # Part 2: FC layers for clinical information and feature concatenation
        self.fc_clinical = nn.Sequential(
            nn.Linear(5, 64),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.fc_concat = nn.Sequential(
            nn.Linear(2048 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # Part 3: Final FC layer for COVID-19 prediction
        self.fc_final = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_image, x_clinical):
        x_image, x_clinical
        # Part 1: 3-D ResNet for feature extraction 
        x_image = self.resnet.conv1(x_image)  
        x_image = self.resnet.bn1(x_image)
        x_image = self.resnet.relu(x_image)
        x_image = self.resnet.maxpool(x_image)

        x_image = self.resnet.layer1(x_image)
        x_image = self.resnet.layer2(x_image)
        x_image = self.resnet.layer3(x_image)
        x_image = self.resnet.layer4(x_image)

        x_image = self.resnet.avgpool(x_image)
        x_image = self.resnet.dropout(x_image)
        x_image = torch.flatten(x_image, 1)
        # Part 2: FC layers for clinical information and feature concatenation
        x_clinical = self.fc_clinical(x_clinical)
        x_fusion = torch.cat((x_image, x_clinical), dim=1)
        x_fusion = self.fc_concat(x_fusion)

        # Part 3: Final FC layer for COVID-19 prediction
        output = self.fc_final(x_fusion)
        output = self.sigmoid(output)

        return output

In [None]:
import pandas as pd

df = pd.read_csv(csv_dir)
data_dir = initial_config['data_dir']
data_dict_train = ImageDatasetClinical('ct_seg', data_dir, df, "train")
data_dict_val = ImageDatasetClinical('ct_seg', data_dir, df, "val")
print("Data is loaded")  

Train size:  187




Read image for  0268 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0005 , Partition:  ct_seg
Shape:  (103, 103, 73)
Read image for  0100 , Partition:  ct_seg
Shape:  (103, 103, 73)
Read image for  0210 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0211 , Partition:  ct_seg
Shape:  (103, 103, 83)
Read image for  0104 , Partition:  ct_seg
Shape:  (103, 103, 73)
Read image for  0041 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0231 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0077 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0128 , Partition:  ct_seg
Shape:  (103, 103, 82)
Read image for  0182 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0016 , Partition:  ct_seg
Shape:  (103, 103, 72)
Read image for  0283 , Partition:  ct_seg
Shape:  (103, 103, 107)
Read image for  0168 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0101 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0176 , P



Shape:  (103, 103, 74)
Read image for  0114 , Partition:  ct_seg
Shape:  (103, 103, 66)
Read image for  0278 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0081 , Partition:  ct_seg
Shape:  (103, 103, 67)
Read image for  0032 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0152 , Partition:  ct_seg
Shape:  (103, 103, 74)
Read image for  0040 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0123 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0281 , Partition:  ct_seg
Shape:  (103, 103, 90)
Read image for  0125 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0085 , Partition:  ct_seg
Shape:  (103, 103, 65)
Read image for  0307 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0253 , Partition:  ct_seg
Shape:  (103, 103, 75)
Read image for  0013 , Partition:  ct_seg
Shape:  (103, 103, 72)
Read image for  0194 , Partition:  ct_seg
Shape:  (103, 103, 91)
Read image for  0098 , Partition:  ct_seg
Shape:  (103, 103, 65)
Re

In [None]:
run_test('ct_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.5576787336498359 , train_acc:  0.697986577181208 ,val_loss:  0.8435406260453563 , val_acc:  0.6578947368421053
Epoch:  0 , train_f1_score:  0.5161290168762207 , train_precision:  0.5333333611488342 ,train_recall:  0.5 , train_auc:  0.7320544719696045
Epoch:  0 , val_f1_score:  0.43478259444236755 , val_precision:  0.625 ,val_recall:  0.3333333432674408 , val_auc:  0.5246376991271973
Save Model, best_accuracy:  0.6578947368421053
Epoch:  1 , train_loss:  0.48948470205847044 , train_acc:  0.8120805369127517 ,val_loss:  0.762664062513283 , val_acc:  0.7105263157894737
Epoch:  1 , train_f1_score:  0.6410256624221802 , train_precision:  0.8333333134651184 ,train_recall:  0.5208333134651184 , train_auc:  0.79084157943725

In [None]:
run_test('pet_seg')

learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Data is loaded, training starting
Fold [1/5]
learning_rate:  1e-05 , drop_rate:  0.2 , weight_decay:  1e-06 , activation:  relu , optimizer:  Adam
Epoch:  0 , train_loss:  0.5973182066725602 , train_acc:  0.7516778523489933 ,val_loss:  0.6723664538248589 , val_acc:  0.6842105263157895
Epoch:  0 , train_f1_score:  0.602150559425354 , train_precision:  0.6829268336296082 ,train_recall:  0.5384615659713745 , train_auc:  0.7242268323898315
Epoch:  0 , val_f1_score:  0.3333333432674408 , val_precision:  0.4285714328289032 ,val_recall:  0.27272728085517883 , val_auc:  0.6666666865348816
Save Model, best_accuracy:  0.6842105263157895
Epoch:  1 , train_loss:  0.5320723220869331 , train_acc:  0.7785234899328859 ,val_loss:  0.6320238262460885 , val_acc:  0.7631578947368421
Epoch:  1 , train_f1_score:  0.6206896305084229 , train_precision:  0.7714285850524902 ,train_recall:  0.5192307829856873 , 

In [None]:
run_test('mixed')

In [None]:
run_test('ct_mult_seg')