In [None]:
import tqdm
import os
import data
import nets
import torch
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score, roc_auc_score, confusion_matrix

diseases = ['DR', 'ARMD', 'MH', 'DN', 'MYA', 'BRVO', 'TSLN', 'ERM', 'LS', 'MS',
            'CSR', 'ODC', 'CRVO', 'TV', 'AH', 'ODP', 'ODE', 'ST', 'AION', 'PT',
            'RT', 'RS', 'CRS', 'EDN', 'RPEC', 'MHL', 'RP', 'other']

In [None]:
class EnsembleModel(torch.nn.Module):
  def __init__(self, models: list[torch.nn.Module]):
    super().__init__()
    self.models = models
    
  def forward(self, x):
    ys = torch.stack([m(x) for m in self.models]).mean(0)
    return ys

def load_model(ckpt_path, arch, device='cuda:0'):
  model = nets.load_model(arch)
  ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
  new_dict = {
      k.replace('vit.', 'model.'): v
      for k, v in ckpt['state_dict'].items()
  }
  model.load_state_dict(new_dict)
  model.eval()
  model.to(device)
  return model

def get_stats(pred, labels, index=0):
  auc = roc_auc_score(labels[:, index], pred[:, index])
  ap = average_precision_score(labels[:, index], pred[:, index])
  matrix = confusion_matrix(labels[:, index], pred[:, index] > 0.5)
  matrix25 = confusion_matrix(labels[:, index], pred[:, index] > 0.25)
  matrix10 = confusion_matrix(labels[:, index], pred[:, index] > 0.10)
  matrix5 = confusion_matrix(labels[:, index], pred[:, index] > 0.05)
  return {'AUC': auc, 'AP': ap, 'conf@0.5': matrix, 'conf@0.25': matrix25, 'conf@0.10': matrix10, 'conf@0.05': matrix5}

def get_metrics(model: EnsembleModel, valsets, batch_size=4, index=0):
  N = len(valsets[0])
  num_workers = min(os.cpu_count() - 1, batch_size)
  dataloader = DataLoader(valsets[0], batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers)

  outs = np.zeros((N, 29))

  with torch.no_grad():
    for i, (imgs, label) in enumerate(tqdm.tqdm(dataloader)):
      idx = i * batch_size
      imgs = imgs.to(model.models[0].device)
      out = torch.sigmoid(model(imgs)).detach().cpu().numpy()
      outs[idx:idx + len(out), :] = out

  stats = {}
  for v, valset in enumerate(valsets, 1):
    labels = valset.df.iloc[:, 1:].to_numpy()
    for i in index:
      didx = diseases.index(i) + 1
      stats[i + '_' + str(v)] = get_stats(outs, labels, index=didx)

  df = pd.DataFrame(stats).T
  return df
# MODEL = '/usr/mvl2/itdfh/dev/retinal-disease-classification/weights-resnext512-init-update-itdfhnorm/checkpoints/best/boosting-resnext512-init-update-itdfhnorm-512x512-b06-epoch=036-val_s_score=0.8907.ckpt'
# model = load_model(MODEL, 'resnext')

In [None]:
import glob

multi_expert_models = glob.glob('/usr/mvl2/itdfh/dev/retinal-disease-classification/weights-resnext512-init-update-itdfhnorm/checkpoints/best/*.ckpt')
# Single model
single_model = EnsembleModel([load_model(multi_expert_models[0], 'resnext')])

# Multi-expert
multi_expert_models = [load_model(m, 'resnext') for m in multi_expert_models]
multi_expert = EnsembleModel(multi_expert_models)

# Bagging
bagging_models = glob.glob('/cluster/VAST/civalab/results/riadd/imad/final-resnext-512-noinit-noupdate-norm-newsplit/best/*.ckpt')
bagging_models = [load_model(m, 'resnext', device='cuda:1') for m in bagging_models]
bagging = EnsembleModel(bagging_models)

