# Libraries

In [1]:
import numpy as np
import os
import itertools
import yaml

from dataset_processing import *
from vae_model import *


def load_config_file(file_path, create_folder_results):

    stream = open(file_path, 'r')
    config = yaml.safe_load(stream)
    #stream.close()
    
    # direction where the results are saved
    config["r_path"] = os.path.join(config['results_path'], config['prefix'] + "_")
    config["r_path"] += str(config['use_teacherf']) + '_'
    config["r_path"] += str(config['use_gt_sampling']) + '_'
    config["r_path"] += str(config['use_features'])

    # model name
    
    
    if create_folder_results:

        try:
            os.makedirs(config['results_path'])
            os.makedirs(config['r_path'])
        except:
            try:
                os.makedirs(config['r_path'])
            except:
                pass

    return config

            
def get_callbacks(config, vae_m, data):
        
    if config['formal_training']:
        
        callback_1 = SomeCallback(vae_m,
                                  data['val']['X'], data['val']['Y'],
                                  config,
                                  save_model = True)

        callback_2 = SomeCallback(vae_m,
                                  data['test']['X'], data['test']['Y'],
                                  config,
                                  save_model = False)

        return [callback_1, callback_2]

    callback = SomeCallback(vae_m,
                            data['test']['X'], data['test']['Y'],
                            config,
                            save_model = True)

    return [callback]
    
    
def get_optimizer(config):
    
    if config['sgd']:
        optimizer = tf.optimizers.SGD(learning_rate = config['lr'],
                                      momentum = config['momentum'], 
                                      decay = config['decay'])
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate = config['lr'])
    
    return optimizer


def save_history(path, r, callbacks, name_model):
    
    if len(callbacks) == 2:
    
        r.history['vali_metrics'] = callbacks[0].history
        r.history['test_metrics'] = callbacks[1].history

    else:
        r.history['test_metrics'] = callbacks[0].history

    np.save(os.path.join(path, name_model + '_history'), r.history)

    
if __name__ == '__main__':

    config = load_config_file("config_files/vae_tf1_gt1.yaml", 1)
    
    for args in itertools.product(config['test_list'], config['id_list']):

        tf.keras.backend.clear_session()

        change_args = {

            'i_test': args[0],
            'run_id': args[1]          

        }

        config.update(change_args)
        os.path.join('datasets', args[0], 'data.npy')

        data = load_dataset(config, verbose = 0)
        vae_m = VAE(config)
        train_gen = VAEDataGenerator(data['train']['X'], data['train']['Y'], config)
        callbacks = get_callbacks(config, vae_m, data)
        optimizer = get_optimizer(config)
        vae_m.compile_model(config, optimizer)

        r = vae_m.train_model.fit(train_gen,
                                  epochs = config['epochs'],
                                  callbacks = callbacks,
                                  use_multiprocessing = False,
                                  workers = 1,
                                  verbose = 1
                                  )

        save_history(config['r_path'], r, callbacks, vae_m.name_model)
        vae_m.load_prediction_weights(config['r_path'])    
        pred_t = vae_m.decode_sequences(data['test']['X'], config['outputs_final_test'], config['batch_size_vali'])       
        f_name = os.path.join(config['r_path'], vae_m.name_model) + '_preds'
        np.save(f_name, pred_t)
        print('Saved in', f_name)

        ades, fdes = get_metrics(pred_t, data['test']['Y'])
        print('ade', np.mean(ades))
        print('fde', np.mean(fdes))
        print()



{'test_list': ['eth', 'hotel', 'univ', 'zara1', 'zara2'], 'id_list': ['0'], 'prefix': 'vae', 'results_path': 'results_ethucy', 'data_path': 'datasets', 'formal_training': 1, 'use_teacherf': 1, 'use_gt_sampling': 1, 'use_features': 0, 'max_d': 20.0, 'reduce': 360, 'f_dis_D': 128, 'LSTM_D': 256, 'normal_D': 128, 'outputs_validation': 20, 'outputs_final_test': 100, 'sgd': 1, 'lr': 0.005, 'momentum': 0.9, 'decay': 0.0, 'kl_w': 0.25, 'epochs': 5, 'batch_size': 128, 'batch_size_vali': 2048, 'features': 0, 'obs_l': 8, 'pre_l': 12, 'tot_l': 20, 'D': 2, 'decoders': 1, 'shuffle': 1, 'r_path': 'results_ethucy/vae_1_1_0'}
