In [None]:
import os
import io, os,sys,types
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
from torch import autograd
from PIL import Image
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from IPython.display import clear_output

import torch.utils.data

import random
import scipy.io

import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from utils import self_attention,Fusion_block,CCA
import seaborn as sns
from sklearn.metrics import confusion_matrix

from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import average_precision_score


from sklearn.metrics import classification_report

import optuna
import shutil
import pickle

import torch.optim
import torch.utils.data
import pretrainedmodels
from vit_pytorch import ViT
from pytorch_pretrained_vit import ViT
from networks import vit_seg_configs

from networks.model import get_Model

# from lion_pytorch import Lion
import timm
from tqdm import tqdm

import nni
from nni.experiment import Experiment


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import warnings
warnings.filterwarnings("ignore")

### Model information

In [None]:
param = {
    'traindata_dir': '/path/to/train_img_dir',
    'valdata_dir': '/path/to/val_img_dir',
    'testdata_dir': '/path/to/test_img_dir',
    'modal_dir': '/path/to/model_parameters/',
    'log_dir': '/path/to/training_log/',
    'num_classes': 2,
    'contrastive_center_loss_lambda': 0.5,
    'contrastive_center_loss_delta': 0.001,
    'center_learning_rate': 0.00001,
    'train_batch_size': 16, 
    'val_batch_size': 64,
    'test_batch_size': 32,
    'num_epochs': 200,
    'patient_epoch': 20,
    'l_lr': 0.00002,
    'initial_lr': 0.00001,
    'learning_rate': 0.00005,
    'S_S_GAP': [0.1,0.1,0.1,0.1,0.1],
    'w_momentum': 0.9,
    'w_weight_decay': 0.00001,
    'workers': 4,
    'seed': 42,
    'hidden_dim': 512,
    'n_layers': 2,
    'dropout': 0.5,
    'loss-weight': torch.tensor(['class_0_weight',...,'class_c_weight'], dtype=torch.float32),
    'seg-pretrained-weight':'/path/to/seg/model_pretrained_weight/U_Net.pth',
    'seg_model_name':'U_Net_mid_output'#{transUnet:'transUnet',AttentionUnet:'AttentionUnet',transUnet_feature_fuse}
}

