# <a href="https://mingxia.web.unc.edu/" target="_parent"><img src="https://mingxia.web.unc.edu/wp-content/uploads/sites/12411/2020/12/logo_MagicLab-horizontal-4.png" alt="MAGIC Lab"/></a>

# **Downstream classification model finetuning based on pretrained pretext model of BAR**
---

**Loading required libraries**
---

In [10]:
import os
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "5" 
import sys, argparse
import enum
import time
import datetime
import random
import json
import multiprocessing
import os.path as osp
import pandas as pd
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import pylab as pl
import logging
import shutil
import tempfile
import gzip
from typing import Optional, Sequence, Tuple, Union
from urllib.request import urlretrieve
from PIL import Image

from pathlib import Path
from scipy import stats
from IPython import display
from tqdm import trange, tqdm
#from tqdm.notebook import tqdm

import copy
import pprint
import torchio as tio
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torchsummary
from torch.nn import L1Loss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix, roc_auc_score, matthews_corrcoef
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold, train_test_split, KFold
from sklearn.manifold import TSNE
from sklearn import svm

from neuroCombat import neuroCombat


import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, Dataset, ImageDataset
from monai.networks.nets import VarAutoEncoder,ViTAutoEnc, AutoEncoder, Classifier
from monai.networks.layers.convutils import calculate_out_shape, same_padding
from monai.networks.layers.factories import Act, Norm
from monai.utils import set_determinism, first
from monai.utils.enums import MetricReduction
from monai.metrics import compute_hausdorff_distance, HausdorffDistanceMetric
from monai.losses import ContrastiveLoss, DiceLoss, DiceCELoss
from monai.transforms import (
    ConvertToMultiChannelBasedOnBratsClasses,
    AsDiscrete,
    Activations,

#    AddChannel,
#    Compose,
#    RandRotate90,
#    Resize,
#    ScaleIntensity,
#    EnsureType
    AddChannelD,
    Compose,
    LoadImageD,
    ScaleIntensityD,
    EnsureTypeD,
    LoadImaged,
    Compose,
    CropForegroundd,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd,
    OneOf,
    ScaleIntensityRanged,
    RandSpatialCropSamplesd,
    RandCoarseDropoutd,
    RandCoarseShuffled
)

**Global Setting**
---

In [14]:
torch.manual_seed(0)
set_determinism(seed=0)

pin_memory = torch.cuda.is_available()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
pretrained_path = './pretrained/'
trained_path = './models/'
logdir_path = os.path.normpath('./log/')
if os.path.exists(logdir_path)==False:
    os.mkdir(logdir_path)
if os.path.exists(pretrained_path)==False:
    os.mkdir(pretrained_path)
if os.path.exists(trained_path)==False:
    os.mkdir(trained_path)

modelname = 'AEclf256';#resnet18,resnet34,resnet50,efficientnet-b0,DenseNet121
pretrained = False
downsampled = False
samplespace = 1
if downsampled:
    samplespace = 2 #2,4,8
max_epochs = 2
val_interval = 5
kfold = 5
categories = 2
Combat = False
if pretrained:
    savedir = trained_path+'Pretrained_'
else:
    savedir = trained_path

cpu


**Define Training Transforms**
---

