## Load libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

In [17]:
import os
from pathlib import Path
import json

from src.lstm import AttentionLSTM
from src.metrics import exact_match_metric
from src.callbacks import NValidationSetsCallback
from src.generators import DataGeneratorAttention
from src.utils import get_sequence_data

import tensorflow as tf
from tensorflow.keras.optimizers import Adam

print(f"Using TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.test.is_gpu_available()}")

Using TensorFlow version: 2.0.0-beta1
GPU Available: False


## Paths

In [7]:
SETTINGS = Path('../settings/')
DATA = Path('../data/processed/')

## Load settings

In [8]:
settings_path = Path(SETTINGS / "settings_local.json")

with open(str(settings_path), "r") as file:
    settings_dict = json.load(file)

settings_dict

{'math_module': 'arithmetic__add_sub',
 'train_level': '*',
 'batch_size': 1024,
 'thinking_steps': 16,
 'epochs': 1,
 'num_encoder_units': 512,
 'num_decoder_units': 2048,
 'embedding_dim': 2048,
 'save_path': '../data/',
 'data_path': '../data/'}

## Load data

In [9]:
data_gen_pars, input_texts, target_texts = get_sequence_data(settings_dict)

In [10]:
print('Number of training samples:', len(input_texts['train']))

Number of training samples: 1599998


In [11]:
print('Number of validation samples:', len(input_texts['interpolate']))

Number of validation samples: 10000


## Data generators

In [13]:
training_generator = DataGeneratorAttention(
    input_texts=input_texts["train"],
    target_texts=target_texts["train"],
    **data_gen_pars
)
validation_generator = DataGeneratorAttention(
    input_texts=input_texts["valid"],
    target_texts=target_texts["valid"],
    **data_gen_pars
)
interpolate_generator = DataGeneratorAttention(
    input_texts=input_texts["interpolate"],
    target_texts=target_texts["interpolate"],
    **data_gen_pars
)
extrapolate_generator = DataGeneratorAttention(
    input_texts=input_texts["extrapolate"],
    target_texts=target_texts["extrapolate"],
    **data_gen_pars
)

## Load model

In [14]:
lstm = AttentionLSTM(
    data_gen_pars["num_encoder_tokens"],
    data_gen_pars["num_decoder_tokens"],
    data_gen_pars["max_encoder_seq_length"],
    data_gen_pars["max_decoder_seq_length"],
    settings_dict["num_encoder_units"],
    settings_dict["num_decoder_units"],
    settings_dict["embedding_dim"],
)
model = lstm.get_model()
model.summary()

W0921 18:33:44.146364 4557796800 deprecation.py:323] From /Users/lewtun/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:3868: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 2048)   67584       input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 2048)   28672       input_2[0][0]                    
______________________________________________________________________________________________

In [15]:
adam = Adam(
    lr=6e-4,
    beta_1=0.9,
    beta_2=0.995,
    epsilon=1e-9,
    decay=0.0,
    amsgrad=False,
    clipnorm=0.1,
)
model.compile(
    optimizer=adam, loss="categorical_crossentropy", metrics=[exact_match_metric]
)

## Configure callbacks

In [16]:
valid_dict = {
    'validation':validation_generator,
    'interpolation': interpolate_generator,
    'extrapolation': extrapolate_generator
}

In [25]:
history = NValidationSetsCallback(valid_dict)

In [26]:
# directory where the checkpoints will be saved
checkpoint_dir = settings_dict["save_path"] + "training_checkpoints"
# name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix, save_weights_only=True
)

## Train model

In [None]:
train_hist = model.fit_generator(
    training_generator,
    epochs=settings_dict["epochs"],
    callbacks=[history, checkpoint_callback],
    verbose=1,
)