In [283]:
from data_pipeline.np_dataset import NpDataset
import input_mapping.models_torch as models_torch
from data_pipeline.image_transforms import get_transforms

from data_pipeline.data_package import DataPackage





from PIL import Image
from pydicom import dcmread
import torch
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
import numpy as np

from ai_backend.loggers.model_logger import is_min
from uuid import uuid4
import torch.nn as nn
from torch.optim import Adam
import torch
import json
import os
import re
import tqdm
from ai_backend.evaluators.metrics.multi_label_metrics import  multi_label_f_beta, multi_label_confusion_matrix, multi_label_accuracy, multi_label_precision, multi_label_recall
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay, multilabel_confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd



In [284]:
model_id = '78cce3e7-29cf-4bf8-a557-fbf4c1ad8ec9'
model_key = 'resnet18'
model_folder = f'models/{model_key}/{model_id}'
path_to_model_conig = f'{model_folder}/run_config.json'
#load the model configuration
with open(path_to_model_conig, 'r') as f:
    run_config = json.load(f)
#also

In [285]:
transform_type = run_config['transform_type']
transforms_config = models_torch.model_dict[model_key]['transforms_config']
transform = get_transforms(transform_name = transform_type, transforms_config = transforms_config)

In [286]:
#train the model
dataset_name = '2024-06-05_16-22-01'
#load the dataset configuration
path_to_dataset_config = f'datasets/{dataset_name}/dataset_config.json'

with open(path_to_dataset_config, 'r') as f:
    dataset_config = json.load(f)
#get the the labels
labels_to_encode = dataset_config['labels_to_encode']


In [287]:
best_model_save_folder = f'models/{model_key}/{model_id}'
best_model_save_path = f'{best_model_save_folder}/weights.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#create np datasets for training, validation and testing
read_dicom = lambda x: dcmread(x).pixel_array
dicom_file_reader = lambda x: Image.fromarray(read_dicom(x)).convert('RGB')
default_file_reader = lambda x: Image.open(x).convert('RGB')


In [288]:
def convert_package_to_dataset(package, augmentations=None):
    file_reader = dicom_file_reader if package.data_source_name == 'UKB' else default_file_reader
    dataset = NpDataset(file_paths=package.get_data(), labels=package.get_labels(),
                         file_reader=file_reader, transform=transform, augmentation_transform=augmentations)
    return dataset

def convert_package_list_to_dataset(package_list, augmentations=None):
    datasets = []
    for package in package_list:
        dataset = convert_package_to_dataset(package)
        datasets.append(dataset)
    return datasets

In [289]:
#list the saved directories and load the datapackages
dataset_path = 'datasets/2024-06-05_16-22-01'
test_packages_path = f'{dataset_path}/test'
test_packages_paths = os.listdir(test_packages_path)
test_packages = []
for package_path in test_packages_paths:
    test_packages.append(DataPackage.load(f'{test_packages_path}/{package_path}'))
test_dataset = convert_package_list_to_dataset(test_packages)


In [290]:


#create data loaders
num_workers = 4
#create the data loaders
batch_size = 512
data_loaders = []
for dataset in test_dataset:
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    data_loaders.append(data_loader)

In [291]:
# create model
model = models_torch.get_model(model_name=model_key, num_classes=len(labels_to_encode))




In [292]:
#add dropout forward hooks to the model
for name, module in model.named_modules():
    re_pattern = re.compile(r'^layer\d+$')
    if re_pattern.match(name) is not None:
        print('Adding forward hook for:', name)
        module.register_forward_hook(lambda module, input,
                                      output: torch.nn.functional.dropout2d(output, p=0.2, training=module.training))

Adding forward hook for: layer1
Adding forward hook for: layer2
Adding forward hook for: layer3
Adding forward hook for: layer4


In [293]:
#load the best model
model.load_state_dict(torch.load(best_model_save_path))

<All keys matched successfully>

In [294]:
#sread the best thresholds
with open(f'{best_model_save_folder}/best_thresholds.json', 'r') as f:
    best_thresholds = json.load(f)
#convert to tensor
best_thresholds = torch.tensor(best_thresholds, dtype=torch.float32)

In [295]:
#evaluate the model
y_true_list = []
y_pred_list = []
for test_loader in data_loaders:
    model.to(device)
    model.eval()
    y_true = []
    y_pred = []
    for i, data in enumerate(test_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        y_true.append(labels.cpu().detach())
        y_pred.append(outputs.cpu().detach())
    y_true = torch.concat(y_true, dim=0)
    y_pred = torch.concat(y_pred, dim=0)
    y_true_list.append(y_true)
    y_pred_list.append(y_pred)

In [296]:
print(best_thresholds)
print(labels_to_encode)
best_thresholds = best_thresholds.numpy()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
['Age-related Macular Degeneration', 'Best Disease', 'Bietti crystalline dystrophy', 'cataract', 'Cone Dystrophy or Cone-rod Dystrophy', 'Diabetic Retinopathy', 'glaucoma', 'Maculopathy', 'Myopia', 'Normal', 'Retinitis Pigmentosa', 'Stargardt Disease', 'Macular dystrophy', 'Pseudoxanthoma elasticum', 'Retinal Dystrophy', 'Optic atrophy', 'Usher-Syndrom', 'Drusen', 'Leber Hereditary Optic Neuropathy', 'Choroideremia', 'Sorsby Fundus Dystrophy']


In [297]:
#calculate the metrics
f1_dicts = []
for i in range(len(y_true_list)):
    y_true = y_true_list[i]
    y_pred = y_pred_list[i]
    f1 = multi_label_f_beta(y_true, y_pred, threshold=best_thresholds, beta=1.0)
    f1_dict = {label: f1_value for label, f1_value in zip(labels_to_encode, f1)}
    f1_dicts.append(f1_dict)

  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) + recal)
  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) + recal)
  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) + recal)
  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) + recal)
  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) + recal)
  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) + recal)
  recall = true_positives / (true_positives + false_negatives)
  f_beta = (1 + beta**2) * (precision * recal) / ((beta**2 * precision) +

In [298]:
rp_f1_scores = []
for f1_dict in f1_dicts:
    rp_f1_score = f1_dict['Retinitis Pigmentosa']
    rp_f1_scores.append(rp_f1_score)
print(f'RP F1 scores: {rp_f1_scores}')

RP F1 scores: [nan, 0.01639344262295082, 0.02985074626865672, 0.1492537313432836, 1.0, 0.7804878048780487, 0.550561797752809]


In [299]:
#get the Retinitis Pigmentosa f1 scores from the f1_dicts
#create a pandas dataframe which contains the f1 scores of the Retinitis Pigmentosa
#reshape rp f1 scores
rp_f1_scores = np.array(rp_f1_scores)
column_names = [package.data_source_name for package in test_packages]
#create a dictionary matching the column names to the rp f1 scores
rp_f1_dict = {column_name: rp_f1_score for column_name, rp_f1_score in zip(column_names, rp_f1_scores)}
rp_f1_df = pd.DataFrame(rp_f1_dict, index=[0])

In [300]:
#remove the odir5k column
rp_f1_df = rp_f1_df.drop('ODIR-5K', axis=1)


In [301]:

#rename SES to BAU
rp_f1_df = rp_f1_df.rename(columns={'SES': 'BAU'})
#sort the columns alphabetically
rp_f1_df = rp_f1_df.reindex(sorted(rp_f1_df.columns), axis=1)
rp_f1_df.round(2).head()

Unnamed: 0,1000images,BAU,RFMiD,RFMiD2,RIPS,UKB
0,0.15,0.78,0.02,0.03,1.0,0.55
