# 0 - Requirements

In [3]:
import tensorflow as tf
import matplotlib.pyplot as plt
from model.tf.ViT_model import ViT_UNet

In [4]:
tf.config.list_physical_devices('GPU')

[]

# 1 - Data

In [None]:
# Obtenemos en una lista todos los nombres de las imágenes de la carpeta y los parámetros necesarios
img_path = '/input/'
img_shape = (128,128)
batch_size = 32
seed = 123 #Será necesario para que el método de Keras separe de forma coherente los datos
class_names = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']
num_classes = len(class_names)
# Importamos los datos con Keras
train_data = tf.keras.preprocessing.image_dataset_from_directory(
  img_path + 'train/',
  validation_split=0.2,
  subset='training',
  color_mode = 'grayscale',
  shuffle = False,
  seed=seed,
  image_size=img_shape,
  batch_size=batch_size)

val_data = tf.keras.preprocessing.image_dataset_from_directory(
  img_path + 'val/',
  validation_split=0.2,
  subset='validation',
  color_mode = 'grayscale',
  shuffle = False,
  seed=seed,
  image_size=img_shape,
  batch_size=batch_size)

# Expresamos las etiquetas en formato ont.hot encoding
train_data = train_data.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))
val_data = val_data.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))

In [None]:
#Extraemos algo de información de los datos para comprobar que han sido bien importados
plt.figure(figsize=(10, 10))
for images, labels in val_data.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

# 2 - Fit model

In [None]:
# Parámetros
epochs = 1
# Instanciamos el nuevo modelo con la GPU habilitada
input_layer = [tf.keras.Input(shape = (128,128,1))]
preprocessing_layers = [tf.keras.layers.experimental.preprocessing.Resizing(128,128),
                        tf.keras.layers.experimental.preprocessing.Rescaling(scale = 1./255, offset = 0.0)]
if len(tf.config.list_physical_devices('GPU'))>0:
  print("Número de GPUs disponibles: ", len(tf.config.list_physical_devices('GPU')))
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    model = ViT_UNet(depth = 3,
                       depth_te = 4,
                       linear_list = [4],
                       preprocessing = 'conv',
                       num_patches = 64,
                       patch_size = 16,
                       num_channels = 1,
                       hidden_dim = 64,
                       num_heads = 8,
                       attn_drop = .2,
                       proj_drop = .2,
                       linear_drop = .2,
                       )
    # Escribimos algunos métodos a aplicar durante el entrenamiento
    callbacks = [
                 tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 2),
                 tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = .1)
    ]
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
# Entrenamos modelo 
#model.fit(train_data, epochs=epochs, batch_size = batch_size, validation_data = val_data, callbacks = callbacks, verbose = 1)