In [None]:
# refuge_model = load_model('/usr/mvl2/itdfh/dev/retinal-disease-classification/weights-resnext512-finetune-refuge/checkpoints/boosting-resnext512-finetune-refuge-512x512-epoch=095-val_s_score=0.7074.ckpt', 'resnext')
# refuge_model = load_model('/usr/mvl2/itdfh/dev/retinal-disease-classification/weights-resnext512-finetune-refuge/checkpoints/boosting-resnext512-finetune-refuge-512x512-epoch=025-val_s_score=0.7072.ckpt', 'resnext')
# refuge_model = load_model('weights-resnext512-refuge/checkpoints/boosting-resnext512-refuge-512x512-b00-epoch=061-val_auc_glaucoma=0.9560.ckpt', 'resnext')
# only learn GL
# refuge_model = load_model('weights-resnext512-refuge/checkpoints/boosting-resnext512-refuge-512x512-b00-epoch=030-val_auc_glaucoma=0.9616.ckpt', 'resnext')
# frozen backbone
# refuge_model = load_model('weights-resnext512-refuge/checkpoints/boosting-resnext512-refuge-512x512-b00-epoch=045-val_auc_glaucoma=0.6358.ckpt', 'resnext')
# elham
# refuge_model = load_model('/cluster/VAST/civalab/results/elham_results/checkpoints_exp1/boosting-resnext512-init-update-itdfhnorm-512x512-b02-epoch=057-val_s_score=0.6710.ckpt', 'resnext')
refuge_model = load_model('/cluster/VAST/civalab/results/elham_results/checkpoints_exp1/boosting-resnext512-init-update-itdfhnorm-512x512-b00-epoch=018-val_s_score=0.6917.ckpt', 'resnext')
# refuge_model = load_model('/cluster/VAST/civalab/results/elham_results/checkpoints_exp1/boosting-resnext512-init-update-itdfhnorm-512x512-b00-epoch=020-val_s_score=0.6977.ckpt', 'resnext')
refuge_model = EnsembleModel([refuge_model])

# REFUGE

In [None]:
# test and validation GT
test_excel = '/cluster/VAST/civalab/public_datasets/refuge-full/Test400-GT/Glaucoma_label_and_Fovea_location.xlsx'
df_refuge = pd.read_excel(test_excel).drop(['ID', 'Fovea_X', 'Fovea_Y'], axis=1)


df_refuge['Disease_Risk'] = 0
for d in diseases:
  df_refuge[d] = 0
df_refuge.rename({'ImgName': 'ID'}, axis=1, inplace=True)
df_refuge['ODC'] = df_refuge['Label(Glaucoma=1)']
df_refuge['Disease_Risk'] = df_refuge['Label(Glaucoma=1)']
df_refuge.drop(['Label(Glaucoma=1)'], axis=1).to_csv(
    f'datasets/refuge_test.csv', index=False)


In [None]:
# test and validation GT
val_excel = '/cluster/VAST/civalab/public_datasets/refuge-full/Validation400-GT/Fovea_locations.xlsx'
df_refuge = pd.read_excel(val_excel).drop(['ID', 'Fovea_X', 'Fovea_Y'], axis=1)

df_refuge['Disease_Risk'] = 0
for d in diseases:
  df_refuge[d] = 0
df_refuge.rename({'ImgName': 'ID'}, axis=1, inplace=True)
df_refuge['ODC'] = df_refuge['Glaucoma Label']
df_refuge['Disease_Risk'] = df_refuge['Glaucoma Label']
df_refuge.drop(['Glaucoma Label'], axis=1).to_csv(
    f'datasets/refuge_val.csv', index=False)


In [None]:
# TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/REFUGE-llzqd/train'
# # TEST_CSV = '/cluster/VAST/civalab/public_datasets/REFUGE-llzqd/REFUGE_train.csv'
# TEST_CSV = 'refuge-train-full-label-riadd.csv'

TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/REFUGE/test/Images'
TEST_CSV = 'datasets/refuge_test.csv'


In [None]:
valset = data.REFUGEDataset(TEST_CSV, TEST_IMG_PATH, testing=True,
                            input_size=512)
# healthy_df = valset.df[(valset.df['ODC'] == 0) & (valset.df.sum(axis=1) == 0)]
# glaucoma_df = valset.df[valset.df['ODC'] == 1]
# valset.df = pd.concat([
#   healthy_df,
#   glaucoma_df  
# ])
plt.imshow(valset[0][0].permute(1, 2, 0)[..., 0].detach().cpu().numpy())


In [None]:
get_metrics(multi_expert, [valset], index=['ODC'], batch_size=16)

In [None]:
get_metrics(bagging, [valset], index=['ODC'])

In [None]:
get_metrics(single_model, [valset], index=['ODC'])

In [None]:
get_metrics(refuge_model, [valset], index=['ODC'])

# EyePACS

