# Setup

In [None]:
!pip install medmnist
!pip install git+https://github.com/qubvel/classification_models.git
!pip install tensorflow==2.16.1

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd drive/MyDrive/MLHM

In [4]:
import os
import medmnist
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers as layers
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dropout, BatchNormalization, GlobalAveragePooling2D, Input, ReLU
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.optimizers import Adam
from classification_models.keras import Classifiers
from shvit import SHVIT
from utils import load_data, cal_metrics, train, evaluate
import warnings
warnings.filterwarnings('ignore')

# Load and View Data

In [None]:
data_flag = 'chestmnist'
dataset_folder = f'./{data_flag}'
model_folder = './saved_models'
download = False
info = medmnist.INFO[data_flag]
task, n_samples, n_channels, n_classes, label_dict = info['task'], info['n_samples'], info['n_channels'], len(info['label']), info['label']

print('Task:', task)
print('Number of samples:', n_samples)
print('Number of channels:', n_channels)
print('Number of classes:', n_classes)
print('Label Dict:', label_dict)

In [6]:
# if not os.path.exists(dataset_folder):
#     download = True
#     os.makedirs(dataset_folder)

# size_flags = [64, 128, 224]
# for size_flag in size_flags:
#     train_dataset = DataClass(root=dataset_folder, size=size_flag, split='train', download=download)
#     valid_dataset = DataClass(root=dataset_folder, size=size_flag, split='val', download=download)
#     test_dataset = DataClass(root=dataset_folder, size=size_flag, split='test', download=download)

In [7]:
# data_224 = load_data(dataset_folder, data_flag, size_flag=224)
# train_images_224, train_labels_224, val_images_224, val_labels_224, test_images_224, test_labels_224 = data_224[0], data_224[1], data_224[2], data_224[3], data_224[4], data_224[5]

In [None]:
# df_rows = []

# def stats(sett):
#   for c, disease in label_dict.items():
#       count = np.sum(eval(f'{sett}_labels_224')[:, int(c)]==1)
#       df_rows.append([sett, disease, count])

# stats('train')
# stats('val')
# stats('test')
# dataset_df = pd.DataFrame(df_rows, columns=['set', 'disease', 'count'])

# for d in label_dict.values():
#   ids = dataset_df['disease']==d
#   df_rows.append(['whole', d, np.sum(dataset_df[ids]['count'])])

# dataset_df = pd.DataFrame(df_rows, columns=['set', 'disease', 'count'])
# dataset_df

In [None]:
# def plot_pie(sett):
#     df = dataset_df[dataset_df['set']==sett]
#     plt.figure(figsize=(7, 7))
#     plt.pie(x=df['count'], labels=df['disease'], autopct=lambda pct: '{:1.1f}%'.format(pct) if pct > 5 else '',
#             pctdistance=0.7, labeldistance=1.1, textprops={'size': 'xx-large'})

#     plt.savefig(f'{sett}.pdf', bbox_inches='tight')
#     plt.show()

# plot_pie('whole')
# plot_pie('train')
# plot_pie('val')
# plot_pie('test')

In [None]:
# label_texts = info['label']
# nr, nc = 2, 7
# fig, axes = plt.subplots(nrows=nr, ncols=nc, sharex=True, sharey=True, figsize=(20, 7))
# axes = axes.reshape(nr*nc)

# for ax, class_id in zip(axes, range(0, 14)):
#     title = ''
#     # get images that have current class_id
#     image_ids = np.where(train_labels_224[:, class_id]==1)[0]
#     images = train_images_224[image_ids]
#     labels = train_labels_224[image_ids]

#     # get image that has most diseases
#     sum_by_row = np.sum(images, axis=1)
#     image_id = np.where(sum_by_row == np.max(sum_by_row))[0][0]
#     label_ids = np.where(labels[image_id] == 1)[0]
#     for i in label_ids:
#         title += label_texts[str(i)] + '\n'
#     title = title[:-1]

#     ax.imshow(train_images_224[image_id], cmap='gray')
#     ax.axis('off')
#     ax.set_title(title)

# plt.savefig('dataset.pdf', bbox_inches='tight')

In [None]:
# label_texts = info['label']

