# Sequence Model Training
## Metadata Single-Record Embeddings to File

In [1]:
import sys
sys.path.append("../")

import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import metrics
from tensorflow.keras import callbacks
from modules import utils, models
from tqdm import tqdm

2021-07-22 09:27:52.495083: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


In [None]:
metadata_train_gen = utils.CombinedDataGen(data_file='../data/demo_train.csv',
                                           out_mode='meta',
                                           mode='test',
                                           shuffle=False,
                                           scaler_dir='../models/scalers',
                                           include_index=True)
train_index_map = utils.index_map('../data/demo_train.csv')

metadata_valid_gen = utils.CombinedDataGen(data_file='../data/demo_valid.csv',
                                           out_mode='meta',
                                           mode='test',
                                           shuffle=False,
                                           scaler_dir='../models/scalers',
                                           include_index=True)
valid_index_map = utils.index_map('../data/demo_valid.csv')

metadata_test_gen = utils.CombinedDataGen(data_file='../data/demo_test.csv',
                                           out_mode='meta',
                                           mode='test',
                                           shuffle=False,
                                           scaler_dir='../models/scalers',
                                           include_index=True)
test_index_map = utils.index_map('../data/demo_test.csv')

metadata_mdl = tf.keras.models.load_model('../models/metadata_single_rec.h5')

emb_inp = metadata_mdl.inputs
emb_out = metadata_mdl.layers[-2].output
embedding_mdl = tf.keras.models.Model(inputs=emb_inp, outputs=emb_out)

In [None]:
utils.write_embeddings(generator=metadata_train_gen, 
                       out_mode='meta',
                       mode='train',
                       model=embedding_mdl,
                       index_map=train_index_map,
                       save_dir='../data/metadata_demo_sequence_train.pkl',
                       nb_eps=100)
utils.write_embeddings(generator=metadata_valid_gen, 
                       out_mode='meta',
                       mode='valid',
                       model=embedding_mdl,
                       index_map=valid_index_map,
                       save_dir='../data/metadata_demo_sequence_valid.pkl',
                       nb_eps=1)
utils.write_embeddings(generator=metadata_test_gen, 
                       out_mode='meta',
                       mode='test',
                       model=embedding_mdl,
                       index_map=test_index_map,
                       save_dir='../data/metadata_demo_sequence_test.pkl',
                       nb_eps=1)

## Train Metadata Sequence Model

In [None]:
metadata_seq_mdl = models.init_seq_model(inp_shape=33)
train_metadata_seq_gen = utils.SequenceGen(data_file='../data/metadata_demo_sequence_train.pkl',
                                           batch_size=2,
                                           emb_shape=33)
valid_metadata_seq_gen = utils.SequenceGen(data_file='../data/metadata_demo_sequence_valid.pkl',
                                           batch_size=1,
                                           emb_shape=33,
                                           drop_prob=0.)

In [None]:
es_cb = callbacks.EarlyStopping(monitor='val_loss',
                                min_delta=1e-8,
                                patience=30,
                                restore_best_weights=True)

metadata_seq_mdl.compile(optimizer='adam', 
                         loss='binary_crossentropy',
                         metrics=[metrics.AUC()])

metadata_seq_mdl.fit(train_metadata_seq_gen, 
                     validation_data=valid_metadata_seq_gen,
                     epochs=1000,
                     callbacks=[es_cb])

In [None]:
metadata_seq_mdl.save('../models/metadata_sequence_rec.h5')