<a href="https://mingxia.web.unc.edu/" target="_parent"><img src="https://mingxia.web.unc.edu/wp-content/uploads/sites/12411/2020/12/logo_MagicLab-horizontal-4.png" alt="MAGIC Lab"/></a>

# **BAPM Pretext Model Training**
---

**Loading required libraries**
---

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
import sys, argparse
import enum
import time
import datetime
import random
import json
import multiprocessing
import os.path as osp
import pandas as pd
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import pylab as pl
import logging
import shutil
import tempfile
import gzip
from typing import Optional, Sequence, Tuple, Union
from urllib.request import urlretrieve
from PIL import Image

from pathlib import Path
from scipy import stats
from IPython import display
from tqdm import trange, tqdm

import copy
import pprint
import torchio as tio
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import L1Loss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix, roc_auc_score, matthews_corrcoef
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold, train_test_split, KFold
from sklearn.manifold import TSNE
from sklearn import svm

from neuroCombat import neuroCombat

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, Dataset, ImageDataset
from monai.networks.nets import VarAutoEncoder,ViTAutoEnc, AutoEncoder
from monai.networks.layers.convutils import calculate_out_shape, same_padding
from monai.networks.layers.factories import Act, Norm
from monai.networks.utils import one_hot
from monai.utils import set_determinism, first
from monai.utils.enums import MetricReduction
from monai.metrics import compute_hausdorff_distance, HausdorffDistanceMetric
from monai.losses import ContrastiveLoss, DiceLoss, DiceCELoss
from monai.transforms import (
    ConvertToMultiChannelBasedOnBratsClasses,
    AsDiscrete,
    Activations,
    AddChannelD,
    Compose,
    LoadImageD,
    ScaleIntensityD,
    EnsureTypeD,
    LoadImaged,
    Compose,
    CropForegroundd,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd,
    OneOf,
    ScaleIntensityRanged,
    RandSpatialCropSamplesd,
    RandCoarseDropoutd,
    RandCoarseShuffled
)


**Global Setting**
---

In [2]:
torch.manual_seed(0)
set_determinism(seed=0)

pin_memory = torch.cuda.is_available()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
im_shape = (1,176,208,176)
modelname = 'AE2DecPriorV1'
train_size = 0.8
val_interval = 5
pretrained = False
downsampled = False
samplespace = 1
if downsampled:
    samplespace = 2
pretrained_path = './pretrained/'
trained_path = './models/'
logdir_path = os.path.normpath('./log/')
if os.path.exists(logdir_path)==False:
    os.mkdir(logdir_path)
if os.path.exists(pretrained_path)==False:
    os.mkdir(pretrained_path)
if os.path.exists(trained_path)==False:
    os.mkdir(trained_path)    

if pretrained:
    savedir = trained_path+'Pretrained_'
else:
    savedir = trained_path

**File scanner**
---

In [3]:
class ScanFile(object):
    def __init__(self, directory, prefix=None, postfix=None):
        self.directory = directory
        self.prefix = prefix
        self.postfix = postfix

    def scan_files(self):
        files_list = []

        for dirpath, dirnames, filenames in os.walk(self.directory):
            ''''' 
            dirpath is a string, the path to the directory.   
            dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..'). 
            filenames is a list of the names of the non-directory files in dirpath. 
            '''
            for special_file in filenames:
                if self.postfix:
                    if special_file.endswith(self.postfix):
                        files_list.append(os.path.join(dirpath, special_file))
                elif self.prefix:
                    if special_file.startswith(self.prefix):
                        files_list.append(os.path.join(dirpath, special_file))
                else:
                    files_list.append(os.path.join(dirpath, special_file))

        return files_list

    def scan_subdir(self):
        subdir_list = []
        for dirpath, dirnames, files in os.walk(self.directory):
            subdir_list.append(dirpath)
        return subdir_list

**Preparing for data reading**
---

In [4]:
ADNI_dir   = './data/ADNI_iBEAT_linearReg/'
scan1 = ScanFile(ADNI_dir, postfix='n3.nii.gz')
scan2 = ScanFile(ADNI_dir, postfix='-seg.nii.gz')

ADNI_mri  = sorted(scan1.scan_files())
ADNI_seg  = sorted(scan2.scan_files())

**Define Dataset with Augamentation**
---