# for class_id in range(0, 14):
#     # get images that have current class_id
#     image_ids = np.where(train_labels_224[:, class_id]==1)[0]
#     images = train_images_224[image_ids]
#     labels = train_labels_224[image_ids]

#     # get image that has most diseases
#     sum_by_row = np.sum(images, axis=1)
#     image_id = np.where(sum_by_row == np.max(sum_by_row))[0][0]
#     label_ids = np.where(labels[image_id] == 1)[0]

#     im = Image.fromarray(train_images_224[image_id])
#     im.save(f'./images_64/raw_example_{image_id}_{str(label_ids)}.jpg')

In [None]:
# plt.imshow(train_images_224[0], cmap='gray')
# plt.axis('off')
# plt.savefig('raw_example.jpg', bbox_inches='tight')
# plt.show()

# Models and Training Settings

In [None]:
def CNN(image_size, n_classes):
    model = Sequential()

    model.add(Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu', input_shape=(image_size, image_size, 1)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2)))

    model.add(Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2)))

    model.add(Conv2D(filters=128, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2)))

    model.add(Flatten())

    model.add(Dense(128, activation='relu'))
    model.add(Dense(n_classes, activation='sigmoid'))

    return model

In [None]:
if not os.path.exists(model_folder):
    os.makedirs(model_folder)

BATCH_SIZE = 128
N_EPOCHS = 30
LR = 0.001
WEIGHT_DECAY = 0.001

In [14]:
data_augmentation = tf.keras.Sequential([
    layers.RandomZoom(0.2),
    layers.RandomTranslation(0.1, 0.1),
    layers.RandomRotation(0.2),
    layers.RandomFlip('horizontal_and_vertical')
])

# 64x64

In [6]:
image_size_64 = 64
data_64 = load_data(dataset_folder, data_flag, size_flag=image_size_64)
X_train_64, Y_train_64, X_valid_64, Y_valid_64, X_test_64, Y_test_64 = data_64[0], data_64[1], data_64[2], data_64[3], data_64[4], data_64[5]
X_train_64, X_valid_64, X_test_64 = np.expand_dims(X_train_64, axis=-1)/255, np.expand_dims(X_valid_64, axis=-1)/255, np.expand_dims(X_test_64, axis=-1)/255

## CNN

In [None]:
cnn_64_path = f'{model_folder}/cnn_64.weights.h5'
cnn_64 = CNN(image_size_64, n_classes)
cnn_64.summary()

In [None]:
cnn_64.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(cnn_64, X_train_64, Y_train_64, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=cnn_64_path)

