# TransUnet feature Merge using different method and Pretrained on the new Dataset

In [None]:
import os
import io, os,sys,types
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from torch import autograd
from PIL import Image
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

from IPython.display import clear_output

import torch.utils.data

import random
import scipy.io

import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import seaborn as sns
from sklearn.metrics import confusion_matrix

from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

from sklearn.metrics import classification_report

import torch.optim
import torch.utils.data

from networks.model import get_Model

# from lion_pytorch import Lion
import timm
from tqdm import tqdm




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Model information

In [None]:
# Set the dataset path, model weight path
root_dir = '/path/to/data/'
model_dir = '/path/to/model_weight/'

dataset_categories = ['train_set','val_set','test_set'] #feel free to set trainset, valset, testset name

In [None]:
class SegImgNet(nn.Module):
    def __init__(self, convNeXt_raw, convNeXt_seg, seg_Net, num_classes=2, hidden_dim=512, n_layers=2, dropout=0.5):
        super(SegImgNet, self).__init__()
        self.seg_Net = seg_Net
        convNeXt_raw.head.fc = nn.Linear(1024,1024)
        convNeXt_seg.head.fc = nn.Linear(1024,1024)
        self.convNeXt_raw_stem = convNeXt_raw.stem
        self.convNeXt_seg_stem = convNeXt_seg.stem
        
        self.Unet_stage_0_feature_transform = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_1_feature_transform = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_2_feature_transform = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_3_feature_transform = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1)

        self.convNeXt_seg_stage0 = convNeXt_seg.stages[0]
        self.convNeXt_seg_stage1 = convNeXt_seg.stages[1]
        self.convNeXt_seg_stage2 = convNeXt_seg.stages[2]
        self.convNeXt_seg_stage3 = convNeXt_seg.stages[3]

        self.convNeXt_raw_stage0 = convNeXt_raw.stages[0]
        self.convNeXt_raw_stage1 = convNeXt_raw.stages[1]
        self.convNeXt_raw_stage2 = convNeXt_raw.stages[2]
        self.convNeXt_raw_stage3 = convNeXt_raw.stages[3]

        self.convNeXt_raw_head = convNeXt_raw.head
        self.convNeXt_seg_head = convNeXt_seg.head

        self.convNeXt_raw_linear1 = nn.Linear(1024,512)
        self.convNeXt_raw_relu1 = nn.GELU()
        self.convNeXt_raw_linear2 = nn.Linear(512,512)
        self.convNeXt_raw_relu2 = nn.GELU()
        # self.convNeXt_raw_linear3 = nn.Linear(512,2)

        self.convNeXt_seg_linear1 = nn.Linear(1024,512)
        self.convNeXt_seg_relu1 = nn.GELU()
        self.convNeXt_seg_linear2 = nn.Linear(512,512)
        self.convNeXt_seg_relu2 = nn.GELU()
        # self.convNeXt_seg_linear3 = nn.Linear(512,2)

        self.linear1 = nn.Linear(1024,hidden_dim)
        self.relu1 = nn.GELU()

        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))

        self.mlp_layers = nn.Sequential(*layers)

        
        self.linear2 = nn.Linear(hidden_dim,num_classes)
        self.softmax = nn.Softmax(dim=1)
        

    def forward(self,x):
        seg,unet_feature_map0,unet_feature_map1,unet_feature_map2,unet_feature_map3 = self.seg_Net(x)

        seg_mask = seg[:,1,:,:]
        seg_mask = seg_mask.unsqueeze(1)
        seg_mask = seg_mask.repeat(1,3,1,1)
        seg_raw = x*seg_mask

        convNeXt_raw_output = self.convNeXt_raw_stem(x)
        convNeXt_seg_output = self.convNeXt_seg_stem(seg_raw)

        convNeXt_raw_stage0_feature = self.convNeXt_raw_stage0(convNeXt_raw_output)
        convNeXt_seg_stage0_feature = self.convNeXt_seg_stage0(convNeXt_seg_output)
        Unet_stage_0_feature = self.Unet_stage_0_feature_transform(unet_feature_map0)
        convNeXt_seg_stage1_input = F.sigmoid(Unet_stage_0_feature) * convNeXt_seg_stage0_feature

        convNeXt_raw_stage1_feature = self.convNeXt_raw_stage1(convNeXt_raw_stage0_feature)
        convNeXt_seg_stage1_feature = self.convNeXt_seg_stage1(convNeXt_seg_stage1_input)
        Unet_stage_1_feature = self.Unet_stage_1_feature_transform(unet_feature_map1)
        convNeXt_seg_stage2_input = F.sigmoid(Unet_stage_1_feature) * convNeXt_seg_stage1_feature

        convNeXt_raw_stage2_feature = self.convNeXt_raw_stage2(convNeXt_raw_stage1_feature)
        convNeXt_seg_stage2_feature = self.convNeXt_seg_stage2(convNeXt_seg_stage2_input)
        Unet_stage_2_feature = self.Unet_stage_2_feature_transform(unet_feature_map2)
        convNeXt_seg_stage3_input = F.sigmoid(Unet_stage_2_feature) * convNeXt_seg_stage2_feature

        convNeXt_raw_stage3_feature = self.convNeXt_raw_stage3(convNeXt_raw_stage2_feature)
        convNeXt_seg_stage3_feature = self.convNeXt_seg_stage3(convNeXt_seg_stage3_input)
        Unet_stage_3_feature = self.Unet_stage_3_feature_transform(unet_feature_map3)
        convNeXt_seg_stage3_feature = F.sigmoid(Unet_stage_3_feature) * convNeXt_seg_stage3_feature
        
        convNeXt_raw_output = self.convNeXt_raw_head(convNeXt_raw_stage3_feature)
        convNeXt_seg_output = self.convNeXt_seg_head(convNeXt_seg_stage3_feature)


        convNeXt_raw_output = self.convNeXt_raw_linear1(convNeXt_raw_output)
        convNeXt_raw_output = self.convNeXt_raw_relu1(convNeXt_raw_output)
        convNeXt_raw_output = self.convNeXt_raw_linear2(convNeXt_raw_output)
        convNeXt_raw_output = self.convNeXt_raw_relu2(convNeXt_raw_output)
        # convNeXt_raw_logits = self.convNeXt_raw_linear3(convNeXt_raw_output)

        convNeXt_seg_output = self.convNeXt_seg_linear1(convNeXt_seg_output)
        convNeXt_seg_output = self.convNeXt_seg_relu1(convNeXt_seg_output)
        convNeXt_seg_output = self.convNeXt_seg_linear2(convNeXt_seg_output)
        convNeXt_seg_output = self.convNeXt_seg_relu2(convNeXt_seg_output)
        # convNeXt_seg_logits = self.convNeXt_seg_linear3(convNeXt_seg_output)


        
        output = torch.cat((convNeXt_raw_output, convNeXt_seg_output),1)
        output = self.linear1(output)
        output = self.relu1(output)
        output = self.mlp_layers(output)
        logits = self.linear2(output)
        probs = self.softmax(logits)

        return probs, logits
        
        