In [None]:
df_eyepacs = pd.read_csv('/cluster/VAST/civalab/public_datasets/EyePacs-diabetic-retinopathy-detection/testLabels.csv')
# df_eyepacs = pd.read_csv('/cluster/VAST/civalab/public_datasets/EyePacs-diabetic-retinopathy-detection/trainLabels.csv')
df_eyepacs['Disease_Risk'] = 0
for d in diseases:
  df_eyepacs[d] = 0
df_eyepacs.rename({'image': 'ID'}, axis=1, inplace=True)
df_eyepacs

In [None]:
for v in df_eyepacs['level'].unique():
  if not v:
    continue
  df_eyepacs['DR'] = (df_eyepacs['level'] >= v).astype(int)
  df_eyepacs['Disease_Risk'] = (df_eyepacs['level'] >= v).astype(int)
  remove_cols = ['level']
  if 'Usage' in df_eyepacs.columns:
    remove_cols.append('Usage')
  df_eyepacs.drop(remove_cols, axis=1).to_csv(f'datasets/test/eyepacs_test_level_{v}.csv', index=False)
  # break

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/EyePacs-diabetic-retinopathy-detection/test'
eyepacs = []
for i in range(1, 5):
  TEST_CSV = f'datasets/test/eyepacs_test_level_{i}.csv'
  eyepacs.append(data.EyePACSDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='.jpeg'))
plt.imshow(eyepacs[1][10][0].permute(1, 2, 0)[..., 0].detach().cpu().numpy())

In [None]:
df_results_eyepacs_single = get_metrics(single_model, eyepacs, batch_size=32, index=['DR'])
df_results_eyepacs_single

In [None]:
df_results_eyepacs_bagging = get_metrics(bagging, eyepacs, batch_size=64, index=['DR'])
df_results_eyepacs_bagging

In [None]:
df_results_eyepacs_bagging

In [None]:
eyepacs[0].__len__()

In [None]:
df_results_eyepacs_multi_expert = get_metrics(multi_expert, eyepacs, batch_size=32, index=['DR'])
df_results_eyepacs_multi_expert

# IDRiD

In [None]:
df_idrid = pd.read_csv('/cluster/VAST/civalab/public_datasets/IDRiD/B-Disease-Grading/2-Groundtruths/b-IDRiD_Disease-Grading_Testing-Labels.csv')


In [None]:
df_idrid['Disease_Risk'] = 0
for d in diseases:
  df_idrid[d] = 0
df_idrid.rename({'Image name': 'ID'}, axis=1, inplace=True)

for v in df_idrid['Retinopathy grade'].unique():
  if not v:
    continue
  df_idrid['DR'] = (df_idrid['Retinopathy grade'] >= v).astype(int)
  df_idrid['Disease_Risk'] = (df_idrid['Retinopathy grade'] >= v).astype(int)
  df_idrid.drop(['Retinopathy grade', 'Risk of macular edema '], axis=1).to_csv(f'datasets/idrid_test_level_{v}.csv', index=False)
  # break

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/IDRiD/B-Disease-Grading/1-Original-Images/b-Testing-Set'
idrid = []
for i in range(1, 5):
  TEST_CSV = f'datasets/idrid_test_level_{i}.csv'
  idrid.append(
    data.IDRiDDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512))
plt.imshow(idrid[0][1][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())

In [None]:
get_metrics(single_model, idrid, index=['DR'])

In [None]:
get_metrics(multi_expert, idrid, index=['DR'])

In [None]:
get_metrics(bagging, idrid, index=['DR'])

# Messidor

In [None]:
import glob
excels = glob.glob('/cluster/VAST/civalab/public_datasets/MESSIDOR/excel/*.xls')
# df_messidor = pd.read_excel('')

In [None]:
df_messidor = pd.DataFrame()
for excel in excels:
  base = excel.split('Annotation_')[-1].replace('.xls', '')
  df_base = pd.read_excel(excel)
  df_base['Image name'] = df_base['Image name'].apply(lambda x: base + '/' + x)
  df_messidor = pd.concat([df_messidor, df_base], axis=0)
  

In [None]:
df_messidor['Disease_Risk'] = 0
for d in diseases:
  df_messidor[d] = 0
df_messidor.rename({'Image name': 'ID'}, axis=1, inplace=True)

for v in df_messidor['Retinopathy grade'].unique():
  if not v:
    continue
  df_messidor['DR'] = (df_messidor['Retinopathy grade'] >= v).astype(int)
  df_messidor['Disease_Risk'] = (df_messidor['Retinopathy grade'] >= v).astype(int)
  df_messidor.drop(['Retinopathy grade', 'Risk of macular edema ', 'Ophthalmologic department'], axis=1).to_csv(
      f'datasets/messidor_level_{v}.csv', index=False)
  # break

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/MESSIDOR'
messidor = []
for i in range(1, 4):
  TEST_CSV = f'datasets/messidor_level_{i}.csv'
  messidor.append(data.IDRiDDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='.tif'))
