# TransUnet feature Merge using different method and Pretrained on the new Dataset

In [None]:
import os
import io, os,sys,types
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
from torch import autograd
from PIL import Image
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from IPython.display import clear_output

import torch.utils.data

import random
import scipy.io

import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from utils import self_attention,Fusion_block,CCA
import seaborn as sns
from sklearn.metrics import confusion_matrix

from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import average_precision_score


from sklearn.metrics import classification_report

import optuna
import shutil
import pickle

import torch.optim
import torch.utils.data
import pretrainedmodels
from vit_pytorch import ViT
from pytorch_pretrained_vit import ViT
from networks import vit_seg_configs

from networks.model import get_Model

# from lion_pytorch import Lion
import timm
from tqdm import tqdm

import nni
from nni.experiment import Experiment


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import warnings
warnings.filterwarnings("ignore")

### Model information

In [None]:
param = {
    'traindata_dir': '/path/to/train_img_dir',
    'valdata_dir': '/path/to/val_img_dir',
    'testdata_dir': '/path/to/test_img_dir',
    'views_img_dir_list': ['view_1_img_dir_name', ..., 'view_v_img_dir_name'],
    'modal_dir': '/path/to/model_parameters/',
    'log_dir': '/path/to/training_log/',

    'probs_loss_weight': 2,
    'mass_mode': 'ebu',
    'num_view': 5,
    'num_classes': 2,
    'ccc_loss_beta': 0.25,
    'ccc_loss_so_weight': 2,
    'contrastive_center_loss_lambda': 0.5,
    'contrastive_center_loss_delta': 0.001,
    'center_learning_rate': 0.00001,
    'train_batch_size': 16, 
    'val_batch_size': 64,
    'test_batch_size': 32,
    'num_epochs': 200,
    'patient_epoch': 20,
    'l_lr': 0.00002,
    'initial_lr': 0.00001,
    'learning_rate': 0.00005,
    'S_S_GAP': [0.1,0.1,0.1,0.1,0.1],
    'w_momentum': 0.9,
    'w_weight_decay': 0.00001,
    'workers': 4,
    'seed': 42,
    'hidden_dim': 512,
    'n_layers': 2,
    'dropout': 0.5,
    'loss-weight': torch.tensor(['class_0_weight',...,'class_c_weight'], dtype=torch.float32),
    'seg-pretrained-weight':'/path/to/seg/model_pretrained_weight/U_Net.pth',
    'seg_model_name':'U_Net_mid_output'#{transUnet:'transUnet',AttentionUnet:'AttentionUnet',transUnet_feature_fuse}
}

