# Train LSTM sequence prediction models ------------------------------

Neural language models are designed to learn how to predict the next 
word in a sequence given the prior context. In our case, words are 
cluster assignments for each 100-ms segment of pitch contour.

This implementation of the LSTM is adapted from this tutorial: https://keras.rstudio.com/articles/examples/lstm_text_generation.html

In [18]:
library(janitor); library(here); library(keras); library(tidyverse)
source(here("code/00_config/lena-pred-dnn-config.R"))
source(here("code/00_helper_functions/lstm-train-h.R"))


Attaching package: ‘janitor’

The following objects are masked from ‘package:stats’:

    chisq.test, fisher.test



In [3]:
# get data and flatten to a one level list of datasets to fit
# should be n_prop_cds X n_q_shapes datasets in the list
d_list <- read_rds(here(paths_config$lstm_sum_path, 
                        "lena-pred-lstm-train-test.rds")) %>% 
  flatten() 

In [19]:
# create list of models for each dataset: prop CDS and n-qshapes
mods <- d_list %>% map(create_lstm, lstm_config)

# train model and generate predictions 
# the train lstm function also handles post-processing 
# and tidying the model predictions
results_obj <- pmap(list(mods, names(mods), d_list[1]), 
                    .f = safe_train_lstm, 
                    lstm_config = lstm_config)

________________________________________________________________________________
Layer (type)                        Output Shape                    Param #     
embedding_26 (Embedding)            (None, 10, 30)                  270         
________________________________________________________________________________
lstm_26 (LSTM)                      (None, 30)                      7320        
________________________________________________________________________________
dense_26 (Dense)                    (None, 8)                       248         
Total params: 7,838
Trainable params: 7,838
Non-trainable params: 0
________________________________________________________________________________


New names:
* `` -> ...1
* `` -> ...2
* `` -> ...3
* `` -> ...4
* `` -> ...5
* … and 3 more problems


In [None]:
# save predictions for later analysis
write_rds(results_obj, 
          here(paths_config$lstm_sum_path, "lena-pred-lstm-preds.rds"), 
          compress = "gz")  