In [None]:
# Imports
import sys

# Local Imports
sys.path.insert(0, 'python')
from lstm import lstm_graph
from read_data import read_data
from scrn import scrn_graph
from srn import srn_graph
from tokens import text_elements_to_tokens

In [None]:
# Flags
rnn_flg = 3      # 1 for SRN
                 # 2 for LSTM
                 # 3 for SCRN
usecase_flg = 1  # 1 for predicting letters
                 # 2 for predicting words with fixed vocabulary size
                 # 3 for predicting words with cutoff for infrequent words

In [None]:
# Network-specific hyperparameters
if rnn_flg == 1:
    
    # Network hyperparameters
    hidden_size = 100         # Dimension of the hidden vector
    
    # Training hyperparameters
    num_unfoldings = 10
    
elif rnn_flg == 2:
    
    # Network hyperparameters
    hidden_size = 100          # Dimension of the hidden vector
    
    # Training hyperparameters
    num_unfoldings = 10
    
elif rnn_flg == 3:
    
    # Network hyperparameters
    alpha = 0.95
    hidden_size = 100          # Dimension of the hidden vector
    state_size = 10            # Dimension of the state vector

    # Training hyperparameters
    num_unfoldings = 50
    
# General network hyperparameters
vocabulary_size = 10000    # Fixed vocabulary size for usecase_flg = 2
word_frequency_cutoff = 5  # Cutoff for infrequent words for usecase_flg = 3

# General training hyperparameters
batch_size = 32
clip_norm = 1.25
learning_decay = 1/1.5     # Multiplier to decay the learn rate when required
learning_rate = 0.05       # Initial learning rate
num_epochs = 100           # Total number of epochs to run the algorithm
optimization_frequency = 5 #
summary_frequency = 500
training_size =  6000000   # Size of training set
validation_size = 600000   # Size of validation set

# Data file
filename = 'data/text8.zip'

In [None]:
# Prepare training and validation batches
raw_data = read_data(usecase_flg, filename, vocabulary_size)
data, dictionary, reverse_dictionary, vocabulary_size = text_elements_to_tokens(usecase_flg, raw_data, vocabulary_size)
training_text = data[:training_size]
validation_text = data[training_size:training_size+validation_size]

In [None]:
# Initiate graph
if rnn_flg == 1:
    # Use SRN
    graph = srn_graph(hidden_size, vocabulary_size, num_unfoldings, batch_size)
elif rnn_flg == 2:
    # Use LSTM
    graph = lstm_graph(hidden_size, vocabulary_size, num_unfoldings, batch_size)
elif rnn_flg == 3:
    # Use SCRN
    graph = scrn_graph(alpha, hidden_size, state_size, vocabulary_size, num_unfoldings, batch_size)
    
# Optimize graph
graph.optimization(learning_rate, learning_decay, optimization_frequency, clip_norm, num_epochs, 
                   summary_frequency, training_text, validation_text)