In [None]:
cnn_64.load_weights(cnn_64_path)
test_metrics = evaluate(cnn_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## VGG-16

In [None]:
VGG16, preprocess_input = Classifiers.get('vgg16')
vgg16_64 = VGG16(input_shape=(image_size_64, image_size_64, 3), weights='imagenet', include_top=False)
# for layer in vgg16_64.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_64, image_size_64, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation=None)(new_input)
x = vgg16_64(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
vgg16_64 = Model(inputs=new_input, outputs=prediction)
vgg16_64.summary()

In [None]:
vgg16_64_path = f'{model_folder}/vgg16_64.weights.h5'
vgg16_64.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(vgg16_64, X_train_64, Y_train_64, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=vgg16_64_path)

In [None]:
vgg16_64.load_weights(vgg16_64_path)
test_metrics = evaluate(vgg16_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## VGG-19

In [None]:
VGG19, preprocess_input = Classifiers.get('vgg19')
vgg19_64 = VGG19(input_shape=(image_size_64, image_size_64, 3), weights='imagenet', include_top=False)
# for layer in vgg19_64.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_64, image_size_64, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation=None)(new_input)
x = vgg19_64(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
vgg19_64 = Model(inputs=new_input, outputs=prediction)
vgg19_64.summary()

In [None]:
vgg19_64_path = f'{model_folder}/vgg19_64.weights.h5'
vgg19_64.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss=B'binary_crossentropy')
train(vgg19_64, X_train_64, Y_train_64, n_epochs=N_EPOCHS, model_path=vgg19_64_path)

In [None]:
vgg19_64.load_weights(vgg19_64_path)
test_metrics = evaluate(vgg19_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## ResNet-18

In [None]:
ResNet18, preprocess_input = Classifiers.get('resnet18')
resnet18_64 = ResNet18(input_shape=(image_size_64, image_size_64, 3), weights='imagenet', include_top=False)
# for layer in resnet18_64.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_64, image_size_64, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation='relu')(new_input)
x = resnet18_64(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
resnet18_64 = Model(inputs=new_input, outputs=prediction)
resnet18_64.summary()

In [None]:
resnet18_64_path = f'{model_folder}/resnet18_64.weights.h5'
resnet18_64.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(resnet18_64, X_train_64, Y_train_64, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=resnet18_64_path)

In [None]:
resnet18_64.load_weights(resnet18_64_path)
test_metrics = evaluate(resnet18_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## ResNet-50

In [None]:
ResNet50, preprocess_input = Classifiers.get('resnet50')
resnet50_64 = ResNet50(input_shape=(image_size_64, image_size_64, 3), weights='imagenet', include_top=False)
# for layer in resnet50_64.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_64, image_size_64, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation='relu')(new_input)
x = resnet50_64(x)
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
resnet50_64 = Model(inputs=new_input, outputs=prediction)
resnet50_64.summary()

In [None]:
resnet50_64_path = f'{model_folder}/resnet50_64.weights.h5'
resnet50_64.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(resnet50_64, X_train_64, Y_train_64, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=resnet50_64_path)

In [None]:
resnet50_64.load_weights(resnet50_64_path)
test_metrics = evaluate(resnet50_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## SHViT

In [None]:
shvit_settings = {
        'embed_dim': [128, 224, 320],
        'depth': [2, 4, 5],
        'partial_dim': [32, 48, 68],
        'types' : ['i', 's', 's']
    }
shvit = SHVIT(**shvit_settings)
input = Input(shape=(image_size_64, image_size_64, 1))
output = shvit(input)
shvit_64 = Model(inputs=input, outputs=output)
shvit_64.summary()

In [None]:
shvit_64_path = f'{model_folder}/shvit_64.weights.h5'
shvit_64.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(shvit_64, X_train_64, Y_train_64, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=shvit_64_path)

In [None]:
shvit_64.load_weights(shvit_64_path)
test_metrics = evaluate(shvit_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

# 128x128

In [None]:
image_size_128 = 128
data_128 = load_data(dataset_folder, data_flag, size_flag=image_size_128)
X_train_128, Y_train_128, X_valid_128, Y_valid_128, X_test_128, Y_test_128 = data_128[0], data_128[1], data_128[2], data_128[3], data_128[4], data_128[5]
X_train_128, X_valid_128, X_test_128 = np.expand_dims(X_train_128, axis=-1)/255, np.expand_dims(X_valid_128, axis=-1)/255, np.expand_dims(X_test_128, axis=-1)/255

## CNN

In [None]:
cnn_128 = CNN(image_size_128, n_classes)
cnn_128.summary()

In [None]:
cnn_128.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(cnn_128, X_train_128, Y_train_128, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=cnn_128_path)

In [None]:
cnn_128.load_weights(cnn_64_path)
test_metrics = evaluate(cnn_64, X_test_128, Y_test_128)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## VGG-16

In [None]:
VGG16, preprocess_input = Classifiers.get('vgg16')
vgg16_128 = VGG16(input_shape=(image_size_128, image_size_128, 3), weights='imagenet', include_top=False)
# for layer in vgg16_128.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_128, image_size_128, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation=None)(new_input)
x = vgg16_128(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
vgg16_128 = Model(inputs=new_input, outputs=prediction)
vgg16_128.summary()

In [None]:
vgg16_128_path = f'{model_folder}/vgg16_128.weights.h5'
vgg16_128.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(vgg16_64, X_train_128, Y_train_128, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=vgg16_128_path)

In [None]:
vgg16_128.load_weights(vgg16_128_path)
test_metrics = evaluate(vgg16_64, X_test_128, Y_test_128)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## VGG-19

In [None]:
VGG19, preprocess_input = Classifiers.get('vgg19')
vgg19_128 = VGG19(input_shape=(image_size_128, image_size_128, 3), weights='imagenet', include_top=False)
# for layer in vgg19_224.layers:
#     layer.trainable = False

new_input = Input(shape=(image_size_128, image_size_128, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation=None)(new_input)
x = vgg19_128(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
vgg19_128 = Model(inputs=new_input, outputs=prediction)
vgg19_128.summary()

In [None]:
vgg19_128_path = f'{model_folder}/vgg19_128.weights.h5'
vgg19_128.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss=B'binary_crossentropy')
train(vgg19_64, X_train_128, Y_train_128, n_epochs=N_EPOCHS, model_path=vgg19_128_path)

In [None]:
vgg19_128.load_weights(vgg19_64_path)
test_metrics = evaluate(vgg19_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## ResNet-18

In [None]:
ResNet18, preprocess_input = Classifiers.get('resnet18')
resnet18_128 = ResNet18(input_shape=(image_size_128, image_size_128, 3), weights='imagenet', include_top=False)
# for layer in resnet18_64.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_128, image_size_128, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation='relu')(new_input)
x = resnet18_128(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
resnet18_128 = Model(inputs=new_input, outputs=prediction)
resnet18_128.summary()

In [None]:
resnet18_128_path = f'{model_folder}/resnet18_128.weights.h5'
resnet18_128.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(resnet18_64, X_train_128, Y_train_128, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=resnet18_128_path)

In [None]:
resnet18_128.load_weights(resnet18_128_path)
test_metrics = evaluate(resnet18_128, X_test_128, Y_test_128)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## ResNet-50

In [None]:
ResNet50, preprocess_input = Classifiers.get('resnet50')
resnet50_128 = ResNet50(input_shape=(image_size_128, image_size_128, 3), weights='imagenet', include_top=False)
# for layer in resnet50_128.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_128, image_size_128, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation='relu')(new_input)
x = resnet50_128(x)
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
resnet50_128 = Model(inputs=new_input, outputs=prediction)
resnet50_128.summary()

In [None]:
resnet50_128_path = f'{model_folder}/resnet50_128.weights.h5'
resnet50_128.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(resnet50_128, X_train_128, Y_train_128, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=resnet50_128_path)

In [None]:
resnet50_128.load_weights(resnet50_128_path)
test_metrics = evaluate(resnet50_128, X_test_128, Y_test_128)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## SHViT

In [None]:
shvit_settings = {
        'embed_dim': [128, 224, 320],
        'depth': [2, 4, 5],
        'partial_dim': [32, 48, 68],
        'types' : ['i', 's', 's']
    }
shvit = SHVIT(**shvit_settings)
input = Input(shape=(image_size_128, image_size_128, 1))
output = shvit(input)
shvit_128 = Model(inputs=input, outputs=output)
shvit_128.summary()

In [None]:
shvit_128_path = f'{model_folder}/shvit_128.weights.h5'
shvit_128.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(shvit_128, X_train_128, Y_train_128, n_epochs=1, batch_size=BATCH_SIZE, model_path=shvit_128_path)

In [None]:
shvit_128.load_weights(shvit_128_path)
test_metrics = evaluate(shvit_128, X_test_128, Y_test_128)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

# 224x224

In [None]:
image_size_224 = 224
data_224 = load_data(dataset_folder, data_flag, size_flag=image_size_224)
X_train_224, Y_train_224, X_test_224, Y_test_224 = np.expand_dims(data_224[0], axis=-1)/255, data_224[1], np.expand_dims(data_224[4], axis=-1), data_224[5]
X_train_224, X_test_224 = np.expand_dims(X_train_224, axis=-1)/255, np.expand_dims(X_test_224, axis=-1)/255

## CNN

In [None]:
cnn_224 = CNN(image_size_224, n_classes)
cnn_224.summary()

In [None]:
cnn_224_path = f'{model_folder}/cnn_224.weights.h5'
cnn_224.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(cnn_224, X_train_224, Y_train_224, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=cnn_224_path)

In [None]:
cnn_224.load_state_dict(torch.load(cnn_224_path, map_location=DEVICE))
test_metrics = evaluate(cnn_224, test_loader_224, device=DEVICE)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## VGG-16

In [None]:
VGG16, preprocess_input = Classifiers.get('vgg16')
vgg16_224 = VGG16(input_shape=(image_size_224, image_size_224, 3), weights='imagenet', include_top=False)
# for layer in vgg16_224.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_224, image_size_224, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation=None)(new_input)
x = vgg16_224(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
vgg16_224 = Model(inputs=new_input, outputs=prediction)
vgg16_224.summary()

In [None]:
vgg16_224_path = f'{model_folder}/vgg16_224.weights.h5'
vgg16_224.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(vgg16_64, X_train_224, Y_train_224, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=vgg16_224_path)

In [None]:
vgg16_224.load_weights(vgg16_224_path)
test_metrics = evaluate(vgg16_64, X_test_224, Y_test_224)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## VGG-19

In [None]:
VGG19, preprocess_input = Classifiers.get('vgg19')
vgg19_224 = VGG19(input_shape=(image_size_224, image_size_224, 3), weights='imagenet', include_top=False)
# for layer in vgg19_224.layers:
#     layer.trainable = False

new_input = Input(shape=(image_size_224, image_size_224, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation=None)(new_input)
x = vgg19_224(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
vgg19_224 = Model(inputs=new_input, outputs=prediction)
vgg19_224.summary()

In [None]:
vgg19_224_path = f'{model_folder}/vgg19_224.weights.h5'
vgg19_224.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss=B'binary_crossentropy')
train(vgg19_64, X_train_128, Y_train_128, n_epochs=N_EPOCHS, model_path=vgg19_224_path)

In [None]:
vgg19_128.load_weights(vgg19_64_path)
test_metrics = evaluate(vgg19_64, X_test_64, Y_test_64)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## ResNet-18

In [None]:
ResNet18, preprocess_input = Classifiers.get('resnet18')
resnet18_224 = ResNet18(input_shape=(image_size_224, image_size_224, 3), weights='imagenet', include_top=False)
# for layer in resnet18_64.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_224, image_size_224, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation='relu')(new_input)
x = resnet18_224(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
resnet18_224 = Model(inputs=new_input, outputs=prediction)
resnet18_224.summary()

In [None]:
resnet18_224_path = f'{model_folder}/resnet18_224.weights.h5'
resnet18_224.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(resnet18_224, X_train_224, Y_train_224, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=resnet18_224_path)

In [None]:
resnet18_224.load_weights(resnet18_224_path)
test_metrics = evaluate(resnet18_128, X_test_224, Y_test_224)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## ResNet-50

In [None]:
ResNet50, preprocess_input = Classifiers.get('resnet50')
resnet50_224 = ResNet50(input_shape=(image_size_224, image_size_224, 3), weights='imagenet', include_top=False)
# for layer in resnet50_224.layers:
#     layer.trainable = False
new_input = Input(shape=(image_size_224, image_size_224, 1))
x = Conv2D(filters=3, kernel_size=(3, 3), padding="same", activation='relu')(new_input)
x = resnet50_224(x)
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
prediction = Dense(n_classes, activation='sigmoid')(x)
resnet50_224 = Model(inputs=new_input, outputs=prediction)
resnet50_224.summary()

In [None]:
resnet50_224_path = f'{model_folder}/resnet50_224.weights.h5'
resnet50_224.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(resnet50_128, X_train_224, Y_train_224, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, model_path=resnet50_224_path)

In [None]:
resnet50_224.load_weights(resnet50_224_path)
test_metrics = evaluate(resnet50_128, X_test_224, Y_test_224)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

## SHViT

In [None]:
shvit_settings = {
        'embed_dim': [128, 224, 320],
        'depth': [2, 4, 5],
        'partial_dim': [32, 48, 68],
        'types' : ['i', 's', 's']
    }
shvit = SHVIT(**shvit_settings)
input = Input(shape=(image_size_224, image_size_224, 1))
output = shvit(input)
shvit_224 = Model(inputs=input, outputs=output)
shvit_224.summary()

In [None]:
shvit_224.compile(optimizer=Adam(learning_rate=LR, weight_decay=WEIGHT_DECAY), loss='binary_crossentropy')
train(shvit_224, X_train_224, Y_train_224, n_epochs=1, batch_size=BATCH_SIZE, model_path=shvit_224_path)

In [None]:
shvit_224.load_weights(shvit_224_path)
test_metrics = evaluate(shvit_128, X_test_224, Y_test_224)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')