In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import copy
import csv
import types
import random
import argparse
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.models import load_model
from keras import backend as K
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasRegressor
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import utils
from sklearn.model_selection import train_test_split
from keras.models import Model, Sequential
from keras.layers.convolutional_recurrent import ConvLSTM2D
from keras.layers.normalization import BatchNormalization
from keras.layers import *
from data import *

Using TensorFlow backend.


In [2]:
def get_model_loss(model_base_path, x_test, y_test):
    
    loss_list = []
    file_list = os.listdir(model_base_path)
    if 'best_val_loss_model.h5' in file_list:
        print('Testing 1 model')
        model_path = os.path.join(model_base_path, 'best_val_loss_model.h5')
        model = load_model(model_path)
        model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(lr=0.0001), metrics=['mae'])
        test_gen = val_generator(x_test, y_test)
        loss = model.evaluate_generator(test_gen, steps=len(test_gen))
        loss_list.append({'file': file, 'loss': loss[0], 'mae': loss[1]}) 
    else:
        for file in file_list:
            if 'csv' in file:
                continue
            model_path = os.path.join(model_base_path, file, 'best_val_loss_model.h5')
            if 'lr' in file:
                lr_idx = file.rindex('lr')
                learning_rate = float(file[lr_idx+3:file.rindex('|')])
            else:
                learning_rate = 0.0001
            model = load_model(model_path)
            model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(lr=learning_rate), metrics=['mae'])
            test_gen = val_generator(x_test, y_test)
            loss = model.evaluate_generator(test_gen, steps=len(test_gen))
            loss_list.append({'file': file, 'loss': loss[0], 'mae': loss[1]})        
    return loss_list

In [3]:
def get_model_prediction(model_dir, x_test, y_test):

    model_path = os.path.join(model_dir, 'best_val_loss_model.h5')
    model = load_model(model_path)
    model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(lr=0.0001), metrics=['mae'])
    test_gen = val_generator(x_test, y_test)
    predictions = model.predict_generator(test_gen, steps=len(test_gen))
        
    return predictions 

In [5]:
#For stft count_dense1 = 128
#Best: -1.2366702235959561 using {'count_dense1': 128, 'count_dense2': 16, 'count_hidden1': 64, 'dropout': 0.2, 'epochs': 50, 'hidden1': 32, 'init': 'normal', 'lr': 0.0001, 'optimizer': 'adam'}
data_type = 'time'
data_base_path = '/scratch/sk7898/pedbike/window_256'
best_model = 'count_dense1=64|count_dense2=16|count_hidden1=32|hidden1=64'

if data_type == 'stft':
    model_base_path = '/scratch/sk7898/radar_counting/models/lstm_stft'
    fileloc = os.path.join(data_base_path, 'downstream_stft')
elif data_type == 'time':
    model_base_path = '/scratch/sk7898/radar_counting/models/lstm_time'
    fileloc = os.path.join(data_base_path, 'downstream_time')
else:
    raise ValueError('Data type not supported!')
    
model_path = os.path.join(model_base_path, best_model)
x_train, x_val, x_test, y_train, y_val, y_test, seqs_train, seqs_val, seqs_test = get_data(fileloc)

In [None]:
loss_list = get_model_loss(model_path, x_test, y_test)
print(loss_list)

In [6]:
predictions = get_model_prediction(model_path, x_test, y_test)
print(predictions)

[[2.1938884]
 [2.1851707]
 [2.1765296]
 [2.1703002]
 [2.187777 ]
 [2.1821437]
 [2.1786397]
 [2.1838152]
 [2.1825333]
 [2.1833713]
 [2.1812406]
 [2.1847894]
 [2.1796305]
 [2.189432 ]
 [2.188066 ]
 [2.178646 ]
 [2.178604 ]
 [2.1855125]
 [2.1791244]
 [2.1841063]
 [2.1897428]
 [2.1801343]
 [2.1814184]
 [2.1896682]
 [2.1903288]
 [2.178945 ]
 [2.1851506]
 [2.1795402]
 [2.181309 ]
 [2.1784322]
 [2.1903834]
 [2.1760602]
 [2.1791923]
 [2.1790123]
 [2.1821125]
 [2.1796112]
 [2.1809201]
 [2.1789713]
 [2.1812344]
 [2.193906 ]
 [2.1776106]
 [2.1792014]
 [2.1825385]
 [2.177925 ]
 [2.1838446]
 [2.178438 ]
 [2.1808608]
 [2.179047 ]
 [2.1927092]
 [2.178223 ]
 [2.17806  ]
 [2.1776273]
 [2.1791666]
 [2.1819766]
 [2.176402 ]
 [2.1791189]
 [2.1846566]
 [2.1756952]
 [2.179805 ]
 [2.1892962]
 [2.1928377]
 [2.1779919]
 [2.1802113]
 [2.1864772]
 [2.183382 ]
 [2.179324 ]
 [2.1788485]
 [2.180514 ]
 [2.1794214]
 [2.1813946]
 [2.1914454]
 [2.1896718]
 [2.1806152]
 [2.178083 ]
 [2.178766 ]
 [2.177303 ]
 [2.1822422]