In [None]:
class DSF(nn.Module):
    def __init__(self,mass_mode='ebu'):
        super(DSF, self).__init__()
        self.mass_mode = mass_mode
        self.eps = 1e-8

    def ebu_from_logits(self, logits):
        """
        logits -> evidence(e) -> Dirichlet(alpha) -> EbU: (u, b)
        input: logits (B, c)
        output:
          u: (B, 1)
          b: (B, c)  且  sum(b, -1) + u = 1
        """
        # evidence >= 0
        e = F.softplus(logits)                    # (B, c)
        alpha = e + 1.0                           # (B, c)
        T = alpha.sum(dim=-1, keepdim=True)       # (B, 1)
        c = logits.size(-1)
        u = c / (T + self.eps)                         # (B, 1)
        b = (alpha - 1.0) / (T + self.eps)             # (B, c)
        # tensor clip
        u = torch.clamp(u, self.eps, 1.0 - self.eps)
        b = torch.clamp(b, self.eps, 1.0 - self.eps)
        # tensor normalize
        s = b.sum(dim=-1, keepdim=True) + u
        b = b / (s + self.eps)                         # (B, c)
        u = u / (s + self.eps)                         # (B, 1)
        return b,u
    
    def to_mass(self, b, u):
        """
        convert (u,b) to valid BBA: (m_single, m_theta)
        - "ebu":         m_i = b_i, m_Theta = u
        - "discounted":  first m_i~=(1-u)*b_i, m_Theta~=u, then normalize by Z=(1-u+u^2)
        """
        if self.mass_mode == "ebu":
            return b, u
        if self.mass_mode == "discounted":
            m_single = (1 - u) * b
            m_theta = u
            m_single = torch.clamp(m_single, self.eps, 1.0 - self.eps)
            m_theta  = torch.clamp(m_theta,  self.eps, 1.0 - self.eps)
            s = m_single.sum(dim=-1, keepdim=True) + m_theta
            m_single = m_single / (s + self.eps)
            m_theta = m_theta / (s + self.eps)
            return m_single, m_theta
        
    def ds_fuse(self, b_list, u_list):
        """
        multi-source DS fusion (all sources enter at once):
          U = ∏ u_n
          M_i = ∏ (u_n + b^{(n)}_i)
          Q = Σ_i M_i - (C-1) * U
          m*(Θ) = U / Q
          m*({H_i}) = (M_i - U) / Q
        return: m_star(single class quality, BxC), u_star(Bx1)
        """
        # stack to (V,B,C) / (V,B,1)
        m_stack = torch.stack(b_list, dim=0)        # (V,B,C) 
        u_stack = torch.stack(u_list, dim=0)        # (V,B,1)
        # U = ∏ u_n
        U = u_stack.prod(dim=0)                     # (B,1)

        # M_i = ∏ (u_n + b^{(n)}_i)
        term = u_stack + m_stack                    # (V,B,C)
        M = term.prod(dim=0)                        # (B,C)

        # Q = Σ_i M_i - (C-1) * U
        B, C = M.shape
        Q = M.sum(dim=-1, keepdim=True) - (C - 1) * U   # (B,1)
        Q = torch.clamp(Q, min=self.eps)
 
        m_theta = U / Q                          # (B,1)
        m_star  = (M - U) / Q                    # (B,C)

        # tensor clip and normalize
        m_star  = torch.clamp(m_star, self.eps, 1.0 - self.eps)
        m_theta = torch.clamp(m_theta, self.eps, 1.0 - self.eps)

        Z = m_star.sum(dim=-1, keepdim=True) + m_theta
        m_star  = m_star  / (Z + self.eps)
        m_theta = m_theta / (Z + self.eps)
        return m_star, m_theta

    def forward(self, x):
        """
        input:
          logits: {view_name: feats(B,D_v)}
        output:
          probs: (B,C) with p_i = m*({H_i})
          extras: diagnostic information
        """
        b_list = []
        u_list = []
        for logits in x:
            b, u = self.ebu_from_logits(logits)
            b, u = self.to_mass(b, u)
            b_list.append(b)
            u_list.append(u)

        m_fused, u_fused = self.ds_fuse(b_list, u_list)
        probs = m_fused

        return probs


