In [1]:
import sys
import os
from training_utils import *

sys.path.append('models')

config = load_config_file("ETH-UCY/config_files/sdvae/sdvae_tf0_feat0.yaml", 1)
# config = load_config_file("ETH-UCY/config_files/vae_kl_025/vae_tf0_gt1_feat0.yaml", 1)

# config = load_config_file("crossroad/config_files/sdvae/sdvae_tf0_feat0.yaml", 1)
# config = load_config_file("crossroad/config_files/vae_kl_025/vae_tf1_gt1_feat0.yaml", 1)

if 'crossroad' in config['data_path']: 
    sys.path.append('crossroad')
elif 'ETH-UCY' in config['data_path']:
    sys.path.append('ETH-UCY')
else:
    assert 1==0

import tensorflow as tf
import numpy as np
import itertools
import yaml
from vae import *
from sdvae import *
from generator import Generator
from callback import Callback
from autoencoder_patches import *
from dataset_processing import *
        
    
if __name__ == '__main__':

    for args in itertools.product(config['test_list'], config['id_list']):

        tf.keras.backend.clear_session()
        
        change_args = {

            'i_test': args[0],
            'run_id': args[1]          

        }

        config.update(change_args)
        data = LoadData(config)
        
        if config['use_features'] and 'ae_pre_enc' in config:
            
            config['get_semantic_batch_f'] = get_semantic_batch_f
            config['AutoencoderClass'] = Autoencoder
        
        if config['use_features'] and ('ae_pre_enc' in config and config['ae_pre_enc']):
        
            save_dir = os.path.join(config['data_path'], 'pre_encodings_aug')        
            os.makedirs(save_dir, exist_ok = True)
            load_file = os.path.join(save_dir, str(config['i_test'])) + '.npy'
            
            if os.path.exists(load_file) == False:
                
                print('Training autoencoder')
                
                autoencoder = Autoencoder(config)
                h, _, _ = autoencoder.train(config, data)
                h = h.history

                ae_name = os.path.join(save_dir, 'enc_' + str(config['i_test']) + '.h5')
                autoencoder.planar_encoder.save(ae_name)
                ae_name = os.path.join(save_dir, 'dec_' + str(config['i_test']) + '.h5')
                autoencoder.tran_dec.save(ae_name)

                ae_h = os.path.join(save_dir, 'h_' + str(config['i_test']))
                np.save(ae_h, h)

                print('Doing pre-encoding')
                
                data.X_train['features'] = get_features(autoencoder, data.X_train, 'obs_traj', config)
                assert len(data.X_train['features']) == len(data.X_train['encoder']) 

                data.X_train['features_dec'] = get_features(autoencoder, data.X_train, 'decoder_traj', config)
                assert len(data.X_train['features_dec']) == len(data.X_train['encoder'])

                data.X_test['features'] = get_features(autoencoder, data.X_test, 'obs_traj', config)
                assert len(data.X_test['features']) == len(data.X_test['encoder'])

                np.save(load_file, data)
                
                tf.keras.backend.clear_session()
 
            data = np.load(load_file, allow_pickle = True)
            data = data.item()
 
            print('Pre-encoding loaded!')

        if config['use_features']:
            if 'reduce' in config:
                config['features_shape'] = (None, data.X_train['features'].shape[-1])
            elif 'num_filters' in config:
                if config['ae_pre_enc']:
                    config['features_shape'] = (None, data.X_train['features'].shape[-1])
                else:
                    config['features_shape'] =  (None, config['fLGrid'], config['fLGrid'], config['n_classes'])
        
        if config['model'] == 'vae':
            vae_m = VAE(config)
        else:
            vae_m = SDVAE(config)
                        
        if 'transform_data' in config and config['transform_data']:
            vae_m.attach_data_transformer(data.transformer)    
            
        train_gen = Generator(data.X_train, data.Y_train, config)
        callbacks = get_callbacks(config, vae_m, data)
        optimizer = get_optimizer(config)
        vae_m.compile_model(config, optimizer)

        r = vae_m.train_model.fit(train_gen,
                                  epochs = config['epochs'],
                                  callbacks = callbacks,
                                  use_multiprocessing = False,
                                  workers = 1,
                                  verbose = 1
                                  )

        save_history(config['r_path'], r, callbacks, vae_m.name_model)
        vae_m.load_prediction_weights(config['r_path'])    
        pred_t = vae_m.decode_sequences(data.X_test)       
        f_name = os.path.join(config['r_path'], vae_m.name_model) + '_preds'
        np.save(f_name, pred_t)
        print('Saved in', f_name)

        ades, fdes = get_metrics(pred_t, data.Y_test)
        print('ade', np.mean(ades))
        print('fde', np.mean(fdes))
        print()    
        
        f_name = os.path.join(config['r_path'], vae_m.name_model) + '_config.yaml'
        f = open(f_name, "w")
        yaml.dump(config, f)
        f.close()


    train
        biwi_hotel_train.txt
        877
        crowds_zara01_train.txt
        1976
        crowds_zara02_train.txt
        4477
        uni_examples_train.txt
        538
        students001_train.txt
        11691
        crowds_zara03_train.txt
        1760
        students003_train.txt
        8988
    30307
    val
        students001_val.txt
        1887
        students003_val.txt
        834
        uni_examples_val.txt
        79
        crowds_zara03_val.txt
        708
        crowds_zara01_val.txt
        337
        crowds_zara02_val.txt
        1259
        biwi_hotel_val.txt
        318
    5422
    test
        biwi_eth.txt
        364
    364
SAVING MODEL: sdvae_eth_0
Epoch 1/50

ADE 0.40262356
FDE 0.75712895
BEST: None

MODEL sdvae_eth_0 SAVED!


ADE 0.7089329
FDE 1.3336841
BEST: None

Epoch 2/50
 46/236 [====>.........................] - ETA: 14s - loss: 0.0175

KeyboardInterrupt: 