In [None]:
#import libraries
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
import math
import os
import cv2 as cv
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures, StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
import sys
import pickle
import shutil
from PIL import Image


# Data augmentation and preprocessing

In [None]:
DIR_TRASHNET = "../input/trashnet-limited/trashnet_limited/"
DIR_TRASHBOX = "../input/trashbox-limited/trashbox_limited/"
DIR_WILD = "../input/waste-wild/waste_wild/"
DIR_WEIGHTS = "./weights/"

def get_data_from_dir(dir, batch_size=128):
  train_dataset = tf.keras.preprocessing.image_dataset_from_directory(dir, validation_split=0.1, subset="training", seed=42, batch_size=batch_size, smart_resize = True, image_size=(256, 256))
  print(train_dataset)
  test_dataset = tf.keras.preprocessing.image_dataset_from_directory(dir, validation_split=0.1, subset="validation", seed=42, batch_size=batch_size, smart_resize = True, image_size=(256, 256))
  class_names = train_dataset.class_names
  AUTOTUNE = tf.data.AUTOTUNE
  train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
  test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
  return train_dataset, test_dataset, class_names

In [None]:
trashbox_train, trashbox_test, classes = get_data_from_dir(DIR_TRASHBOX)
trashnet_train, trashnet_test, classes = get_data_from_dir(DIR_TRASHNET)

In [None]:
waste_wild = tf.keras.preprocessing.image_dataset_from_directory(DIR_WILD, smart_resize=True, image_size=(256,256))

In [None]:
concat_train = trashbox_train.concatenate(trashnet_train) 

In [None]:
concat_test = trashbox_test.concatenate(trashnet_test)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
concat_train = concat_train.prefetch(buffer_size=AUTOTUNE)
concat_test = concat_test.prefetch(buffer_size=AUTOTUNE)

# Model definitions

In [None]:
def layers_for_data_augmentation():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.RandomFlip('horizontal_and_vertical'))
  model.add(tf.keras.layers.RandomRotation(0.2))
  return model

def get_model(numClasses, data_augmentation_list = [], network = 'MobileNetV3', reg_method = 'l1'):
  
  base_model = tf.keras.Sequential()

  # Preprocessing
  # if network == 'ResNet152V2' or network == 'inception_resnet':
  #   base_model.add(tf.keras.layers.Rescaling(1./127.5, offset=-1))

  # # Data augmentation
  if len(data_augmentation_list) > 0:
      base_model.add(layers_for_data_augmentation())

  # Different models
  if network == 'MobileNetV3' or network == 'm':
    base_model.add(tf.keras.applications.MobileNetV3Large(input_shape=(256, 256,3), weights='imagenet', include_top=False, classes=numClasses))
  elif network == 'ResNet50' or network == 'rn50':
    base_model.add(tf.keras.applications.ResNet50(input_shape=(256,256,3), weights = 'imagenet', include_top = False, classes = numClasses))
  elif network == 'ResNet152V2' or network == 'rn152v2':
    base_model.add(tf.keras.applications.ResNet152V2(input_shape=(256,256,3), weights = 'imagenet', include_top = False, classes = numClasses))
  elif network == 'inception_resnet':
    base_model.add(tf.keras.applications.InceptionResNetV2(input_shape=(256,256,3), weights = 'imagenet', include_top = False, classes = numClasses))
  else: raise ValueError("unknown network type")

  # Freeze layers
  for layers in base_model.layers[-1].layers[:-6]:
    layers.trainable=False

  # Trainable layers
  base_model.add(tf.keras.layers.Dropout(0.45))
  base_model.add(tf.keras.layers.GlobalAveragePooling2D())
  base_model.add(tf.keras.layers.BatchNormalization())
  if reg_method == 'l1': regularizer = tf.keras.regularizers.l1(0.045)
  elif reg_method == 'l2': regularizer = tf.keras.regularizers.l2(0.045)
  else: raise ValueError('unknown/unsupported regularizer. Set to l1 or')
  base_model.add(tf.keras.layers.Dense(256, activation = tf.keras.activations.elu, kernel_regularizer=regularizer, activity_regularizer=regularizer,  kernel_initializer='he_normal'))
  base_model.add(tf.keras.layers.Dropout(0.45))
  base_model.add(tf.keras.layers.Dense(numClasses, activation=tf.keras.activations.softmax))

  return base_model

