# Necessary Libraries / Imports

In [None]:
import sys
sys.path.append("..")

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv1D, Dense, Dropout, Activation, Input, MaxPool1D, Flatten, Dropout
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.callbacks import TensorBoard
from data_generator import DataGenerator
import os

In [None]:
class TrainerWrapper:
    def __init__(self, name):
        self._model = None
        self._is_summary_shown = False
        self._network_name = name
    
    def _load_data(self, seed):
        data_generator = DataGenerator(seed=seed, upsampling=20, validation_split=0.2)
        self._validation_X, self._validation_y = data_generator.get_validation_data()
        self._training_generator = data_generator.get_training_generator(2000)
        
    def _create_network(self):
        inputs = Input(shape=(140, 4))
        conv = Conv1D(filters=4, kernel_size=3, strides=1, padding='same', activation='relu')(inputs)
        max_pooled = MaxPool1D(pool_size=7)(conv)
        flatten = Flatten()(max_pooled)
        classifier = Dense(units=5, activation='relu')(flatten)
        outputs = Dense(units=3, activation='softmax')(classifier)
        self._model = Model(inputs=inputs, outputs=outputs)
    
    def _compile_model(self):
        self._model.compile(
            loss=categorical_crossentropy, 
            optimizer=SGD(learning_rate=0.01), 
            metrics=['accuracy']
        )
    
    def _fit(self, epochs):
        self._model.fit(
            self._training_generator,
            epochs=epochs,
            shuffle=True,
            validation_data=(self._validation_X, self._validation_y),
            verbose=False,
            callbacks=[
                TensorBoard(
                    log_dir=f'../logs/{self._name}',
                    write_graph=False,
                    write_images=False,
                    update_freq="epoch"
                )
            ]
        )
        
        
    def fit(self, experiments=5, epochs=10):
        for seed in range(experiments):
            self._name = f'{self._network_name}/e{seed}/'
            self._load_data(seed=seed)
            self._create_network()
            self._compile_model()
            if not self._is_summary_shown:
                self._model.summary()
                self._is_summary_shown = True
            self._fit(epochs)
            tf.keras.backend.clear_session()

# Loading the data

In [None]:
execute_on_gpu = True

with tf.device('/gpu:0' if execute_on_gpu else '/cpu:0'):
    model = TrainerWrapper(name='cnn_exp_1')
    model.fit(10, 300)