In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn import metrics
import json

import tensorflow as tf
import tensorflow.keras
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import inception_v3
from tensorflow.keras import backend as K

In [None]:
image_dir = '/path/to/image/folder/'
train_labels = pd.read_csv('/path/to/label/folder/train_data_new.csv')
val_labels = pd.read_csv('/path/to/label/folder/val_data_new.csv')
test_labels = pd.read_csv('/path/to/label/folder/test_data_new.csv')

In [None]:
dataframe = test_labels

In [None]:
def preprocess_image(img_path):
    img = Image.open(img_path)
    img = img.resize((512, 512))
    img = np.asarray(img)
    img = np.expand_dims(img, axis=0)
    img = inception_v3.preprocess_input(img)
    return img

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
which_model = 'universal'

if which_model == 'diagnosis_only':
    model = '/path/to/pretrained/models/7pt_F_diag_T_mgmt_F_exported.h5'
elif which_model == 'universal':
    model = '/path/to/pretrained/models/7pt_T_diag_T_mgmt_T_exported.h5'
else:
    raise NotImplementedError

In [None]:
classifier = load_model(model)
classifier.trainable = False
print(K.learning_phase())

In [None]:
# Choose the modality-specific classifier: clinic or derm
which_modality = 'clinic'

In [None]:
if which_modality == 'clinic':
    derm_img = np.zeros((1, 512, 512, 3))
elif which_modality == 'derm':
    clinic_img = np.zeros((1, 512, 512, 3))
else:
    raise NotImplementedError
    
meta = np.zeros((1, 1, 1, 14))
aux_input_derm = np.zeros((1, 14, 14, 14))
aux_input_clinic = np.zeros((1, 14, 14, 14))

In [None]:
label_dict_diagnosis = {'BCC':0, 'NEV':1, 'MEL':2, 'MISC':3, 'SK':4}
label_dict_pn = {'ABS':0, 'TYP':1, 'ATP':2}
label_dict_bwv = {'ABS':0, 'PRS':1}
label_dict_vs = {'ABS':0, 'REG':1, 'IR':2}
label_dict_pig = {'ABS':0, 'REG':1, 'IR':2}
label_dict_str = {'ABS':0, 'REG':1, 'IR':2}
label_dict_dag = {'ABS':0, 'REG':1, 'IR':2}
label_dict_rs = {'ABS':0, 'PRS':1}

In [None]:
if which_modality == 'clinic':
    my_model_diagnosis = tf.keras.Model(classifier.input, classifier.get_layer('DIAG_clinic').output)
    my_model_pn = tf.keras.Model(classifier.input, classifier.get_layer('PN_clinic').output)
    my_model_bwv = tf.keras.Model(classifier.input, classifier.get_layer('BWV_clinic').output)
    my_model_vs = tf.keras.Model(classifier.input, classifier.get_layer('VS_clinic').output)
    my_model_pig = tf.keras.Model(classifier.input, classifier.get_layer('PIG_clinic').output)
    my_model_str = tf.keras.Model(classifier.input, classifier.get_layer('STR_clinic').output)
    my_model_dag = tf.keras.Model(classifier.input, classifier.get_layer('DaG_clinic').output)
    my_model_rs = tf.keras.Model(classifier.input, classifier.get_layer('RS_clinic').output)
elif which_modality == 'derm':
    my_model_diagnosis = tf.keras.Model(classifier.input, classifier.get_layer('DIAG_derm').output)
    my_model_pn = tf.keras.Model(classifier.input, classifier.get_layer('PN_derm').output)
    my_model_bwv = tf.keras.Model(classifier.input, classifier.get_layer('BWV_derm').output)
    my_model_vs = tf.keras.Model(classifier.input, classifier.get_layer('VS_derm').output)
    my_model_pig = tf.keras.Model(classifier.input, classifier.get_layer('PIG_derm').output)
    my_model_str = tf.keras.Model(classifier.input, classifier.get_layer('STR_derm').output)
    my_model_dag = tf.keras.Model(classifier.input, classifier.get_layer('DaG_derm').output)
    my_model_rs = tf.keras.Model(classifier.input, classifier.get_layer('RS_derm').output)

In [None]:
gts_diagnosis, gts_pn, gts_bwv, gts_vs, gts_pig, gts_str, gts_dag, gts_rs = ([] for i in range(8))
preds_diagnosis, preds_pn, preds_bwv, preds_vs, preds_pig, preds_str, preds_dag, preds_rs = ([] for i in range(8))

