In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" # Choose which GPUs by checking current use with nvidia-smi
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import ResNet152V2
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras import metrics
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image

# Check CUDA functionality, restart kernel to change GPUs
gpus = tf.config.list_physical_devices('GPU')
print(gpus)

# Define function to preprocess images as required by ResNet
def preprocess(images, labels):
    return tf.keras.applications.resnet_v2.preprocess_input(images), labels

# Data Import and Split
traindir = '<path_to_train_data>'
valdir = '<path_to_validation_data>'

buffersize = 3
im_dim = 512

train = tf.keras.preprocessing.image_dataset_from_directory(
    traindir, image_size=(im_dim, im_dim), batch_size=16)
val = tf.keras.preprocessing.image_dataset_from_directory(
    valdir, image_size=(im_dim, im_dim), batch_size=16)

train_ds = train.map(preprocess)
val_ds = val.map(preprocess)
train_ds = train_ds.prefetch(buffer_size=buffersize)
val_ds = val_ds.prefetch(buffer_size=buffersize)

im_dim = 512
testdir = '<path_to_test_data>'
test = tf.keras.preprocessing.image_dataset_from_directory(
    testdir, image_size=(im_dim, im_dim), batch_size=16)
test_ds = test.map(preprocess)

epochs = 50
lr = 0.00588461 #The learning rate was optimized with Bayesian optimization

# Compile model and distribute between GPUs

mirrored_strategy = tf.distribute.MirroredStrategy()
#atexit.register(mirrored_strategy._extended._collective_ops._pool.close) # type: ignore

with mirrored_strategy.scope():
    cb1 = EarlyStopping(monitor='val_accuracy', patience=5)
    cb2 = ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, patience=2, min_lr=0.00001)
    opt = keras.optimizers.Adam(learning_rate=lr)
    metr = [metrics.BinaryAccuracy(name='accuracy'), metrics.AUC(name='auc'), metrics.Recall(name='recall'), 
            metrics.Precision(name='precision'), metrics.TruePositives(name='TP'), metrics.TrueNegatives(name='TN'),
            metrics.FalsePositives(name='FP'), metrics.FalseNegatives(name='FN')]
    ptmodel = ResNet152V2(include_top=False, weights='imagenet', classes=2, 
                          input_shape=(512, 512, 3), pooling='avg')
        
    ptmodel.trainable = False
    # un-freeze the BatchNorm layers
    for layer in ptmodel.layers:
        if "BatchNormalization" in layer.__class__.__name__:
            layer.trainable = True
              
    last_output = ptmodel.output
    x = tf.keras.layers.Flatten()(last_output)
    x = tf.keras.layers.Dense(2048, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5, seed=717)(x)
    x = tf.keras.layers.Dense(512, activation = 'relu')(x)
    x = tf.keras.layers.Dense(128, activation = 'relu')(x)
    x = tf.keras.layers.Dense(32, activation = 'relu')(x)
    x = tf.keras.layers.Dense(1, activation = 'sigmoid')(x)
    model = tf.keras.Model(ptmodel.input, x)
    model.compile(optimizer=opt, loss='BinaryCrossentropy', metrics=metr)
    

# Train model
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=[cb1, cb2])

# Test model
testloss, testaccuracy, testauc, testrecall, testprecision, testTP, testTN, testFP, testFN = model.evaluate(test_ds)

# Save model
#! mkdir -p saved_model
model.save('saved_model/birads0model_resnet152v2_pt_a1928')