In [None]:
class EyeMVNet(nn.Module):
    def __init__(self, convNeXt_raw, convNeXt_seg, seg_Net, mass_mode='ebu', num_view=5,num_class=2,hidden_dim=512, n_layers=2, dropout=0.5):
        super(EyeMVNet, self).__init__()

        self.num_view = num_view

        self.seg_Net = seg_Net
        convNeXt_raw.head.fc = nn.Linear(1024,1024)
        convNeXt_seg.head.fc = nn.Linear(1024,1024)
        self.convNeXt_raw_stem = convNeXt_raw.stem
        self.convNeXt_seg_stem = convNeXt_seg.stem
        
        self.Unet_stage_0_feature_transform = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_1_feature_transform = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_2_feature_transform = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_3_feature_transform = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1)

        self.convNeXt_seg_stage0 = convNeXt_seg.stages[0]
        self.convNeXt_seg_stage1 = convNeXt_seg.stages[1]
        self.convNeXt_seg_stage2 = convNeXt_seg.stages[2]
        self.convNeXt_seg_stage3 = convNeXt_seg.stages[3]

        self.convNeXt_raw_stage0 = convNeXt_raw.stages[0]
        self.convNeXt_raw_stage1 = convNeXt_raw.stages[1]
        self.convNeXt_raw_stage2 = convNeXt_raw.stages[2]
        self.convNeXt_raw_stage3 = convNeXt_raw.stages[3]

        self.convNeXt_raw_head = convNeXt_raw.head
        self.convNeXt_seg_head = convNeXt_seg.head

        self.convNeXt_raw_linear1 = nn.Linear(1024,512)
        self.convNeXt_raw_relu1 = nn.GELU()
        self.convNeXt_raw_linear2 = nn.Linear(512,512)
        self.convNeXt_raw_relu2 = nn.GELU()
        # self.convNeXt_raw_linear3 = nn.Linear(512,2)

        self.convNeXt_seg_linear1 = nn.Linear(1024,512)
        self.convNeXt_seg_relu1 = nn.GELU()
        self.convNeXt_seg_linear2 = nn.Linear(512,512)
        self.convNeXt_seg_relu2 = nn.GELU()
        # self.convNeXt_seg_linear3 = nn.Linear(512,2)

        self.shared_linear1 = nn.Linear(1024,hidden_dim)
        self.shared_relu1 = nn.GELU()

        self.view_mlps_s = nn.ModuleList()
        for i in range(num_view):
            layers = []
            for i in range(n_layers):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.GELU())
                # layers.append(nn.Dropout(dropout))

            self.view_mlps_s.append(nn.Sequential(*layers))

        self.view_mlps_o = nn.ModuleList()
        for i in range(num_view):
            layers = []
            for i in range(n_layers):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.GELU())
                # layers.append(nn.Dropout(dropout))

            self.view_mlps_o.append(nn.Sequential(*layers))

        self.view_weights_mlps = nn.ModuleList()
        for i in range(num_view):
            layers = []
            for i in range(n_layers):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim,1))
            layers.append(nn.Sigmoid())
            self.view_weights_mlps.append(nn.Sequential(*layers))


        self.views_last_linear_layers = nn.ModuleList()
        for i in range(num_view):
            layers = []
            layers.append(nn.Linear(2*hidden_dim,hidden_dim))
            layers.append(nn.GELU())
            self.views_last_linear_layers.append(nn.Sequential(*layers))

        self.views_cls_layers = nn.ModuleList()
        for i in range(num_view):
            self.views_cls_layers.append(nn.Linear(hidden_dim,num_class))

        self.DSF = DSF(mass_mode=mass_mode)
        
        

    def forward(self,multi_view_images):

        shared_net_output = []
        s_net_output = []
        o_net_output = []
        shared_net_output_weights = []
        final_view_output = []
        view_probs = []

        for x in multi_view_images:

            seg,unet_feature_map0,unet_feature_map1,unet_feature_map2,unet_feature_map3 = self.seg_Net(x)

            seg_mask = seg[:,1,:,:]
            seg_mask = seg_mask.unsqueeze(1)
            seg_mask = seg_mask.repeat(1,3,1,1)
            seg_raw = x*seg_mask

            convNeXt_raw_output = self.convNeXt_raw_stem(x)
            convNeXt_seg_output = self.convNeXt_seg_stem(seg_raw)

            convNeXt_raw_stage0_feature = self.convNeXt_raw_stage0(convNeXt_raw_output)
            convNeXt_seg_stage0_feature = self.convNeXt_seg_stage0(convNeXt_seg_output)
            Unet_stage_0_feature = self.Unet_stage_0_feature_transform(unet_feature_map0)
            convNeXt_seg_stage1_input = F.sigmoid(Unet_stage_0_feature) * convNeXt_seg_stage0_feature

            convNeXt_raw_stage1_feature = self.convNeXt_raw_stage1(convNeXt_raw_stage0_feature)
            convNeXt_seg_stage1_feature = self.convNeXt_seg_stage1(convNeXt_seg_stage1_input)
            Unet_stage_1_feature = self.Unet_stage_1_feature_transform(unet_feature_map1)
            convNeXt_seg_stage2_input = F.sigmoid(Unet_stage_1_feature) * convNeXt_seg_stage1_feature

            convNeXt_raw_stage2_feature = self.convNeXt_raw_stage2(convNeXt_raw_stage1_feature)
            convNeXt_seg_stage2_feature = self.convNeXt_seg_stage2(convNeXt_seg_stage2_input)
            Unet_stage_2_feature = self.Unet_stage_2_feature_transform(unet_feature_map2)
            convNeXt_seg_stage3_input = F.sigmoid(Unet_stage_2_feature) * convNeXt_seg_stage2_feature

            convNeXt_raw_stage3_feature = self.convNeXt_raw_stage3(convNeXt_raw_stage2_feature)
            convNeXt_seg_stage3_feature = self.convNeXt_seg_stage3(convNeXt_seg_stage3_input)
            Unet_stage_3_feature = self.Unet_stage_3_feature_transform(unet_feature_map3)
            convNeXt_seg_stage3_feature = F.sigmoid(Unet_stage_3_feature) * convNeXt_seg_stage3_feature
            
            convNeXt_raw_output = self.convNeXt_raw_head(convNeXt_raw_stage3_feature)
            convNeXt_seg_output = self.convNeXt_seg_head(convNeXt_seg_stage3_feature)


            convNeXt_raw_output = self.convNeXt_raw_linear1(convNeXt_raw_output)
            convNeXt_raw_output = self.convNeXt_raw_relu1(convNeXt_raw_output)
            convNeXt_raw_output = self.convNeXt_raw_linear2(convNeXt_raw_output)
            convNeXt_raw_output = self.convNeXt_raw_relu2(convNeXt_raw_output)
            # convNeXt_raw_logits = self.convNeXt_raw_linear3(convNeXt_raw_output)

            convNeXt_seg_output = self.convNeXt_seg_linear1(convNeXt_seg_output)
            convNeXt_seg_output = self.convNeXt_seg_relu1(convNeXt_seg_output)
            convNeXt_seg_output = self.convNeXt_seg_linear2(convNeXt_seg_output)
            convNeXt_seg_output = self.convNeXt_seg_relu2(convNeXt_seg_output)
        # convNeXt_seg_logits = self.convNeXt_seg_linear3(convNeXt_seg_output)

            output = torch.cat((convNeXt_raw_output, convNeXt_seg_output),1)
            output = self.shared_linear1(output)
            output = self.shared_relu1(output)
            shared_net_output.append(output)

        for index in range(self.num_view):
            s_net_output.append(self.view_mlps_s[index](shared_net_output[index]))
            o_net_output.append(self.view_mlps_o[index](shared_net_output[index]))
            shared_net_output[index] = torch.cat((s_net_output[index],o_net_output[index]), 1)
            # shared_net_output[index] = shared_net_output[index] * shared_net_output_weights[index]
        # for i in range(len(s_net_output)):
        #     # 每个view的shared特征与同view的proprietary特征串联
        #     shared_net_output.append(torch.cat((s_net_output[i], torch.sigmoid(o_net_output[i])), 1))
            
        for index in range(self.num_view):
            shared_net_output[index] = self.views_last_linear_layers[index](shared_net_output[index])
            shared_net_output_weights.append(self.view_weights_mlps[index](shared_net_output[index]))
            shared_net_output[index] = shared_net_output[index] * shared_net_output_weights[index]
            final_view_output.append(self.views_cls_layers[index](shared_net_output[index]))
        
        
        probs = self.DSF(final_view_output)
            # view_probs.append(F.softmax(final_view_output[index], dim=1))

        # probs = torch.stack(view_probs, dim=0).mean(dim=0)

        
        return probs, final_view_output, shared_net_output, s_net_output, o_net_output
        
        

