In [1]:
import sys
sys.path.append('../')

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import warnings
from scipy.optimize import minimize
from torch.utils.data import DataLoader

from src.model import SindyModel
from src.main import get_args
from src.dataloader import ScrewdrivingDataset

In [3]:
%matplotlib inline
plt.ioff()

<contextlib.ExitStack at 0x724a40e3fc40>

In [4]:
ARGS = get_args()

singal_data_dir = os.path.join('..', 'out', 'sindy-data')
model_dir = os.path.join('..', 'out', 'sindy-model-out', 'checkpoints')
data_dir = os.path.join('../', 'out', 'sindy-out-processed')

eval_file = 'cumulative_results.csv'

In [5]:
df = pd.read_csv(os.path.join(data_dir, eval_file))
df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0
mean,189.699371,-0.044134,0.524895,192.597056,-0.181346,0.375378,328.652159
std,269.732576,0.79406,2.144014,277.873437,0.780429,2.513176,339.134615
min,0.011174,-2.603501,-4.481685,0.031567,-2.587219,-5.695787,0.269293
25%,21.337781,-0.563363,-0.745945,17.048379,-0.717107,-1.49472,92.898624
50%,87.625094,-0.043391,0.417426,88.15903,-0.139241,0.036898,230.645305
75%,245.639275,0.49305,1.758729,261.133623,0.331338,2.692135,448.316043
max,2405.623606,2.830252,5.361715,2703.950757,3.001654,5.535521,2708.324862


In [6]:
tenth_percentile = df['l2_error'].quantile(0.05)
filtered_df = df[df['l2_error'] <= tenth_percentile]
filtered_df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,73.0,73.0,73.0,73.0,73.0,73.0,73.0
mean,4.267601,-0.032983,0.813607,3.642892,0.018577,0.537357,6.663031
std,3.979119,0.114249,2.049457,3.42034,0.107478,2.547642,3.799983
min,0.011174,-0.210704,-3.999868,0.060308,-0.208007,-3.999916,0.269293
25%,0.691643,-0.130772,0.025,0.729373,-0.060382,-1.316123,3.909172
50%,3.410356,-0.043391,0.619876,2.625734,0.032275,0.009498,6.439906
75%,7.595191,0.046233,1.92174,5.658049,0.10426,3.182795,9.61176
max,13.41633,0.210375,5.001482,13.284853,0.191632,4.999972,13.670859


In [7]:
filtered_df.columns

Index(['model_name', 'vx_error', 'mean_adjustment_vx', 'time_shift_vx',
       'vy_error', 'mean_adjustment_vy', 'time_shift_vy', 'l2_error'],
      dtype='object')

In [8]:
filtered_models = filtered_df['model_name']

model_final_results = []

test_loader = DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))
test_x = torch.cat([batch[0] for batch in test_loader], dim=0)

test_y = [batch[1].numpy() for batch in DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))]
test_y = np.concatenate(test_y, axis=0)
test_y = test_y.reshape(test_y.shape[0], test_y.shape[-1])

initial_guess = [0., 0.]

In [9]:
def signal_error(params, reference_signal, target_signal):
    adjusted_signal = adjust_signal(target_signal, *params)
    return np.sum((reference_signal - adjusted_signal) ** 2)

def adjust_signal(target_signal, mean_adjustment, time_shift):
    shifted_target_signal = np.interp(np.arange(len(target_signal)), np.arange(len(target_signal)) - time_shift, target_signal)
    return shifted_target_signal + mean_adjustment

In [10]:
for index, row in filtered_df.iterrows():
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        with torch.no_grad():
            model = SindyModel(**vars(ARGS))
            model.load_state_dict(torch.load(os.path.join(model_dir, row.model_name)))
            model.eval()
    
            pred_y = model(test_x).numpy()
            pred_y = pred_y.reshape(pred_y.shape[0], pred_y.shape[-1]).astype(np.complex64)
    
            reference_signal_x = np.copy(test_y)[:, 0]
            target_signal_x = pred_y[:, 0]
            adjusted_vx = adjust_signal(target_signal_x, row.mean_adjustment_vx, row.time_shift_vx)
            reference_signal_y = np.copy(test_y)[:, 1]
            target_signal_y = pred_y[:, 1]
            adjusted_vy = adjust_signal(target_signal_y, row.mean_adjustment_vy, row.time_shift_vy)
            