In [8]:
class MDD_Dataset(torch.utils.data.Dataset):
    def __init__(self, images, segs):
        subjects = []            
        for (image, seg) in zip(images, segs):
            subject = tio.Subject(
                mri=tio.ScalarImage(image),
                seg=tio.LabelMap(seg),
            )
            subjects.append(subject)
        self.transform()
        self.dataset = tio.SubjectsDataset(subjects, transform=self.aug_transform)
            
    def transform(self):
        MDD_FSL_t1_mpr_landmarks       = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 19.15165923, 43.90807752, 75.49163173, 100.])
        MDD_FSL_t1_gre_landmarks       = np.array([0., 0.38107146, 0.38107149, 0.38107149, 0.38107149, 0.38107149, 0.38107149, 0.3810715, 0.64192136, 48.8195904, 71.07814235, 89.23094854, 100.])
        MDD_FSL_t1_uchc_landmarks      = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 39.38873128, 52.8757862, 71.36841272, 100.])
        landmarks_dict1 = {'mri': MDD_FSL_t1_mpr_landmarks}
        landmarks_dict2 = {'mri': MDD_FSL_t1_gre_landmarks}
        landmarks_dict3 = {'mri': MDD_FSL_t1_uchc_landmarks}
        prior_augment = tio.Compose([
            #tio.ToCanonical(),
            tio.CropOrPad((176, 208, 176)),                                              # tight crop around brain
            tio.OneOf({                                # either
                tio.HistogramStandardization(landmarks_dict1),
                tio.HistogramStandardization(landmarks_dict2),
                tio.HistogramStandardization(landmarks_dict3),
            }),                                   
            tio.OneOf({                                # either
                tio.RandomBlur(std = (3,3,3)):1.0,                    # blur 25% of times
                tio.RandomNoise(std=6):1.0,            # Gaussian noise 25% of times
                tio.RandomBiasField():1.0,                # magnetic field inhomogeneity 30% of times
                tio.RandomMotion(degrees = 2, translation = 0, num_transforms = 2): 1.0,    # random motion artifact
            }),                                   
            tio.RescaleIntensity(percentiles=(0.5,99.5), out_min_max=(0, 1.0)),
        ])

        self.aug_transform = prior_augment


**Data Loader**
---

In [9]:
def get_loader(imagepaths,segpaths, batch_size=1):
    dataset = MDD_Dataset(images=imagepaths,segs=segpaths)
    loader = DataLoader(dataset.dataset,batch_size=batch_size,pin_memory=pin_memory)
    return loader

test_number =  200
train_number = int((len(ADNI_mri)-test_number)*train_size)-1
#print('train_number:'+str(train_number))
train_MRI_loader = get_loader(ADNI_mri[0:train_number],ADNI_seg[0:train_number], batch_size=2)
#test_loader  = get_loader(ADNI_mri[-test_number:],ADNI_seg[-test_number:])


**Model definition**
---