plt.imshow(messidor[0][150][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())

In [None]:
get_metrics(single_model, messidor, index=['DR'], batch_size=32)

In [None]:
get_metrics(multi_expert, messidor, index=['DR'], batch_size=32)

In [None]:
get_metrics(bagging, messidor, index=diseases.index('DR') + 1, batch_size=32)

# DeepDRID

In [None]:
df_deepdrid = pd.read_csv('/cluster/VAST/civalab/public_datasets/DeepDRiD/regular_fundus_images/regular-fundus-validation/regular-fundus-validation.csv')
df_deepdrid['image_path'] = df_deepdrid['image_path'].apply(lambda x: x[1:].replace('\\', '/'))
df_deepdrid = df_deepdrid[['image_path', 'patient_DR_Level']]

df_deepdrid['Disease_Risk'] = 0
for d in diseases:
  df_deepdrid[d] = 0
df_deepdrid.rename({'image_path': 'ID'}, axis=1, inplace=True)

for v in df_deepdrid['patient_DR_Level'].unique():
  if not v:
    continue
  df_deepdrid['DR'] = (df_deepdrid['patient_DR_Level'] >= v).astype(int)
  df_deepdrid['Disease_Risk'] = (df_deepdrid['patient_DR_Level'] >= v).astype(int)
  df_deepdrid.drop(['patient_DR_Level'], axis=1).to_csv(
      f'datasets/deepdrid_val_level_{v}.csv', index=False)
  # break

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/DeepDRiD/regular_fundus_images/regular-fundus-validation'
TEST_CSV = 'datasets/deepdrid_val_level_3.csv'
deepdrid = data.IDRiDDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='.jpg')
plt.imshow(deepdrid[150][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())

In [None]:
get_metrics(single_model, deepdrid, index=diseases.index('DR') + 1, batch_size=32)

In [None]:
get_metrics(multi_expert, deepdrid, index=diseases.index('DR') + 1, batch_size=32)

In [None]:
get_metrics(bagging, deepdrid, index=diseases.index('DR') + 1, batch_size=32)

# HRF: High Resolution Fundus

In [None]:
files = glob.glob('/cluster/VAST/civalab/public_datasets/high-resolution-fundus/images/*')
ids = [os.path.basename(f) for f in files]
df_hrf = pd.DataFrame({
  'file': files,
  'ID': ids
})
df_hrf['Disease_Risk'] = 0
for d in diseases:
  df_hrf[d] = 0
df_hrf['DR'] = df_hrf['ID'].apply(lambda x: 'dr' in x).astype(int)
df_hrf['ODC'] = df_hrf['ID'].apply(lambda x: 'g' in x).astype(int)
df_hrf['Disease_Risk'] = df_hrf['ID'].apply(lambda x: 'h' not in x).astype(int)
df_hrf = df_hrf.drop('file', axis=1)
df_hrf.to_csv('datasets/hrf.csv', index=False)


In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/high-resolution-fundus/images'
TEST_CSV = 'datasets/hrf.csv'
hrf = data.REFUGEDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='')
plt.imshow(hrf[0][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())

In [None]:
get_metrics(refuge_model, [hrf], index=['ODC'], batch_size=32)

In [None]:
get_metrics(bagging, [hrf], index=['ODC', 'DR'], batch_size=32)

In [None]:
get_metrics(multi_expert, [hrf], index=['ODC', 'DR'], batch_size=32)

In [None]:
# get_metrics(single_model, hrf, index=diseases.index('DR') + 1, batch_size=32)
get_metrics(single_model, [hrf], index=['ODC', 'DR'], batch_size=32)

In [None]:
# get_metrics(single_model, hrf, index=diseases.index('ODC') + 1, batch_size=32)
get_metrics(refuge_model, [hrf], index=['ODC'], batch_size=2)

In [None]:
get_metrics(bagging, hrf, index=diseases.index('DR') + 1, batch_size=32)

In [None]:
get_metrics(bagging, hrf, index=diseases.index('ODC') + 1, batch_size=32)

In [None]:
get_metrics(multi_expert, hrf, index=diseases.index('DR') + 1, batch_size=32)

In [None]:
get_metrics(multi_expert, hrf, index=diseases.index('ODC') + 1, batch_size=32)

# ORIGA

In [None]:
df_origa = pd.read_csv('/cluster/VAST/civalab/public_datasets/ORIGA/OrigaList.csv')
df_origa

In [None]:
df_origa['Disease_Risk'] = 0
for d in diseases:
  df_origa[d] = 0
df_origa.rename({'Filename': 'ID'}, axis=1, inplace=True)
df_origa['ODC'] = df_origa['Glaucoma']
df_origa['Disease_Risk'] = df_origa['Glaucoma']
df_origa.drop(['Glaucoma', 'Eye', 'ExpCDR', 'Set'], axis=1).to_csv(
    f'datasets/origa.csv', index=False)


In [None]:
dl = DataLoader(origa, batch_size=4, shuffle=False,  num_workers=4)
for i, (x, y) in enumerate(dl):
  print(x.shape, y.shape)

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/ORIGA/Images_Square'
TEST_CSV = 'datasets/origa.csv'
origa = data.EyePACSDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='')
plt.imshow(origa[0][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())

In [None]:
get_metrics(single_model, [origa], index=['ODC'], batch_size=16)

In [None]:
get_metrics(bagging, [origa], index=['ODC'], batch_size=16)

In [None]:
get_metrics(multi_expert, [origa], index=['ODC'], batch_size=16)

In [None]:
get_metrics(refuge_model, [origa], index=['ODC'], batch_size=16)

In [None]:
get_metrics(bagging, origa, index=0, batch_size=32)

# STARE

In [None]:
with open('/cluster/VAST/civalab/public_datasets/STARE/all-mg-codes.txt') as f:
  lines = f.readlines()
  lines = [l.partition('\t') for l in lines]

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/STARE/all-images'
existing = [i.replace('.ppm', '') for i in os.listdir(TEST_IMG_PATH)]

In [None]:
df_stare = pd.DataFrame(lines, columns=['ID', 'sep', 'diag']).drop('sep', axis=1)
df_stare = df_stare[df_stare['ID'].isin(existing)]

df_stare['Disease_Risk'] = 1
for d in diseases:
  df_stare[d] = 0

def is_dr(x):
  return int('Diabetic Retinopathy' in x)

def is_armd(x):
  return int('Age Related Macular Degeneration' in x)


df_stare['DR'] = df_stare['diag'].apply(is_dr)
df_stare['ARMD'] = df_stare['diag'].apply(is_armd)
df_stare.drop(['diag'], axis=1).to_csv(
    f'datasets/stare.csv', index=False)


In [None]:
df_stare['DR']

In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/STARE/all-images'
TEST_CSV = 'datasets/stare.csv'
stare = data.EyePACSDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='.ppm',
                            noisy_student=False)

