In [1]:
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from tqdm import tqdm, tqdm_notebook
import random

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.applications import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.initializers import *
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import keras.backend as K
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix


from numpy.random import seed
import pickle
from tensorflow import keras as ks
from tensorflow.keras import layers
import pathlib
seed(42)
tf.random.set_seed(42)

In [2]:
class DNN_model:

    def __init__(self, model, name):
      """
        Class to automatize model definition and usage. An instance of DNN_model represents and contains all the methods to handle the model, train it, and save all results on disk.
        Parameters:
            base_model: a pre_trained keras.applications model to act as backbone for the final model. NULL to define a custom architecture
            name: name of the model
        Returns:
          an instance of DNN_model
      """

      self.model = model
      self.history = None
      self.name = name
      self.class_weights=None

      self.dir_name ="/content/drive/Shareddrives/SEAI Project/Models"
      if not os.path.exists(self.dir_name):
        os.makedirs(self.dir_name)
      self.save_path = os.path.join(self.dir_name, name+'.h5')

      self.callbacks_list = [
                ks.callbacks.ModelCheckpoint(
                    filepath=self.save_path,
                    monitor="val_loss",
                    verbose=1,
                    save_best_only=True) #saves the best model in terms of the metric in monitor
                ]

    def compile_and_fit(self, train, val, optimizer=ks.optimizers.Adam(), loss='binary_crossentropy', learning_rate=None, epochs=10, patience=None, class_weights=None):
        if learning_rate is not None:
            optimizer.learning_rate = learning_rate
        
        if patience is not None and patience!=0:
          self.callbacks_list.append(ks.callbacks.EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True))

        if class_weights is not None:
          self.class_weights = class_weights

        self.model.compile(
            optimizer=optimizer,
            loss=loss,
            metrics=['accuracy']
        )

        self.history = self.model.fit(
             train,
             epochs=epochs,
             validation_data=val,
             callbacks=self.callbacks_list,
             class_weight = class_weights
        )   

        self.save_history()

    def plot_model(self, model_name=None):
      if model_name is None:
        model_name = self.name
      ks.utils.plot_model(self.model, to_file=str(model_name) +".jpg", show_shapes=True)
    
    def summary(self):
      self.model.summary()
    
    def unfreeze_layers(self, block_name):

      self.model.trainable = True
      set_trainable = False

      for layer in self.model.layers:
          if layer.name == block_name:
              set_trainable = True
          if set_trainable and not isinstance(layer, layers.BatchNormalization):
              layer.trainable = True
          else:
              layer.trainable = False

      return self.model

    def train_more(self, epochs, patience=None):
      if patience is not None and patience!=0:
        if len(self.callbacks_list)>1:
          self.callbacks_list[1] = ks.callbacks.EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True)
        else:
          self.callbacks_list.append(ks.callbacks.EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True))
      
      if self.history is None:
        self.load_history()

      self.new_history = self.model.fit(
             self.training_set,
             epochs=epochs,
             validation_data=self.val_set,
             callbacks=self.callbacks_list,
             class_weight = self.class_weights
        )
      for key in self.history.history.keys():
        self.history.history[key].extend(self.new_history.history[key])
      
      self.save_history()

    def load_best_model(self):
      if os.path.exists(self.save_path):
        self.model = ks.models.load_model(self.save_path)
        return self.model
      else:
        print ("No model has been saved") 

    def save_history(self):
      with open(os.path.join(self.dir_name, self.name), 'wb') as file_pi:
        pickle.dump(self.history, file_pi)

    def load_history(self):
      with open(os.path.join(self.dir_name, self.name), 'rb') as file_pi:
        history = pickle.load(file_pi)
      self.history=history  

    def evaluate(self, test_set):
      test_loss, test_acc = self.model.evaluate(test_set)
      print(f"Test accuracy: {test_acc:.3f}, test loss: {test_loss:.3f}")


    def summary(self):
      self.model.summary()
    
    def reset_model(self):
      if os.path.exists(self.save_path):
        os.remove(self.save_path) #remove model file
        os.remove(os.path.join(self.dir_name, self.name)) #remove history
      session = K.get_session()
      for layer in self.model.layers: 
          if hasattr(layer, 'kernel.initializer'): 
              layer.kernel.initializer.run(session=session)
          if hasattr(layer, 'bias.initializer'):
              layer.bias.initializer.run(session=session) 
    
    def plot_accuracy(self):
      if self.history is None:
        self.load_history()

      history = self.history.history
      acc = history['accuracy']
      val_acc = history['val_accuracy']
      loss = history['loss']
      val_loss = history['val_loss']
      best_loss_epoch = np.argmin(val_loss)
      best_loss = np.min(val_loss)
      best_acc = val_acc[best_loss_epoch]

      epochs = range(len(acc))

      plt.plot(epochs, acc, 'b.', label='Training accuracy')
      plt.plot(epochs, val_acc, 'g.', label='Validation accuracy')
      plt.title('Training and validation accuracy on '+ self.name)
      plt.legend(loc='upper right')

      plt.figure()

      plt.plot(epochs, loss, 'b', label='Training loss')
      plt.plot(epochs, val_loss, 'g', label='Validation loss')
      plt.plot(best_loss_epoch, best_loss, 'ro', label='Selected model')
      plt.title('Training and validation loss on '+ self.name)
      plt.legend(loc='upper right')

      plt.show()
      print(f"Lowest validation loss has been reached at epoch {(best_loss_epoch+1):d} with validation accuracy of {best_acc:.3f}")  

    def plot_confusion_matrix(self, test_set):
      y_true = tf.concat([labels_batch for data_batch, labels_batch in test_set], axis = 0)
      y_true_cat = np.argmax(y_true, -1)

      pred_Y = self.model.predict(test_set, verbose = True)
      pred_Y_cat = np.argmax(pred_Y, -1)

      fig = plt.figure(figsize=(20,10))

      sns.heatmap(confusion_matrix(y_true_cat, pred_Y_cat), 
              annot=True, fmt="d", cbar = False, cmap = plt.cm.Blues)
      
      plt.show()