In [12]:
class MDD_Dataset(torch.utils.data.Dataset):
    def __init__(self, images, segs, labels, augment=False):
        subjects = []            
        for (image, seg, label) in zip(images, segs, labels):
            #print(image_path,label)
            subject = tio.Subject(
                mri=tio.ScalarImage(image),
                #rec=tio.ScalarImage(recon),
                #seg=tio.LabelMap(seg),
                #roi=tio.LabelMap(roi),
                labels=int(label),
            )
            subjects.append(subject)
        self.transform()
        if augment:
            self.dataset = tio.SubjectsDataset(subjects, transform=self.aug_transform)
        else:
            self.dataset = tio.SubjectsDataset(subjects, transform=self.preproc_transform)
            
    def transform(self):
        #mni = tio.datasets.Colin27()
        get_foreground = tio.ZNormalization.mean
        ADNI_landmarks = np.array([0., 0., 0., 0., 0., 0., 0., 0., 1.64676642, 26.44374401, 47.65044424, 78.28128123, 100.])#ADNI_iBEAT
        landmarks_dict = {'mri': ADNI_landmarks}
        preprocess = tio.Compose([
            tio.ToCanonical(),
            #tio.CropOrPad((176, 208, 176)),                                              # tight crop around brain
            #tio.RescaleIntensity(percentiles=(0.,99.5), out_min_max=(0, 1.0)),
        ])
        augment = tio.Compose([
            tio.RandomAffine(scales=0.1,degrees=20,translation=5,isotropic=True,center='image'),       # random affine
        ])

        self.aug_transform = tio.Compose([preprocess, augment])
        self.preproc_transform = preprocess


def get_loader(imagepaths,segpaths, labels, batch_size=1, augment=False):
    dataset = MDD_Dataset(images=imagepaths,segs=segpaths, labels=labels, augment=augment)
    #batch_size = 6
    if augment:
        #batch_size = 6
        loader = DataLoader(dataset.dataset,batch_size=batch_size,num_workers=batch_size,shuffle=True,pin_memory=pin_memory,drop_last=False)
    else:
        loader = DataLoader(dataset.dataset,batch_size=batch_size,num_workers=batch_size,shuffle=False,pin_memory=pin_memory)    
    return loader

**Model Defination**
---

In [6]:
class AutoEncoderClassifier(AutoEncoder):
    def __init__(
        self,
        spatial_dims: int,
        in_shape: Sequence[int],
        out_channels: int,
        num_classes: int,
        channels: Sequence[int],
        strides: Sequence[int],
        out_channels2: int = 4,
        kernel_size: Union[Sequence[int], int] = 3,
        up_kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 0,
        inter_channels: Optional[list] = None,
        inter_dilations: Optional[list] = None,
        num_inter_units: int = 2,
        act: Optional[Union[Tuple, str]] = Act.PRELU,
        norm: Union[Tuple, str] = Norm.INSTANCE,
        dropout: Optional[Union[Tuple, str, float]] = None,
        bias: bool = True,
        use_sigmoid: bool = False,
        use_softmax: bool = True,
        latent_size: int = 64,
    ) -> None:

        self.in_channels, *self.in_shape = in_shape
        self.use_sigmoid = use_sigmoid

        self.latent_size = latent_size
        self.final_size = np.asarray(self.in_shape, dtype=int)

        super().__init__(
            spatial_dims,
            self.in_channels,
            out_channels,
            channels,
            strides,
            kernel_size,
            up_kernel_size,
            num_res_units,
            inter_channels,
            inter_dilations,
            num_inter_units,
            act,
            norm,
            dropout,
            bias,
        )

        padding = same_padding(self.kernel_size)

        for s in strides:
            self.final_size = calculate_out_shape(self.final_size, self.kernel_size, s, padding)  # type: ignore

        linear_size = int(np.product(self.final_size)) * self.encoded_channels
        
        self.clf = Classifier(in_shape = (self.channels[-1], *self.final_size), 
                              classes = num_classes, 
                              channels = (256,),
                              strides = (2,), 
                              #num_res_units = 0,
                              norm='INSTANCE', 
                              dropout=None, 
                              last_act=None)
            
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.encode(x)
        x = self.intermediate(x)
        y = self.clf(x)
        return y

**Training Process Defination**
---

In [7]:
L1Loss  = torch.nn.L1Loss(reduction='sum')
MSELoss = torch.nn.MSELoss(reduction='sum')
BCELoss = torch.nn.BCELoss(reduction='sum')
loss_function = torch.nn.CrossEntropyLoss()

dice_loss = DiceLoss(include_background=True ,to_onehot_y=True, softmax=True)