In [None]:
class CustomHistoryCheckPoint(tf.keras.callbacks.History):
    def __init__(self, filepath = '', save_freq = 1):
        super(CustomHistoryCheckPoint, self).__init__()
        self.epoch_accuracy = {} # loss at given epoch
        self.epoch_loss = {} # accuracy at given epoch
        self.filepath = filepath
        self.save_freq = save_freq

    def on_epoch_end(self, epoch, logs={}):
        super().on_epoch_end(epoch, logs)
        # things done on end of the epoch
        if epoch % self.save_freq == 0:
          self.write_hist()

    def write_hist(self):
        if not os.path.exists(self.filepath): os.makedirs(self.filepath)
        with open(self.filepath + 'hist.pickle', 'wb+') as file_pi:
          pickle.dump(self.history, file_pi)

In [None]:
def run_model(model, train_dataset, test_dataset, epochs = 10, model_name = '', save_weights = True, save_history = True, save_freq = 20, batch_size = 128):
  callbacks = []
  if model_name:
    model_dir = './'
    checkpoint_path = model_dir + "weights/cp-{epoch:04d}.ckpt"
    history_path = model_dir + "history/"
    final_weights_path = model_dir + "weights/%s.h5" % (model_name)
  
  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00075), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])
  
  if save_history:
    hist_callback = CustomHistoryCheckPoint(history_path)
    callbacks.append(hist_callback)

  if save_weights:
      cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path, 
        verbose=1, 
        save_weights_only=True,
        save_freq=save_freq*batch_size)
      callbacks.append(cp_callback)

  train_hist = model.fit(train_dataset, validation_data=test_dataset, epochs=epochs, verbose=1, callbacks=callbacks)

  #show_evaluation_graph(history)
  #loss, acc = model.evaluate(test_dataset)
  if save_weights:
      model.save("saved")
      converter = tf.lite.TFLiteConverter
      lite_model = converter.from_keras_model(model)
      tflite_model = lite_model.convert()
      with open('%s.tflite' %(model_name), 'wb') as f:
          f.write(tflite_model)
            
  show_evaluation_graph(train_hist)

def show_evaluation_graph(history, from_pickle = False):
  hist = history if from_pickle else history.history  
  plt.plot(hist['accuracy'])
  plt.plot(hist['val_accuracy'])
  plt.title('model accuracy')
  plt.ylabel('accuracy')
  plt.xlabel('epoch')
  plt.legend(['train', 'test'], loc='upper left')
  plt.savefig('./history.png')
  plt.show()
  # summarize history for loss
  plt.plot(hist['loss'])
  plt.plot(hist['val_loss'])
  plt.title('model loss')
  plt.ylabel('loss')
  plt.xlabel('epoch')
  plt.legend(['train', 'test'], loc='upper left')
  plt.show()
  

def single_picture_prediction(model, path_to_image):


  img = tf.keras.preprocessing.image.load_img(path_to_image, target_size=(256, 256))
  img_array = tf.keras.preprocessing.image.img_to_array(img)
  img_array = tf.expand_dims(img_array, 0) 

  predictions_total = model.predict(img_array)
  return predictions_total

In [None]:
mv3_concat_100 = get_model(5, reg_method='l2', data_augmentation_list=['yes'])

In [None]:
run_model(mv3_concat_100, concat_train, concat_test, epochs=100, model_name='mv3_tb_200', save_freq=25)

In [None]:
#mv3_concat_l2_200.predict(waste_wild)

In [None]:
def evaluate_on_wild_dumb(model):
    for (dirpath, dirname, filenames) in os.walk('../input/waste-wild/waste_wild/'):
        for filename in filenames:
            infilename = os.path.join(dirpath,filename)
            if not os.path.isfile(infilename): continue
            img = tf.keras.preprocessing.image.load_img(infilename, target_size=(256, 256))
            img_array = tf.keras.preprocessing.image.img_to_array(img)
            img_array = tf.expand_dims(img_array, 0) 
            pred = model.predict(img_array)
            print('image:')
            display(img)
            print('predicted: ' + classes[np.argmax(pred)])
            print('full prediction: ' + str(pred))
            

In [None]:
evaluate_on_wild_dumb(mv3_concat_100)