In [10]:
class AE2DecoderV1(AutoEncoder):
    def __init__(
        self,
        spatial_dims: int,
        in_shape: Sequence[int],
        out_channels: int,
        out_channels2: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        up_kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 0,
        inter_channels: Optional[list] = None,
        inter_dilations: Optional[list] = None,
        num_inter_units: int = 2,
        act: Optional[Union[Tuple, str]] = Act.PRELU,
        norm: Union[Tuple, str] = Norm.INSTANCE,
        dropout: Optional[Union[Tuple, str, float]] = None,
        bias: bool = True,
    ) -> None:

        self.in_channels, *self.in_shape = in_shape
        self.final_size = np.asarray(self.in_shape, dtype=int)

        super().__init__(
            spatial_dims,
            self.in_channels,
            out_channels,
            channels,
            strides,
            kernel_size,
            up_kernel_size,
            num_res_units,
            inter_channels,
            inter_dilations,
            num_inter_units,
            act,
            norm,
            dropout,
            bias,
        )

        padding = same_padding(self.kernel_size)

        for s in strides:
            self.final_size = calculate_out_shape(self.final_size, self.kernel_size, s, padding)  # type: ignore

        linear_size = int(np.product(self.final_size)) * self.encoded_channels

        decode_channel_list1 = list(channels[-2::-1]) + [out_channels]
        self.decode1, _ = self._get_decode_module(self.encoded_channels//2, decode_channel_list1, strides[::-1] or [1])
        decode_channel_list2 = list(channels[-2::-1]) + [out_channels2]
        self.decode2, _ = self._get_decode_module(self.encoded_channels//2, decode_channel_list2, strides[::-1] or [1])


    def decode_seg(self, x: torch.Tensor) -> torch.Tensor:
        x = self.decode2(x)
        x = torch.softmax(x, dim=1)
        return x

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.encode(x)
        feature = self.intermediate(x)
        feature1, feature2 = torch.chunk(feature,2,dim=1)
        mri  = self.decode1(feature1)
        seg  = self.decode2(feature2)
        return mri, seg

model_AE2Dec = AE2DecoderV1(
        spatial_dims=3,
        in_shape=im_shape,
        out_channels=1,#mri
        out_channels2=4,#seg
        channels=(64,128,256,512),
#        channels=(16,32,32,64),
        strides=(2,2,2,2),
#        inter_channels=(64, 64),
        inter_channels=(512,512*2),
#        inter_dilations=(1, 2),
        inter_dilations=(1, 1),
        num_inter_units=2,
    )


**Training Function**
---

In [11]:
L1Loss  = torch.nn.L1Loss(reduction='sum')
DiceLoss = DiceLoss(to_onehot_y=True, softmax=True)

def train(model, max_epochs, learning_rate, modelname):
    model.to(device)
    avg_train_losses = []
    avg_train_dice_losses =[]
    avg_train_mse_losses = []
    avg_train_l1_losses = []
    test_losses = []
    t = trange(max_epochs, leave=True, desc="step: 0,  epoch: 0,   average train loss: ?, test loss: ?")
    bestloss = sys.maxsize

    for epoch in t:
        model.train()
        mse_losses = []
        dice_losses = []
        l1_losses = []
        epoch_losses = []
        epoch_loss = 0
        l1_loss = 0
        dice_loss = 0
        step = 0
        for batch_data in train_MRI_loader:
            step +=1
            seg    = batch_data['seg'][tio.DATA].to(device).float()
            target = inputs = batch_data['mri'][tio.DATA].to(device).float()
            optimizer.zero_grad()

            recon_mri, recon_seg = model(inputs)
 
            diceLoss = 100*DiceLoss(recon_seg, seg)
            L1_Loss = 10000*L1Loss(recon_mri, target)/(176*208*176)
            loss = L1_Loss + diceLoss
            loss.backward()

            optimizer.step()
            
            dice_losses.append(diceLoss.item())
            l1_losses.append(L1_Loss.item())
            epoch_losses.append(loss.item())
            t.set_description(f"step: {step}, epoch: {epoch + 1}")
        scheduler.step()
        epoch_loss = np.mean(epoch_losses)
        dice_loss  = np.mean(dice_losses)
        l1_loss    = np.mean(l1_losses)
        avg_train_losses.append(epoch_loss)
        avg_train_dice_losses.append(dice_loss)
        avg_train_l1_losses.append(l1_loss)

        if (epoch+1)%val_interval == 0 and (epoch+1)> 0 and epoch_loss < bestloss:
            bestloss = epoch_loss
            AE_model_name = savedir+modelname +f"_mri_seg_epoch{25+epoch+1}_diceloss{dice_loss:.2f}_l1loss{l1_loss:.2f}.pth"
            print(AE_model_name)
            torch.save(model.state_dict(), AE_model_name)

        if len(avg_train_losses)>0:
            t.set_postfix(avg_train_losses=avg_train_losses[-1], avg_train_dice_losses=avg_train_dice_losses[-1], avg_train_l1_losses=avg_train_l1_losses[-1])
                                   
    return model, avg_train_losses, avg_train_dice_losses, avg_train_l1_losses

**Training Process**
---

In [13]:
max_epochs = 10
learning_rate = 1e-4
pretrained = False
preAE_model = './models/XXX.pth'
model = model_AE2Dec
if pretrained:
    print('loading: '+preAE_model)
    model.load_state_dict(torch.load(preAE_model), strict = True)
        
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
model, avg_train_losses, avg_train_dice_losses, avg_train_l1_losses = train(model, max_epochs, learning_rate, modelname+'_trainsize'+str(train_size))

modelname = modelname+'_mri_l1_seg_dice'+'_epoch'+str(max_epochs)
AE_model_name = savedir+modelname+'.pth'
print(AE_model_name)
torch.save(model.state_dict(), AE_model_name)
print("Training Finish!")

step: 4, epoch: 1:   0%|          | 0/10 [00:31<?, ?it/s]                                        


KeyboardInterrupt: 