## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../")

In [None]:
import os
from pathlib import Path
import json

from src.lstm import LSTMSimple
from src.metrics import exact_match_metric
from src.callbacks import NValidationSetsCallback, GradientLogger
from src.generator import DataGenerator, DataGeneratorSeq
from src.utils import get_sequence_data

import tensorflow as tf
from tensorflow.keras.optimizers import Adam

print(tf.__version__)
print("GPU Available: ", tf.test.is_gpu_available())

## Paths

In [None]:
SETTINGS = Path('../settings/')
DATA = Path('../data/processed/')

## Load settings

In [None]:
settings_path = Path(SETTINGS/'settings_local.json')

In [None]:
with open(str(settings_path), 'r') as file:
    settings_dict = json.load(file)

In [None]:
settings_dict

## Load data

In [None]:
data_gen_pars, input_texts, target_texts = get_sequence_data(settings_dict)

In [None]:
print('Number of training samples:', len(input_texts['train']))

In [None]:
print('Number of validation samples:', len(input_texts['interpolate']))

In [None]:
data_gen_pars.keys()

## Data generators

In [None]:
training_generator = DataGeneratorSeq(
    input_texts=input_texts["train"],
    target_texts=target_texts["train"],
    **data_gen_pars
)
validation_generator = DataGeneratorSeq(
    input_texts=input_texts["valid"],
    target_texts=target_texts["valid"],
    **data_gen_pars
)
interpolate_generator = DataGeneratorSeq(
    input_texts=input_texts["interpolate"],
    target_texts=target_texts["interpolate"],
    **data_gen_pars
)
extrapolate_generator = DataGeneratorSeq(
    input_texts=input_texts["extrapolate"],
    target_texts=target_texts["extrapolate"],
    **data_gen_pars
)

## Load model

In [None]:
lstm = LSTMSimple(data_gen_pars['num_tokens'], settings_dict['embedding_dim'])
model = lstm.get_model()
model.summary()

In [None]:
adam = Adam(
    lr=6e-4,
    beta_1=0.9,
    beta_2=0.995,
    epsilon=1e-9,
    decay=0.0,
    amsgrad=False,
    clipnorm=0.1,
)
model.compile(
    optimizer=adam, loss="categorical_crossentropy", metrics=[exact_match_metric]
)

## Configure callbacks

In [None]:
valid_dict = {
    'validation':validation_generator,
    'interpolation': interpolate_generator,
    'extrapolation': extrapolate_generator
}

In [None]:
history = NValidationSetsCallback(valid_dict)

In [None]:
# directory where the checkpoints will be saved
checkpoint_dir = settings_dict["save_path"] + "training_checkpoints"
# name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix, save_weights_only=True
)

## Train model

In [None]:
train_hist = model.fit_generator(
    training_generator,
    epochs=settings_dict["epochs"],
    callbacks=[history, checkpoint_callback],
    verbose=1,
)