## Import Libraries

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras import backend as K
from keras import metrics
from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau, Callback, LearningRateScheduler
import transformers as t
import tensorflow_addons as tfa

In [None]:
from datasets import load_dataset
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import classification_report, confusion_matrix, auc, roc_curve
from sklearn.utils import class_weight
from sklearn.metrics import roc_auc_score, confusion_matrix, plot_confusion_matrix, roc_curve
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pickle
from matplotlib.ticker import MaxNLocator

## Global Variables

Check gpu

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
num_classes = 2
input_shape = (224, 224, 3)
seed = 14

In [None]:
path_tabular = '/home/iexpress/TCGA/data/'
path_images = '/home/iexpress/TCGA/data/images/'

Cargamos los datos

In [None]:
tab_dat = pd.read_csv(os.path.join(path_tabular, 'clean_data.csv'), sep = ';')
selected_columns = ["|" not in i for i in tab_dat.columns]
tab_dat = tab_dat[tab_dat.columns[selected_columns]]
tab_X_train = tab_dat[tab_dat.dataset.values == 'train']
tab_X_val = tab_dat[tab_dat.dataset.values == 'val']
tab_X_test = tab_dat[tab_dat.dataset.values == 'test']
y_train = tab_X_train['target']
y_val = tab_X_val['target']
y_test = tab_X_test['target']
del tab_X_train['dataset']
del tab_X_val['dataset']
del tab_X_test['dataset']
del tab_X_train['target']
del tab_X_val['target']
del tab_X_test['target']
del tab_X_train['index']
del tab_X_val['index']
del tab_X_test['index']

In [None]:
X_train = pickle.load(open(os.path.join(path_tabular, 'X_train'), "rb"))
X_val = pickle.load(open(os.path.join(path_tabular, 'X_val'), "rb"))
X_test = pickle.load(open(os.path.join(path_tabular, 'X_test'), "rb"))

In [None]:
print('Train tab data size = ' + str(tab_X_train.shape))
print('Train image data size = ' + str(X_train.shape))
print('Train target size = ' + str(y_train.shape))
print('Validation tab data size = ' + str(tab_X_val.shape))
print('Validation image data size = ' + str(X_val.shape))
print('Validation target size = ' + str(y_val.shape))
print('Test tab data size = ' + str(tab_X_test.shape))
print('Test target size = ' + str(y_test.shape))
print('Test image data size = ' + str(X_test.shape))

## Hyperparameters

In [None]:
learning_rate = 0.001
weight_decay = 0.0001
decay = 0.9
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

## Model Definition

In [None]:
def grafica_entrenamiento(tr_auc, val_auc, tr_loss, val_loss, best_i,
                          figsize=(10,5), path_results = None):
    plt.figure(figsize=figsize)
    ax = plt.subplot(1,2,1)
    plt.plot(1+np.arange(len(tr_loss)), np.array(tr_loss))
    plt.plot(1+np.arange(len(val_loss)), np.array(val_loss))
    plt.plot(1+best_i, val_loss[best_i], 'or')
    plt.title('loss del modelo', fontsize=18)
    plt.ylabel('loss', fontsize=12)
    plt.xlabel('época', fontsize=18)        
    plt.legend(['entrenamiento', 'validación'], loc='upper left')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    plt.subplot(1,2,2)
    
    plt.plot(1+np.arange(len(tr_auc)),  np.array(tr_auc))
    plt.plot(1+np.arange(len(val_auc)), np.array(val_auc))
    plt.plot(1+best_i, val_auc[best_i], 'or')
    plt.title('AUC', fontsize=18)
    plt.ylabel('AUC', fontsize=12)
    plt.xlabel('época', fontsize=18)    
    plt.legend(['entrenamiento', 'validación'], loc='upper left')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    if (path_results != None):
        plt.savefig(os.path.join(path_results, 'auc_loss.png'))
    plt.show()
    
class TrainingPlot(Callback):
    
    # This function is called when the training begins
    def on_train_begin(self, logs={}):
        # Initialize the lists for holding the logs, losses and accuracies
        self.losses = []
        self.auc = []
        self.val_losses = []
        self.val_auc = []
        self.logs = []
    
    # This function is called at the end of each epoch
    def on_epoch_end(self, epoch, logs={}):       
       
        
        # Before plotting ensure at least 10 epochs have passed
        if epoch > 2:
             # Append the logs, losses and accuracies to the lists
            self.logs.append(logs)        
            self.auc.append(logs.get('auc'))        
            self.val_auc.append(logs.get('val_auc'))
            self.losses.append(logs.get('loss'))
            self.val_losses.append(logs.get('val_loss'))
            best_i = np.argmax(self.val_auc)
            grafica_entrenamiento(self.auc, self.val_auc, self.losses, self.val_losses, best_i)

plot_losses = TrainingPlot()

In [None]:
class print_learning_rate(Callback):
    def on_epoch_begin(self, epoch, logs=None):
        lr = self.model.optimizer.lr
        print(f'Learning rate = {K.eval(lr):.5f}')
print_lr = print_learning_rate()

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded


In [None]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        #layers.RandomFlip("horizontal"),
        #layers.RandomRotation(factor=0.02),
        #layers.RandomZoom(
        #    height_factor=0.2, width_factor=0.2
        #),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(X_train)

In [None]:
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(1, activation='sigmoid')(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

In [None]:
def run_experiment(model, X_train, y_train, X_val, y_val, class_weights = 0.5):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
    loss='binary_crossentropy',
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[metrics.BinaryAccuracy(name = "acc"),
            metrics.AUC(name = "auc")])

#     checkpoint_filepath = "/home/iexpress/TCGA/checkpoints/transformer"
#     checkpoint_callback = keras.callbacks.ModelCheckpoint(
#         checkpoint_filepath,
#         monitor="val_accuracy",
#         save_best_only=True,
#         save_weights_only=True,
#     )
    
    def lr_time_based_decay(epoch, lr):
        return lr * decay

    early_stopping = EarlyStopping(monitor='val_auc', patience=10, verbose=1, mode='max', restore_best_weights=True)

    reduce_lr = ReduceLROnPlateau(monitor='val_auc', factor=0.1, patience=3, mode='min')    

    lr_decay = LearningRateScheduler(lr_time_based_decay, verbose=1)

    history = model.fit(
        x=X_train,
        y=y_train,
        validation_data=(X_val, y_val),
        batch_size=batch_size,
        epochs=num_epochs,
        class_weight=class_weights,
        callbacks=[early_stopping,  lr_decay, plot_losses, print_lr],
        verbose = 1,
    )

    # model.load_weights(checkpoint_filepath)
    
    return history, model

In [None]:
class_weights = class_weight.compute_class_weight(class_weight = 'balanced',
                                                 classes = np.unique(y_train),
                                                 y = y_train)
class_weights = dict(enumerate(class_weights))

In [None]:
model = create_vit_classifier()
# print(model.summary())

## MODEL TRAINING

In [None]:
history, model = run_experiment(model, 
                                X_train, y_train, 
                                X_val, y_val, 
                                class_weights)

## Predict

In [None]:
pred_train = model.predict(X_train)
pred_val = model.predict(X_val)
pred_test = model.predict(X_test)

## Metrics

In [None]:
auc_train = roc_auc_score(y_true = y_train, y_score = pred_train)
auc_val = roc_auc_score(y_true = y_val, y_score = pred_val)
auc_test = roc_auc_score(y_true = y_test, y_score = pred_test)
print('AUC train = %s - AUC val = %s - AUC test = %s' % (str(auc_train), str(auc_val), str(auc_test)))