In [None]:
def specificity_score(y_true, y_pred, average='macro'):
    """
    计算多分类问题的宏平均特异性
    
    参数:
    y_true: 真实标签
    y_pred: 预测标签
    
    返回:
    macro_specificity: 宏平均特异性
    """
    # 获取混淆矩阵
    cm = confusion_matrix(y_true, y_pred)
    specificity_list = []
    
    for i in range(cm.shape[0]):
        # 计算当前类别的真负例和假正例
        tn = np.sum(cm) - np.sum(cm[i,:]) - np.sum(cm[:,i]) + cm[i,i]
        fp = np.sum(cm[:,i]) - cm[i,i]
        specificity = tn / (tn + fp)
        specificity_list.append(specificity)
        if average == 'macro':
            return np.mean(specificity_list)
        else:
            raise ValueError("Only 'macro' are supported.")

In [None]:
def test(model, loader, test_prob_dir, test_label_dir):

    with torch.no_grad():
        model.eval()
    
        for name, parameters in model.named_parameters():
            parameters.requires_grad = False

        num_correct, num_samples = 0, len(loader.dataset)

        test_batch_index = 0

        for imgs, y in loader:

            imgs_var= imgs.to(device)
            scores, _, _, _, _ = model(imgs_var)
            _, preds = scores.data.cpu().max(1)
            if test_batch_index == 0:
                y_test = y.cpu().detach().numpy()
                y_probs = scores.cpu().detach().numpy()
                y_pred = preds.numpy()
            else:
                y_test = np.append(y_test, y.cpu().detach().numpy())
                y_probs = np.vstack((y_probs, scores.cpu().detach().numpy()))
                y_pred = np.append(y_pred, preds.cpu().detach().numpy())

            num_correct += (preds == y).sum()
            test_batch_index += 1

        acc = float(num_correct) / num_samples

        print('Test accuracy: {:.2f}% ({}/{})'.format(
            100.*acc,
            num_correct,
            num_samples,
            ))

        y_test_bin = label_binarize(y_test, classes=np.arange(y_probs.shape[1]))
        AUC_score = roc_auc_score(y_test, y_probs, multi_class='ovr', average='macro')
        AUPRC_score = average_precision_score(y_test_bin, y_probs, average='macro')
        f1 = f1_score(y_test, y_pred, average='macro')
        accuracy = accuracy_score(y_test, y_pred)
        recall = recall_score(y_test, y_pred, average='macro')
        precision = precision_score(y_test, y_pred, average='macro')
        specificity = specificity_score(y_test, y_pred, average='macro')
        

        
        print('AUC:')
        print(AUC_score)
        print('AUPRC:')
        print(AUPRC_score)
        print('recall:')
        print(recall)
        print('specificity:')
        print(specificity)
        print('F1:')
        print(f1)
        print('precision:')
        print(precision)
        print('Accuracy:')
        print(accuracy)

    return AUC_score, AUPRC_score, recall, specificity, f1, precision, accuracy