In [None]:
for idx in range(len(dataframe)):
    
    img_path = os.path.join(image_dir, dataframe.loc[idx, which_modality])
    img = preprocess_image(img_path)
    
    gt_diagnosis = label_dict_diagnosis[dataframe.loc[idx, 'diagnosis']]
    gt_pn = label_dict_pn[dataframe.loc[idx, 'pigment_network']]
    gt_bwv = label_dict_bwv[dataframe.loc[idx, 'blue_whitish_veil']]
    gt_vs = label_dict_vs[dataframe.loc[idx, 'vascular_structures']]
    gt_pig = label_dict_pig[dataframe.loc[idx, 'pigmentation']]
    gt_str = label_dict_str[dataframe.loc[idx, 'streaks']]
    gt_dag = label_dict_dag[dataframe.loc[idx, 'dots_and_globules']]
    gt_rs = label_dict_rs[dataframe.loc[idx, 'regression_structures']]

    gts_diagnosis.append(gt_diagnosis)
    gts_pn.append(gt_pn)
    gts_bwv.append(gt_bwv)
    gts_vs.append(gt_vs)
    gts_pig.append(gt_pig)
    gts_str.append(gt_str)
    gts_dag.append(gt_dag)
    gts_rs.append(gt_rs)
    
    if which_modality == 'clinic':
        clinic_img = img
    elif which_modality == 'derm':
        derm_img = img
        
    soft_pred_diagnosis = my_model_diagnosis.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_pn = my_model_pn.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_bwv = my_model_bwv.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_vs = my_model_vs.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_pig = my_model_pig.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_str = my_model_str.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_dag = my_model_dag.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})
    soft_pred_rs = my_model_rs.predict({'input_1_derm':derm_img, 'input_2_clinic':clinic_img, 'aux_input':meta, 'aux_input_derm':aux_input_derm, 'aux_input_clinic':aux_input_clinic})

    pred_diagnosis = tf.math.argmax(soft_pred_diagnosis, axis=-1)
    pred_pn = tf.math.argmax(soft_pred_pn, axis=-1)
    pred_bwv = tf.math.argmax(soft_pred_bwv, axis=-1)
    pred_vs = tf.math.argmax(soft_pred_vs, axis=-1)
    pred_pig = tf.math.argmax(soft_pred_pig, axis=-1)
    pred_str = tf.math.argmax(soft_pred_str, axis=-1)
    pred_dag = tf.math.argmax(soft_pred_dag, axis=-1)
    pred_rs = tf.math.argmax(soft_pred_rs, axis=-1)

    preds_diagnosis.append(pred_diagnosis.numpy().item())
    preds_pn.append(pred_pn.numpy().item())
    preds_bwv.append(pred_bwv.numpy().item())
    preds_vs.append(pred_vs.numpy().item())
    preds_pig.append(pred_pig.numpy().item())
    preds_str.append(pred_str.numpy().item())
    preds_dag.append(pred_dag.numpy().item())
    preds_rs.append(pred_rs.numpy().item())
    
    print(idx, '/', len(dataframe))

In [None]:
balacc_diagnosis = metrics.balanced_accuracy_score(gts_diagnosis, preds_diagnosis)
balacc_pn = metrics.balanced_accuracy_score(gts_pn, preds_pn)
balacc_bwv = metrics.balanced_accuracy_score(gts_bwv, preds_bwv)
balacc_vs = metrics.balanced_accuracy_score(gts_vs, preds_vs)
balacc_pig = metrics.balanced_accuracy_score(gts_pig, preds_pig)
balacc_str = metrics.balanced_accuracy_score(gts_str, preds_str)
balacc_dag = metrics.balanced_accuracy_score(gts_dag, preds_dag)
balacc_rs = metrics.balanced_accuracy_score(gts_rs, preds_rs)

print('balacc_diagnosis = ', balacc_diagnosis)
print('balacc_pn = ', balacc_pn)
print('balacc_bwv = ', balacc_bwv)
print('balacc_vs = ', balacc_vs)
print('balacc_pig = ', balacc_pig)
print('balacc_str = ', balacc_str)
print('balacc_dag = ', balacc_dag)
print('balacc_rs = ', balacc_rs)