In [None]:
import os
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from torchvision.models import resnet18
import timm
from torch.utils.data import DataLoader
from skimage.io import imread
import sklearn
from sklearn import metrics
from sklearn.metrics import f1_score
from sklearn.utils import class_weight
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import random
from early_stopping import EarlyStopping
import os
from airogs_dataset import Airogs
import wandb
import sys
from torchvision.datasets import ImageFolder
import sklearn.metrics
import yaml
torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:
from skimage.exposure import equalize_adapthist
from skimage.transform import warp_polar

class CLAHE(torch.nn.Module):
    def forward(self, img):
        image = np.array(img, dtype=np.float64) / 255.0
        image = equalize_adapthist(image)
        image = (image*255).astype('uint8')

        return image

class POLAR(torch.nn.Module):
    def polar(self,image):
        return warp_polar(image, radius=(max(image.shape) // 2), multichannel=True)
    
    def forward(self, image):
        image = np.array(image, dtype=np.float64)
        image = self.polar(image)
        return image

def set_seed(s):
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(s)
    random.seed(s)
    os.environ['PYTHONHASHSEED'] = str(s)
set_seed(0)

In [None]:
############ CONFIGS ############

num_workers = 0
batch_size = 8


#original
model_0 = timm.create_model('efficientnet_b0',num_classes=2)
model_0.load_state_dict(torch.load('/home/wangqy/gardnet/Checkpoints/rimonedl_1.pt')['state_dict'])

#polar
model_1 = timm.create_model('efficientnet_b0',num_classes=2)
model_1.load_state_dict(torch.load('/home/wangqy/gardnet/Checkpoints/rimonedl_2.pt')['state_dict'])

models=[model_0,model_1]


transforms = [
    torchvision.transforms.Compose([CLAHE(),torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),
    torchvision.transforms.Compose([POLAR(),CLAHE(),torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),
]

path = ['/home/wangqy/gardnet/RIM-ONE_DL_images/partitioned_by_hospital/test_set',
        '/home/wangqy/gardnet/RIM-ONE_DL_images/partitioned_by_hospital/test_set',
        ]


test_datasets = [
                ImageFolder(path[0], transform=transforms[0]),
                ImageFolder(path[1], transform=transforms[1]),
                ]
                 

test_loader = [
    DataLoader(test_datasets[0], batch_size=batch_size,shuffle=False,num_workers=num_workers),
    DataLoader(test_datasets[1], batch_size=batch_size,shuffle=False,num_workers=num_workers),
]

In [None]:
labels = {0: [], 1: []}     #真实值 labels：测试集中每个样本的真实标签，类型为列表。例如，labels[0]可能包含1000个样本的真实标签，如[1, 2, 0, 1, 3, ...]。

predictions = {0: [], 1: []}    #预测的不同类别的概率模型对测试集中每个样本的预测概率，类型为张量。模型对测试集中每个样本的预测结果，类型为列表。例如，predictions[0]可能包含1000个样本的预测结果，如[1, 2, 0, 1, 3, ...]。
probs = {0: [], 1: []}    #预测的概率最大的类别 例如，probs[1]可能包含1000个样本，每个样本对应两个类别的预测概率，如[[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], ...]。

with torch.no_grad():
    for i in range(2):
        models[i].eval()
        models[i] = models[i].cuda()
        for (inp, target) in tqdm(test_loader[i]):
            labels[i] += target
            batch_prediction = models[i](inp.cuda())
            probs[i] += torch.softmax(batch_prediction,dim=1)
            _, batch_prediction = torch.max(batch_prediction, dim=1)
            predictions[i] += batch_prediction.detach().tolist()

In [None]:
_probs = {}
_labels = {}

_probs[0] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[0])))
_probs[1] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[1])))

_labels[0] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[0])))
_labels[1] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[1])))


In [None]:
w_1 = 1
w_2 = 1
avg_probs = ((w_1*_probs[0]) + (w_2*_probs[1]))/2

In [None]:
preds = np.argmax(avg_probs,axis=1)

In [None]:
gt = _labels[0]

In [None]:
sklearn.metrics.f1_score(gt, preds, average="macro")

In [None]:
sklearn.metrics.roc_auc_score(gt, preds)

In [None]:
confusion = metrics.confusion_matrix(labels, predictions)
print(confusion)