In [None]:
# three loss
def train(model, loss_fn, optimizer, scheduler, param, loader_train, loader_val, modal_dir):

    model.train()
    max_auc = 0

    checkpoint_epoch = 0
    patient_epoch = 0
    for epoch in range(param['num_epochs']):
        model.train()
        print('Starting epoch %d / %d' % (epoch + 1, param['num_epochs']))
    #         adjust_learning_rate(optimizer, epoch)
        epoch_loss = 0
        with torch.enable_grad():
            for t, (x, y) in enumerate(loader_train):

                x_var, y_var = x.to(device), y.to(device)
                _, scores = model(x_var)
                loss = loss_fn(scores, y_var)
                epoch_loss += loss.item()
                

                if (t + 1) % 100 == 0:
                    #print(loss.item())
                    print('t = %d, loss = %.8f' % (t + 1, loss.item()))
                optimizer.zero_grad()
                loss.backward()
        #             nn.utils.clip_grad_norm_(model.parameters(), config['w_grad_clip'])
                optimizer.step()

        model.eval()
        epoch_auc = validate(model, loader_val, param['S_S_GAP'])
        patient_epoch = patient_epoch + 1
        with torch.enable_grad():
            if epoch_auc > max_auc:
                max_auc = epoch_auc
                best_epopch = epoch
                torch.save(model.state_dict(), modal_dir)
                patient_epoch = 0
            scheduler.step(max_auc)

        if patient_epoch > param['patient_epoch']:
            print('reach patient epoch')
            break
    return checkpoint_epoch             