def train(model, max_epochs, learning_rate, savename):
    # Create optimiser
    #optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=1e-5)
    model.to(device)
    avg_train_losses = []
    avg_train_dice_losses =[]
    avg_train_mse_losses = []
    avg_train_kld_losses = []
    test_losses = []
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    threshhold = 0.6

    
    t = trange(max_epochs, leave=True, desc="step: 0,  epoch: 0,   average train loss: ?, test loss: ?")
    
    for epoch in t:
        model.train()
        mse_losses = []
        dice_losses = []
        kld_losses = []
        epoch_losses = []
        epoch_loss = 0
        mse_loss = 0
        kld_loss = 0
        step = 0
        for batch_data in train_loader:
            step +=1
            inputs = batch_data['mri'][tio.DATA].to(device).float()
            labels = batch_data['labels'].to(device)
            #print(labels)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            t.set_description(f"train step: {step}")
        scheduler.step()
        avg_train_losses.append(np.mean(epoch_losses))

        if True and (epoch+1)%val_interval == 0:
            # Test
            test_pred = []
            test_prob = []
            test_label = []
            test_predict= []
            model.eval()
            test_loss = []
            step = 0
            for test_data in test_loader:
                step +=1
                test_inputs = test_data['mri'][tio.DATA].to(device)
                test_labels = test_data['labels'].to(device)
                with torch.no_grad():
                    test_outputs = model(test_inputs.float())
                test_loss.append(loss_function(test_outputs, test_labels).item())
                if (epoch+1) >= 30:
                    outprob = F.softmax(test_outputs, dim=1)
                    test_prob.append(outprob)
                    test_pred.append(outprob.max(dim=1)[1])
                    test_predict.append(torch.where(outprob>0.5,torch.ones_like(outprob),torch.zeros_like(outprob)))
                    test_label.append(test_labels)
                t.set_description(f"test  step: {step}")
            test_losses.append(np.mean(test_loss))

            if (epoch+1) > 30:
                y_prob = torch.cat(test_prob, dim=0).cpu().detach().numpy()
                #y_pred = torch.cat(pred, dim=0).cpu().detach().numpy()
                y_true = torch.cat(test_label, dim=0).cpu().detach().numpy()
                y_pred =  np.array(torch.cat(test_predict, dim=0).cpu().detach().numpy()[...,1],dtype=int)
                epoch_report = classification_report(y_true, y_pred, output_dict = True,target_names=['CN', 'AD'],zero_division=0)

                metric = epoch_report['accuracy']
                if ((metric > best_metric+0.005) and (metric > threshhold)) or ((epoch+1)%10 == 0):
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    best_model_name = savename+'_epoch%02i_acc%.2f.pth' % (best_metric_epoch,best_metric)
                    torch.save(model.state_dict(), best_model_name)
                    print(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
                if (metric > 0.9):
                    break
        if len(test_losses)>0:
            t.set_postfix(avg_train_losses=avg_train_losses[-1], avg_test_losses=test_losses[-1])
    return model, best_model_name, avg_train_losses, test_losses

**Testing Process Defination**
---

In [8]:
def test(model, test_loader):
    min_v_loss = np.inf
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []


    print("-" * 10)
    model.to(device)
    model.eval()

    pred = []
    prob = []
    label = []
    predict= []

    num_correct = 0.0
    metric_count = 0
    for test_data in tqdm(test_loader):
        test_inputs = test_data['mri'][tio.DATA].to(device)
        test_labels = test_data['labels']
        with torch.no_grad():
            test_outputs = model(test_inputs.float())
        outprob = F.softmax(test_outputs, dim=1)
        prob.append(outprob)
        pred.append(outprob.max(dim=1)[1])
        predict.append(torch.where(outprob>0.5,torch.ones_like(outprob),torch.zeros_like(outprob)))
        label.append(test_labels)

    y_prob = torch.cat(prob, dim=0).cpu().detach().numpy()
    #y_pred = torch.cat(pred, dim=0).cpu().detach().numpy()
    y_true = torch.cat(label, dim=0).cpu().detach().numpy()
    y_pred =  np.array(torch.cat(predict, dim=0).cpu().detach().numpy()[...,1],dtype=int)

    AUC = roc_auc_score(y_true,y_prob[:,1]) *100
    MCC = matthews_corrcoef(y_true, y_pred)*100
    confus_mtrx = confusion_matrix(y_true, y_pred).ravel() #sample_weight=sw ravel:flatten
    #SPE  = confus_mtrx[0]/(confus_mtrx[0]+confus_mtrx[1])
    epoch_report = classification_report(y_true, y_pred, output_dict = True,target_names=['CND', 'CI'],zero_division=0)
    return AUC, MCC, epoch_report, confus_mtrx

**Model Training and Test**
---

In [13]:
im_shape = (1,176,208,176)
test_AUC =[]
test_ACC =[]
test_PRE =[]
test_SEN =[]
test_SPE =[]
test_F1s =[]
test_MCC =[]

finetuned = True
preAEmodelname = 'RecSeg'
savefile = savedir+modelname+'_'+preAEmodelname+'_fintune_cn_ci'
logfile = savefile.replace(trained_path,logdir_path)

out_ch = 1        
if preAEmodelname == 'RecSeg':
    preAE_model = './models/'+'RecSeg_AE512_seg_dice_epoch30_loss7.86'+'.pth'
    out_ch = 4    
if preAEmodelname == 'RecInput':
    preAE_model = './models/'+'RecInput_AE512_mri_l1loss_epoch30_loss170485.29'+'.pth'

for k in range(5):
    model = AutoEncoderClassifier(
        spatial_dims=3,
        in_shape=im_shape,
        num_classes=categories,
        out_channels=out_ch,
        channels=(64,128,256,512),
        strides=(2,2,2,2),
        inter_channels=(512, 512),
        num_inter_units=2,
    )    

    print('Loading pretrained pretext model!')
    if finetuned:
        pre_model = preAE_model
        print('loading pretrained auto encoder:'+pre_model)
        model.load_state_dict(torch.load(pre_model), strict = False)
        model_dict = model.state_dict()
        for name, p in model.named_parameters():
            if name.startswith('encode') or name.startswith('intermediate'): #or name.startswith('decode'):
                p.requires_grad = False

    train_listfile = './data/LLD/LLD_labels_cn_ci_train_balenced_'+str(k)+'.csv'
    test_listfile = './data/LLD/LLD_labels_cn_ci_test_'+str(k)+'.csv'
    train_csv_data = pd.read_csv(train_listfile)  # 读取训练数据
    test_csv_data = pd.read_csv(test_listfile)  # 读取训练数据

    train_mri_path_list = train_csv_data['mripath'].values.tolist()
    train_seg_path_list = train_mri_path_list
    train_label_list    = train_csv_data['labels'].values.tolist()
    test_mri_path_list = test_csv_data['mripath'].values.tolist()
    test_seg_path_list = test_mri_path_list
    test_label_list    = test_csv_data['labels'].values.tolist()

    train_loader = get_loader(train_mri_path_list,train_seg_path_list,train_label_list, batch_size=4, augment=True)
    test_loader  = get_loader(test_mri_path_list,test_seg_path_list,test_label_list, batch_size=2)

    print('Training start!')

    max_epochs = 90
    learning_rate = 1e-4

    optimizer = torch.optim.Adam(model.parameters(), learning_rate)
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    savename = savefile+'-fold'+str(k)
    model, best_model_name, avg_train_losses, test_losses = train(model, max_epochs, learning_rate, savename)

    print("Training Finish!")

    print('Testing start!')

    AUC, MCC, epoch_report, confus_mtrx = test(model, test_loader)
    #print(epoch_report)
    ACC = epoch_report['accuracy']*100
    SPE = epoch_report['CND']['recall']*100
    PRE = epoch_report['CI']['precision']*100
    SEN = epoch_report['CI']['recall']*100
    F1s = epoch_report['CI']['f1-score']*100
    test_AUC.append(AUC)
    test_ACC.append(ACC)
    test_PRE.append(PRE)
    test_SEN.append(SEN)
    test_SPE.append(SPE)
    test_F1s.append(F1s)
    test_MCC.append(MCC)
    with open(logfile+'.txt', 'a') as f:
        f.writelines('fold'+str(k)+'\n')
        f.writelines(best_model_name+'\n')
        f.writelines(f'AUC:{AUC:.2f}, ACC:{ACC:.2f}, PRE:{PRE:.2f}, SEN:{SEN:.2f}, SPE:{SPE:.2f}, F1S:{F1s:.2f}, MCC:{MCC:.2f}\n')
        
    print(f'AUC:{AUC:.2f}, ACC:{ACC:.2f}, PRE:{PRE:.2f}, SEN:{SEN:.2f}, SPE:{SPE:.2f}, F1S:{F1s:.2f}, MCC:{MCC:.2f}')
    print(f'{AUC:.2f}\t{ACC:.2f}\t{PRE:.2f}\t{SEN:.2f}\t{SPE:.2f}\t{F1s:.2f}\t{MCC:.2f}')

    print('Testing finish!')
mean_AUC =round(np.mean(test_AUC),2)
mean_ACC =round(np.mean(test_ACC),2)
mean_PRE =round(np.mean(test_PRE),2)
mean_SEN =round(np.mean(test_SEN),2)
mean_SPE =round(np.mean(test_SPE),2)
mean_F1s =round(np.mean(test_F1s),2)
mean_MCC =round(np.mean(test_MCC),2)

std_AUC =round(np.std(test_AUC,ddof=1),2)
std_ACC =round(np.std(test_ACC,ddof=1),2)
std_PRE =round(np.std(test_PRE,ddof=1),2)
std_SEN =round(np.std(test_SEN,ddof=1),2)
std_SPE =round(np.std(test_SPE,ddof=1),2)
std_F1s =round(np.std(test_F1s,ddof=1),2)
std_MCC =round(np.std(test_MCC,ddof=1),2)
log1 = f'AUC:{mean_AUC}\xB1{std_AUC}, ACC:{mean_ACC}\xB1{std_ACC}, PRE:{mean_PRE}\xB1{std_PRE}, SEN:{mean_SEN}\xB1{std_SEN}, SPE:{mean_SPE}\xB1{std_SPE}, F1S:{mean_F1s}\xB1{std_F1s}, MCC:{mean_MCC}\xB1{std_MCC}\n'
log2 = f'{mean_AUC}\xB1{std_AUC}\t{mean_ACC}\xB1{std_ACC}\t{mean_PRE}\xB1{std_PRE}\t{mean_SEN}\xB1{std_SEN}\t{mean_SPE}\xB1{std_SPE}\t{mean_F1s}\xB1{std_F1s}\t{mean_MCC}\xB1{std_MCC}\n'
              
log3 = f'&{mean_AUC}\xB1{std_AUC} &{mean_ACC}\xB1{std_ACC} &{mean_SEN}\xB1{std_SEN} &{mean_SPE}\xB1{std_SPE} &{mean_F1s}\xB1{std_F1s}\n'
log4 = f'&{mean_AUC:.1f}({std_AUC:.1f}) &{mean_ACC:.1f}({std_ACC:.1f}) &{mean_SEN:.1f}({std_SEN:.1f}) &{mean_SPE:.1f}({std_SPE:.1f}) &{mean_F1s:.1f}({std_F1s:.1f})\n'
              
with open(logfile+'.txt', 'a') as f:
    f.writelines(log1)
    f.writelines(log2)
    f.writelines(log3)
    f.writelines(log4)

print(log1)

Loading pretrained pretext model!
loading pretrained auto encoder:./models/RecSeg_AE512_seg_dice_epoch30_loss7.86.pth


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.