###model name: transUnet:'transUnet', transUnet_feature_fuse:'transUnet_feature_fuse', AttentionUnet:'AttentionUnet', duckNet:'duckNet', Unet:'Unet'


In [None]:
class MultiviewDataset(Dataset):
    def __init__(self, views_img_dir_list, transform=None, target_transform=None):
        self.image_dir_list = []
        self.views_img_dir_list = views_img_dir_list
        self.transform = transform
        self.target_transform = target_transform
        self.disease_class_list = os.listdir(self.views_img_dir_list[0])
        for disease_class in self.disease_class_list:
            self.class_image_name_list = os.listdir(os.path.join(self.views_img_dir_list[0], disease_class))
            self.class_image_dir_list = [os.path.join(self.views_img_dir_list[0], disease_class,element) for element in self.class_image_name_list]
            self.image_dir_list = self.image_dir_list + self.DRTiD_image_dir_list

    def __len__(self):
        return len(self.image_dir_list)

    def __getitem__(self, idx):
        image_list = []
        first_view_image = Image.open(self.image_dir_list[idx]).convert('RGB')
        image_list.append(first_view_image)
        image_id = os.path.basename(self.image_dir_list[idx]).split('_')[0]
        image_class = os.path.basename(os.path.dirname(self.image_dir_list[idx]))
        for view_img_dir in self.views_img_dir_list[1:]:
            next_view_image = Image.open(os.path.join(view_img_dir, image_class, image_id + '.png')).convert('RGB')
            image_list.append(next_view_image)
        label = int(image_class.split('_')[1])
        label = torch.tensor(label, dtype=torch.long)
        if self.transform:
            for image in image_list:
                image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image_list, label

In [None]:
all_views_testdata_dir_list = []
for views_img_dir in range(param['views_img_dir_list']):
    all_views_testdata_dir_list.append(os.path.join(param['testdata_dir'], views_img_dir))

test_loader = torch.utils.data.DataLoader(
MultiviewDataset(all_views_testdata_dir_list,
                    transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
    
])),
batch_size=param['val_batch_size'], shuffle=False,
num_workers=param['workers'], pin_memory=True)


seg_model = get_Model(model_name=param['seg_model_name'])

ConvNext_raw = timm.create_model("hf_hub:timm/convnext_base.fb_in22k_ft_in1k", pretrained=True)
ConvNext_seg = timm.create_model("hf_hub:timm/convnext_base.fb_in22k_ft_in1k", pretrained=True)
net = EyeMVNet(ConvNext_raw,ConvNext_seg,seg_model,mass_mode=param['mass_mode'] ,num_view=param['num_view'])
net.to(device)
for name, parameters in net.named_parameters():
    parameters.requires_grad = False

net_dir = os.path.join(model_dir, 'EyeMVNet_model_parameter.pkl')
test_prob_dir = os.path.join(log_dir, 'EyeMVNet_test_prob.txt')
test_label_dir = os.path.join(log_dir, 'EyeMVNet_test_label.txt')
AUC_score, AUPRC_score, recall, specificity, f1, precision, accuracy = test(net, test_loader, test_prob_dir, test_label_dir)
print('AUC_score:', AUC_score)
print('AUPRC_score:', AUPRC_score)
print('recall:', recall)
print('specificity:', specificity)
print('f1:', f1)
print('precision:', precision)
print('accuracy:', accuracy)