In [None]:
def validate(model, loader, gap):
    with torch.no_grad():
        model.eval()

        test_batch_index = 0
        for x, y in loader:

            x_var = x.to(device)
            scores, _ = model(x_var)
            _, preds = scores.data.cpu().max(1)
            if test_batch_index == 0:
                y_test = y.cpu().detach().numpy()
                y_probs = scores.cpu().detach().numpy()
                y_pred = preds.numpy()
            else:
                y_test = np.append(y_test, y.cpu().detach().numpy())
                y_probs = np.vstack((y_probs, scores.cpu().detach().numpy()))
                y_pred = np.append(y_pred, preds.cpu().detach().numpy())

            test_batch_index += 1

        # print(y_test.shape)
        y_probs = y_probs[...,1:]
        # print(y_probs.shape)
        # print(y_pred.shape)


        tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
        specificity = tn / (tn + fp)
        AUC_score = roc_auc_score(y_test, y_probs, multi_class='ovr')
        # f1 = f1_score(y_test, y_pred)
        # accuracy = accuracy_score(y_test, y_pred)
        recall = recall_score(y_test, y_pred)
        print('AUC_score:', AUC_score)
        print('recall:', recall)
        print('specificity:', specificity)

        if abs(recall - specificity) >= gap:
            return 0
        else:
            return AUC_score

In [None]:
def test(model, loader, test_prob_dir, test_label_dir):

    with torch.no_grad():
        model.eval()
    
        for name, parameters in model.named_parameters():
            parameters.requires_grad = False

        num_correct, num_samples = 0, len(loader.dataset)

        test_batch_index = 0

        for x, y in loader:

            x_var = x.to(device)
            scores, _ = model(x_var)
            _, preds = scores.data.cpu().max(1)
            if test_batch_index == 0:
                y_test = y.cpu().detach().numpy()
                y_probs = scores.cpu().detach().numpy()
                y_pred = preds.numpy()
            else:
                y_test = np.append(y_test, y.cpu().detach().numpy())
                y_probs = np.vstack((y_probs, scores.cpu().detach().numpy()))
                y_pred = np.append(y_pred, preds.cpu().detach().numpy())

            num_correct += (preds == y).sum()
            test_batch_index += 1

        acc = float(num_correct) / num_samples
        # print(y_test.shape)
        y_probs = y_probs[...,1:]
        # print(y_probs.shape)
        # print(y_pred.shape)

        np.savetxt(test_prob_dir, y_probs ,fmt='%.10e', delimiter=" ") 
        np.savetxt(test_label_dir, y_test ,fmt='%.10e', delimiter=" ")

        print('Test accuracy: {:.2f}% ({}/{})'.format(
            100.*acc,
            num_correct,
            num_samples,
            ))

        AUC_score = roc_auc_score(y_test, y_probs, multi_class='ovr')
        f1 = f1_score(y_test, y_pred)
        accuracy = accuracy_score(y_test, y_pred)
        recall = recall_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred)
        tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
#         sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        

        
        print('AUC:')
        print(AUC_score)
        print('recall:')
        print(recall)
        print('specificity:')
        print(specificity)
        print('F1:')
        print(f1)
        print('precision:')
        print(precision)
        print('Accuracy:')
        print(accuracy)

    return AUC_score, recall, specificity, f1, precision, accuracy

In [None]:
# Set the segmentation model name, segmentation model weight path, and classification model training hyperparameters
param = {
    'seg-pretrained-weight':'/path/to/pretrained_seg_model_weight',
    'seg_model_name':'U_Net_mid_output',

    'num_classes': 2,   
    'hidden_dim': 512,
    'n_layers': 2,
    'dropout': 0.5,
    
    'train_batch_size': 16, 
    'val_batch_size': 64,
    'test_batch_size': 64,
    'num_epochs': 200,
    'patient_epoch': 20,
    'learning_rate': 0.00005,
    'initial_lr': 0.000001,
    'S_S_GAP': 0.1,          # balance the sensitivity and specificity. Let them gap not go over S_S_GAP
    'w_weight_decay': 0.001,
    'workers': 4,
    'seed': 42,
    'loss-weight': [torch.tensor([1.15,1.85], dtype=torch.float32)],
}

###model name: transUnet:'transUnet', transUnet_feature_fuse:'transUnet_feature_fuse', AttentionUnet:'AttentionUnet', duckNet:'duckNet', Unet:'Unet'


In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.data = datasets.ImageFolder(root=root_dir).imgs
        
        
        self.imgs = [x[0] for x in self.data]
        self.label = [x[1] for x in self.data]
        self.transform = transform
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
       
        img_seg = self.imgs[idx]
        img_seg = Image.open(img_seg)

        
        img_seg= self.transform(img_seg)
        return img_seg,self.label[idx]