In [None]:
!cat $TEST_CSV

In [None]:
plt.imshow(stare[i][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())
i += 1

In [None]:
get_metrics(single_model, [stare], index=['DR', 'ARMD'], batch_size=32)

In [None]:
get_metrics(bagging, [stare], index=['DR', 'ARMD'], batch_size=32)

In [None]:
get_metrics(multi_expert, [stare], index=['DR', 'ARMD'], batch_size=32)

# G1020

In [None]:
df_g1020 = pd.read_csv('/cluster/VAST/civalab/public_datasets/G1020/G1020.csv')

df

In [None]:
df_g1020['Disease_Risk'] = 0
for d in diseases:
  df_g1020[d] = 0
df_g1020.rename({'imageID': 'ID'}, axis=1, inplace=True)
df_g1020['ODC'] = df_g1020['binaryLabels']
df_g1020['Disease_Risk'] = df_g1020['binaryLabels']
df_g1020.drop(['binaryLabels'], axis=1).to_csv(
    f'datasets/g1020.csv', index=False)


In [None]:
TEST_IMG_PATH = '/cluster/VAST/civalab/public_datasets/G1020/Images'
TEST_CSV = 'datasets/g1020.csv'
g1020 = data.EyePACSDataset(TEST_CSV, TEST_IMG_PATH, testing=True, input_size=512, ext='')
plt.imshow(g1020[0][0].permute(1, 2, 0)[..., 1].detach().cpu().numpy())

In [None]:
get_metrics(single_model, [g1020], index=['ODC'], batch_size=32)

In [None]:
get_metrics(bagging, [g1020], index=['ODC'], batch_size=32)

In [None]:
get_metrics(multi_expert, [g1020], index=['ODC'], batch_size=32)