In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, LSTM, Bidirectional, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy
from data_generator import DataGenerator

In [None]:
class TrainerWrapper:
    def __init__(self, name):
        self._model = None
        self._is_summary_shown = False
        self._network_name = name
    
    def _load_data(self, seed):
        data_generator = DataGenerator(seed=seed, upsampling=5, validation_split=0.2)
        self._validation_X, self._validation_y = data_generator.get_validation_data()
        self._training_generator = data_generator.get_training_generator(1000)
    
    def _create_network(self):
        self._model = Sequential()
        self._model.add(Input(shape=(140, 4)))
        self._model.add(Bidirectional(LSTM(100, return_sequences=True)))
        self._model.add(Bidirectional(LSTM(100)))
        self._model.add(Dense(32, activation='relu'))
        self._model.add(Dropout(0.3))
        self._model.add(Dense(3, activation='sigmoid'))
    def _compile_model(self):
        self._model.compile(
            loss=categorical_crossentropy, 
            optimizer=Adam(learning_rate=0.001), 
            metrics=['accuracy']
        )
    
    def _fit(self, epochs):
        self._model.fit(
            self._training_generator,
            epochs=epochs,
            shuffle=True,
            validation_data=(self._validation_X, self._validation_y),
            verbose=True,
            callbacks=[
                TensorBoard(
                    log_dir=f'logs/{self._name}',
                    write_graph=False,
                    write_images=False,
                    update_freq="epoch"
                )
            ]
        )
        
        
    def fit(self, experiments=5, epochs=10):
        for seed in range(experiments):
            self._name = f'{self._network_name}/e{seed}/'
            self._load_data(seed=seed)
            self._create_network()
            self._compile_model()
            if not self._is_summary_shown:
                self._model.summary()
                self._is_summary_shown = True
            self._fit(epochs)
            tf.keras.backend.clear_session()

In [None]:
execute_on_gpu = True

with tf.device('/gpu:0' if execute_on_gpu else '/cpu:0'):
    model = TrainerWrapper(name='blstm')
    model.fit(1, 100)