In [None]:
class SegImgNet(nn.Module):
    def __init__(self, convNeXt_raw, convNeXt_seg, seg_Net, classes = 2, hidden_dim=512, n_layers=2, dropout=0.5):
        super(SegImgNet, self).__init__()
        self.seg_Net = seg_Net
        convNeXt_raw.head.fc = nn.Linear(1024,1024)
        convNeXt_seg.head.fc = nn.Linear(1024,1024)
        self.convNeXt_raw_stem = convNeXt_raw.stem
        self.convNeXt_seg_stem = convNeXt_seg.stem
        
        self.Unet_stage_0_feature_transform = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_1_feature_transform = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_2_feature_transform = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.Unet_stage_3_feature_transform = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1)

        self.convNeXt_seg_stage0 = convNeXt_seg.stages[0]
        self.convNeXt_seg_stage1 = convNeXt_seg.stages[1]
        self.convNeXt_seg_stage2 = convNeXt_seg.stages[2]
        self.convNeXt_seg_stage3 = convNeXt_seg.stages[3]

        self.convNeXt_raw_stage0 = convNeXt_raw.stages[0]
        self.convNeXt_raw_stage1 = convNeXt_raw.stages[1]
        self.convNeXt_raw_stage2 = convNeXt_raw.stages[2]
        self.convNeXt_raw_stage3 = convNeXt_raw.stages[3]

        self.convNeXt_raw_head = convNeXt_raw.head
        self.convNeXt_seg_head = convNeXt_seg.head

        self.convNeXt_raw_linear1 = nn.Linear(1024,512)
        self.convNeXt_raw_relu1 = nn.GELU()
        self.convNeXt_raw_linear2 = nn.Linear(512,512)
        self.convNeXt_raw_relu2 = nn.GELU()
        # self.convNeXt_raw_linear3 = nn.Linear(512,2)

        self.convNeXt_seg_linear1 = nn.Linear(1024,512)
        self.convNeXt_seg_relu1 = nn.GELU()
        self.convNeXt_seg_linear2 = nn.Linear(512,512)
        self.convNeXt_seg_relu2 = nn.GELU()
        # self.convNeXt_seg_linear3 = nn.Linear(512,2)

        self.linear1 = nn.Linear(1024,hidden_dim)
        self.relu1 = nn.GELU()

        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))

        self.mlp_layers = nn.Sequential(*layers)

        
        self.linear2 = nn.Linear(hidden_dim, classes)
        self.softmax = nn.Softmax(dim=1)
        

    def forward(self,x):
        seg,unet_feature_map0,unet_feature_map1,unet_feature_map2,unet_feature_map3 = self.seg_Net(x)

        seg_mask = seg[:,1,:,:]
        seg_mask = seg_mask.unsqueeze(1)
        seg_mask = seg_mask.repeat(1,3,1,1)
        seg_raw = x*seg_mask

        convNeXt_raw_output = self.convNeXt_raw_stem(x)
        convNeXt_seg_output = self.convNeXt_seg_stem(seg_raw)

        convNeXt_raw_stage0_feature = self.convNeXt_raw_stage0(convNeXt_raw_output)
        convNeXt_seg_stage0_feature = self.convNeXt_seg_stage0(convNeXt_seg_output)
        Unet_stage_0_feature = self.Unet_stage_0_feature_transform(unet_feature_map0)
        convNeXt_seg_stage1_input = F.sigmoid(Unet_stage_0_feature) * convNeXt_seg_stage0_feature

        convNeXt_raw_stage1_feature = self.convNeXt_raw_stage1(convNeXt_raw_stage0_feature)
        convNeXt_seg_stage1_feature = self.convNeXt_seg_stage1(convNeXt_seg_stage1_input)
        Unet_stage_1_feature = self.Unet_stage_1_feature_transform(unet_feature_map1)
        convNeXt_seg_stage2_input = F.sigmoid(Unet_stage_1_feature) * convNeXt_seg_stage1_feature

        convNeXt_raw_stage2_feature = self.convNeXt_raw_stage2(convNeXt_raw_stage1_feature)
        convNeXt_seg_stage2_feature = self.convNeXt_seg_stage2(convNeXt_seg_stage2_input)
        Unet_stage_2_feature = self.Unet_stage_2_feature_transform(unet_feature_map2)
        convNeXt_seg_stage3_input = F.sigmoid(Unet_stage_2_feature) * convNeXt_seg_stage2_feature

        convNeXt_raw_stage3_feature = self.convNeXt_raw_stage3(convNeXt_raw_stage2_feature)
        convNeXt_seg_stage3_feature = self.convNeXt_seg_stage3(convNeXt_seg_stage3_input)
        Unet_stage_3_feature = self.Unet_stage_3_feature_transform(unet_feature_map3)
        convNeXt_seg_stage3_feature = F.sigmoid(Unet_stage_3_feature) * convNeXt_seg_stage3_feature
        
        convNeXt_raw_output = self.convNeXt_raw_head(convNeXt_raw_stage3_feature)
        convNeXt_seg_output = self.convNeXt_seg_head(convNeXt_seg_stage3_feature)


        convNeXt_raw_output = self.convNeXt_raw_linear1(convNeXt_raw_output)
        convNeXt_raw_output = self.convNeXt_raw_relu1(convNeXt_raw_output)
        convNeXt_raw_output = self.convNeXt_raw_linear2(convNeXt_raw_output)
        convNeXt_raw_output = self.convNeXt_raw_relu2(convNeXt_raw_output)
        # convNeXt_raw_logits = self.convNeXt_raw_linear3(convNeXt_raw_output)

        convNeXt_seg_output = self.convNeXt_seg_linear1(convNeXt_seg_output)
        convNeXt_seg_output = self.convNeXt_seg_relu1(convNeXt_seg_output)
        convNeXt_seg_output = self.convNeXt_seg_linear2(convNeXt_seg_output)
        convNeXt_seg_output = self.convNeXt_seg_relu2(convNeXt_seg_output)
        # convNeXt_seg_logits = self.convNeXt_seg_linear3(convNeXt_seg_output)


        
        output = torch.cat((convNeXt_raw_output, convNeXt_seg_output),1)
        output = self.linear1(output)
        output = self.relu1(output)
        output = self.mlp_layers(output)
        logits = self.linear2(output)
        probs = self.softmax(logits)

        return probs, logits, output
        
        