In [None]:
fold_AUC = []
fold_F1 = []
fold_Accuracy = []
fold_Recall = []
fold_Precision = []
# fold_Sensitivity = []
fold_Specificity = []
checkpoint_epoch_list = []
for i in range(5):
    param['trainset_dir'] = os.path.join(root_dir,  'fold_'+ str(i), dataset_categories[0])
    param['valset_dir'] = os.path.join(root_dir,  'fold_'+ str(i), dataset_categories[1])
    param['testset_dir'] = os.path.join(root_dir,  'fold_'+ str(i), dataset_categories[2])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
    CustomImageDataset(param['trainset_dir'], transforms.Compose([
        transforms.Resize((256,256)),
        transforms.RandomRotation(30),
        transforms.RandomVerticalFlip(0.2),
        transforms.RandomHorizontalFlip(0.2),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        
    ])),
    batch_size=param['train_batch_size'], shuffle=True,
    num_workers=param['workers'], pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
    CustomImageDataset(param['valset_dir'], transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor()
        
    ])),
    batch_size=param['val_batch_size'], shuffle=False,
    num_workers=param['workers'], pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
    CustomImageDataset(param['testset_dir'], transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor()
        
    ])),
    batch_size=param['test_batch_size'], shuffle=False,
    num_workers=param['workers'], pin_memory=True)

    model = get_Model(model_name=param['seg_model_name'])
    model.load_state_dict(torch.load(param['seg-pretrained-weight']))

    ConvNext_raw = timm.create_model("hf_hub:timm/convnext_base.fb_in22k_ft_in1k", pretrained=True)
    ConvNext_seg = timm.create_model("hf_hub:timm/convnext_base.fb_in22k_ft_in1k", pretrained=True)
    net = SegImgNet(ConvNext_raw,ConvNext_seg,model,num_classes=param['num_classes'],hidden_dim=param['hidden_dim'],n_layers=param['n_layers'],dropout=param['dropout'])
    net.to(device)
    for name, parameters in net.named_parameters():
        if "seg_Net" in name:
            parameters.requires_grad = False

    criterion = nn.CrossEntropyLoss(weight=param['loss-weight'].to(device))
    optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad, net.parameters()), lr=param['learning_rate'], weight_decay=param['w_weight_decay'])
    plateau_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5,threshold=1e-5, threshold_mode='abs', patience=10, min_lr=param['initial_lr'])
    # for name, parameters in net.named_parameters():
    #     if parameters.requires_grad:
    #         print(name)
    net_dir = os.path.join(model_dir, 'SegImgNet' + 'fold_'+ str(i) + dataset_categories[0]+'_model_parameter.pkl')
    if os.path.exists(net_dir):
        os.remove(net_dir)
    checkpoint_epoch = train(net, criterion, optimizer, plateau_scheduler, param, train_loader, val_loader, net_dir)
    checkpoint_epoch_list.append(checkpoint_epoch)
    net.load_state_dict(torch.load(net_dir))
    test_prob_dir = os.path.join(root_dir, 'fold_'+ str(i), 'SegImgNet'+dataset_categories[0]+'_test_prob.txt')
    test_label_dir = os.path.join(root_dir, 'fold_'+ str(i), 'SegImgNet'+dataset_categories[0]+'_test_label.txt')
    AUC_score, recall, specificity, f1, precision, accuracy = test(net, test_loader, test_prob_dir, test_label_dir)
    fold_AUC.append(AUC_score)
    fold_F1.append(f1)
    fold_Accuracy.append(accuracy)
    fold_Recall.append(recall)
    fold_Precision.append(precision)
    fold_Specificity.append(specificity)


In [None]:
print(len(fold_AUC))
print(len(fold_F1))
print(len(fold_Accuracy))
print(len(fold_Recall))
print(len(fold_Precision))
print(len(fold_Specificity))

AUC_result = np.array(fold_AUC)
F1_result = np.array(fold_F1)
Accuracy_result = np.array(fold_Accuracy)
Recall_result = np.array(fold_Recall)
Precision_result = np.array(fold_Precision)
Specificity_result = np.array(fold_Specificity)

print(checkpoint_epoch_list)
print('AUC_result')
print(AUC_result)
print('Recall_result')
print(Recall_result)
print('Specificity_result')
print(Specificity_result)
print('F1_result')
print(F1_result)
print('Precision_result')
print(Precision_result)
print('Accuracy_result')
print(Accuracy_result)


print('Avg_AUC:')
print(np.mean(AUC_result))
print('std_AUC:')
print(np.std(AUC_result))

print('Avg_Recall:')
print(np.mean(Recall_result))
print('std_Recall:')
print(np.std(Recall_result))

print('Avg_Specificity:')
print(np.mean(Specificity_result))
print('std_Specificity:')
print(np.std(Specificity_result))

print('Avg_F1:')
print(np.mean(F1_result))
print('std_F1:')
print(np.std(F1_result))

print('Avg_Precision:')
print(np.mean(Precision_result))
print('std_Precision:')
print(np.std(Precision_result))

print('Avg_Accuracy:')
print(np.mean(Accuracy_result))
print('std_Accuracy:')
print(np.std(Accuracy_result))
