## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [7]:
import sys
sys.path.append("../")

In [25]:
import os
from pathlib import Path
import json

from src.lstm import SimpleLSTM
from src.metrics import exact_match_metric
from src.callbacks import NValidationSetsCallback
from src.generators import DataGeneratorSeq
from src.utils import get_sequence_data

import tensorflow as tf
from tensorflow.keras.optimizers import Adam

print(f"Using TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.test.is_gpu_available()}")

Using TensorFlow version: 2.0.0-beta1
GPU Available: False


## Paths

In [9]:
SETTINGS = Path('../settings/')
DATA = Path('../data/processed/')

## Load settings

In [12]:
settings_path = Path(SETTINGS/'settings_local.json')

In [13]:
with open(str(settings_path), 'r') as file:
    settings_dict = json.load(file)

In [14]:
settings_dict

{'math_module': 'arithmetic__add_sub',
 'train_level': '*',
 'batch_size': 1024,
 'thinking_steps': 16,
 'epochs': 1,
 'num_encoder_units': 512,
 'num_decoder_units': 2048,
 'embedding_dim': 2048,
 'save_path': '../data/',
 'data_path': '../data/'}

## Load data

In [15]:
data_gen_pars, input_texts, target_texts = get_sequence_data(settings_dict)

In [16]:
print('Number of training samples:', len(input_texts['train']))

Number of training samples: 1599998


In [17]:
print('Number of validation samples:', len(input_texts['interpolate']))

Number of validation samples: 10000


In [18]:
data_gen_pars.keys()

dict_keys(['batch_size', 'max_encoder_seq_length', 'max_decoder_seq_length', 'max_seq_length', 'num_encoder_tokens', 'num_decoder_tokens', 'num_tokens', 'input_token_index', 'target_token_index', 'token_index', 'num_thinking_steps'])

## Data generators

In [19]:
training_generator = DataGeneratorSeq(
    input_texts=input_texts["train"],
    target_texts=target_texts["train"],
    **data_gen_pars
)
validation_generator = DataGeneratorSeq(
    input_texts=input_texts["valid"],
    target_texts=target_texts["valid"],
    **data_gen_pars
)
interpolate_generator = DataGeneratorSeq(
    input_texts=input_texts["interpolate"],
    target_texts=target_texts["interpolate"],
    **data_gen_pars
)
extrapolate_generator = DataGeneratorSeq(
    input_texts=input_texts["extrapolate"],
    target_texts=target_texts["extrapolate"],
    **data_gen_pars
)

## Load model

In [20]:
lstm = SimpleLSTM(data_gen_pars['num_tokens'], settings_dict['embedding_dim'])
model = lstm.get_model()
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, 34)]        0         
_________________________________________________________________
lstm (LSTM)                  [(None, None, 2048), (Non 17063936  
_________________________________________________________________
dense (Dense)                (None, None, 34)          69666     
Total params: 17,133,602
Trainable params: 17,133,602
Non-trainable params: 0
_________________________________________________________________


In [21]:
adam = Adam(
    lr=6e-4,
    beta_1=0.9,
    beta_2=0.995,
    epsilon=1e-9,
    decay=0.0,
    amsgrad=False,
    clipnorm=0.1,
)
model.compile(
    optimizer=adam, loss="categorical_crossentropy", metrics=[exact_match_metric]
)

## Configure callbacks

In [22]:
valid_dict = {
    'validation':validation_generator,
    'interpolation': interpolate_generator,
    'extrapolation': extrapolate_generator
}

In [23]:
history = NValidationSetsCallback(valid_dict)

In [24]:
# directory where the checkpoints will be saved
checkpoint_dir = settings_dict["save_path"] + "training_checkpoints"
# name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix, save_weights_only=True
)

## Train model

In [None]:
train_hist = model.fit_generator(
    training_generator,
    epochs=settings_dict["epochs"],
    callbacks=[history, checkpoint_callback],
    verbose=1,
)