<a href="https://colab.research.google.com/github/jenieto/pollen-classification/blob/master/Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Montamos Google Drive
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
# Extract data
!mkdir -p /data
!tar xvzf "/content/drive/My Drive/Datasets/anuka1200.tar.gz" --directory /data

In [None]:
# Create datasets
import tensorflow as tf
import numpy as np

training_size = 0.8
validation_size = 0.2
image_height = 96
image_width = 96
image_channels = 1
image_size = (image_height, image_width, image_channels)
directory = '/data/anuka1200'

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    directory=directory,
    subset='training',
    labels='inferred',
    validation_split=validation_size,
    seed=123,
    color_mode='grayscale',
    image_size=(image_height, image_width))

validation_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    directory=directory,
    subset='validation',
    labels='inferred',
    validation_split=validation_size,
    seed=123,
    color_mode='grayscale',
    image_size=(image_height, image_width))

class_names = train_dataset.class_names
print('Class Names', class_names)

In [None]:
# Explore data
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    image = np.array(images[i].numpy().astype("uint8"))
    if image_channels == 1:
      plt.imshow(image.squeeze(), cmap='gray')
    else:
      plt.imshow(image)
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [84]:
# Create model
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.applications import MobileNetV2


def create_models():
  model_0 = Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=image_size),
    layers.Flatten(),
    layers.Dense(1, activation='sigmoid')
  ])

  model_1 = Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=image_size),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid')
  ])

  model_2 = Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=image_size),
    layers.Conv2D(64, (3, 3), activation='relu'),
    # layers.Dropout(0.25),
    layers.MaxPooling2D(),
    layers.Flatten(),
    #layers.Dropout(0.25),
    layers.Dense(128, activation='relu'), #, kernel_regularizer=tf.keras.regularizers.l2(0.001)),
    layers.Dense(1, activation='sigmoid')
  ])

  model_3 = Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=image_size),
    layers.Conv2D(16, 3, activation='relu'),
    # layers.Dropout(0.25),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    # layers.Dropout(0.25),
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid')
  ])

  '''mobile_net = MobileNetV2(weights='imagenet', include_top=False, input_shape=image_size, pooling='avg') #Load the MobileNet v2 model
  mobile_net.trainable = False
  model_4 = tf.keras.models.Sequential([
    mobile_net,
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid')
  ])'''

  return {
      'model_0': model_0,
      'model_1': model_1,
      'model_2': model_2,
      'model_3': model_3,
      #'model_4': model_4
  } 

In [88]:
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from datetime import datetime
from IPython.display import Image

def display_model(model_name, model):
  plot_file = f'{model_name}.png'
  plot_model(model, to_file=plot_file, show_shapes=True, show_layer_names=True) 
  Image(retina=True, filename=plot_file)

# Compile models
def compile_models(models_dict):
  for model_name, model in models_dict.items():
    model.compile(optimizer='adam',
              loss='  y',#tf.keras.losses.categorical_crossentropy(),
              metrics=['binary_accuracy']) 
    display_model(model_name, model)
    #print(f'{model_name} summary:')
    #model.summary()

  
def model_callbacks(model_name):
  filepath_mdl = f'{model_name}.h5'
  checkpoint = ModelCheckpoint(filepath_mdl, monitor='val_loss', verbose=1, save_best_only=True) # Va guardando los pesos tras cada época
  log_dir = f"logs/fit/{model_name}/" + datetime.now().strftime("%Y%m%d-%H%M%S")
  tensorboard = TensorBoard(log_dir=log_dir, write_graph=True, write_images=True) # Para graficado de las estadísticas durante el entrenamiento
  earlystopping = EarlyStopping(patience=10, verbose=1) # Detiene el entrenamiento prematuramente si validation accuracy lleva sin aumentar varias épocas
  reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, epsilon=1e-4, mode='min')
  return [checkpoint, tensorboard, earlystopping, reduce_lr_loss]

In [86]:
# Train models
def train_models(models_dict):
  epochs=50
  model_results = {}
  for model_name, model in models_dict.items():
    print('-----------------------------------------------------------------------')
    print(f'Fitting model {model_name}')
    history = model.fit(
      train_dataset,
      validation_data=validation_dataset,
      epochs=epochs,
      callbacks=model_callbacks(model_name))
    model.load_weights(f'{model_name}.h5')
    score = model.evaluate(validation_dataset)
    model_results[model_name] = score
  return model_results


In [None]:
# Execute
models = create_models()
compile_models(models)
# results = train_models(models)

print('------------ Results -----------------')
for model_name, score in results.items():
  print(f'{model_name}: validation_loss: {score[0]}, validation_accuracy:{score[1]}')

In [None]:
# Start Tensorboard
%load_ext tensorboard
%tensorboard --logdir logs