In [None]:
def test(model, loader, test_prob_dir, test_label_dir):

    with torch.no_grad():
        model.eval()
    
        for name, parameters in model.named_parameters():
            parameters.requires_grad = False

        num_correct, num_samples = 0, len(loader.dataset)

        test_batch_index = 0

        for x, y in loader:

            x_var = x.to(device)
            scores, _, _ = model(x_var)
            _, preds = scores.data.cpu().max(1)
            if test_batch_index == 0:
                y_test = y.cpu().detach().numpy()
                y_probs = scores.cpu().detach().numpy()
                y_pred = preds.numpy()
            else:
                y_test = np.append(y_test, y.cpu().detach().numpy())
                y_probs = np.vstack((y_probs, scores.cpu().detach().numpy()))
                y_pred = np.append(y_pred, preds.cpu().detach().numpy())

            num_correct += (preds == y).sum()
            test_batch_index += 1

        acc = float(num_correct) / num_samples
        # print(y_test.shape)
        y_probs = y_probs[...,1:]
        # print(y_probs.shape)
        # print(y_pred.shape)
    #     test_prob_dir = './data/eye_image/fold_0/Resnet_Aug_rwROSE_test_prob.txt'
#         np.savetxt(test_prob_dir, y_probs ,fmt='%.10e', delimiter=" ") 
#         np.savetxt(test_label_dir, y_test ,fmt='%.10e', delimiter=" ")

        print('Test accuracy: {:.2f}% ({}/{})'.format(
            100.*acc,
            num_correct,
            num_samples,
            ))

        AUC_score = roc_auc_score(y_test, y_probs, multi_class='ovr')
        AUPRC_score = average_precision_score(y_test, y_probs)
        f1 = f1_score(y_test, y_pred)
        accuracy = accuracy_score(y_test, y_pred)
        recall = recall_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred)
        tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
#         sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        

        
        print('AUC:')
        print(AUC_score)
        print('AUPRC:')
        print(AUPRC_score)
        print('recall:')
        print(recall)
        print('specificity:')
        print(specificity)
        print('F1:')
        print(f1)
        print('precision:')
        print(precision)
        print('Accuracy:')
        print(accuracy)

    return AUC_score, AUPRC_score, recall, specificity, f1, precision, accuracy

###model name: transUnet:'transUnet', transUnet_feature_fuse:'transUnet_feature_fuse', AttentionUnet:'AttentionUnet', duckNet:'duckNet', Unet:'Unet'


In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.data = datasets.ImageFolder(root=root_dir).imgs
        
        
        self.imgs = [x[0] for x in self.data]
        self.label = [x[1] for x in self.data]
        self.transform = transform
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
       
        img_seg = self.imgs[idx]
        img_seg = Image.open(img_seg)
        # img_org = img_seg.copy()
        # img_org = np.asarray(img_org)
        # img_org = cv2.cvtColor(img_org, cv2.COLOR_RGB2GRAY)
        # img_org = Image.fromarray(img_org)
        
        img_seg= self.transform(img_seg)
        # img_org = self.transform(img_org)
        return img_seg,self.label[idx]#return the processed image, original image and label

In [None]:

test_loader = torch.utils.data.DataLoader(
CustomImageDataset(param['testdata_dir'], transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
    
])),
batch_size=param['val_batch_size'], shuffle=False,
num_workers=param['workers'], pin_memory=True)

model = get_Model(model_name=param['model_name'])
model.load_state_dict(torch.load(param['seg-pretrained-weight']))

ConvNext_raw = timm.create_model("hf_hub:timm/convnext_base.fb_in22k_ft_in1k", pretrained=True)
ConvNext_seg = timm.create_model("hf_hub:timm/convnext_base.fb_in22k_ft_in1k", pretrained=True)
net = SegImgNet(ConvNext_raw,ConvNext_seg,model,param['num_classes'])
net.to(device)
for name, parameters in net.named_parameters():
    parameters.requires_grad = False

net_dir = os.path.join(model_dir, 'SegImgNet' + '_model_parameter2.pkl')
test_prob_dir = os.path.join(log_dir, 'SegImgNet_test_prob.txt')
test_label_dir = os.path.join(log_dir, 'SegImgNet_test_label.txt')
AUC_score, AUPRC_score, recall, specificity, f1, precision, accuracy = test(net, test_loader, test_prob_dir, test_label_dir)
print('AUC_score:', AUC_score)
print('AUPRC_score:', AUPRC_score)
print('recall:', recall)
print('specificity:', specificity)
print('f1:', f1)
print('precision:', precision)
print('accuracy:', accuracy)



