In [1]:
import sys
sys.path.append('../')

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import warnings
from scipy.optimize import minimize
from torch.utils.data import DataLoader

from src.model import SindyModel
from src.main import get_args
from src.dataloader import ScrewdrivingDataset

In [3]:
%matplotlib inline
plt.ioff()

<contextlib.ExitStack at 0x714b48a04460>

In [4]:
ARGS = get_args()

singal_data_dir = os.path.join('..', 'out', 'sindy-data')
model_dir = os.path.join('..', 'out', 'sindy-model-out', 'checkpoints')
data_dir = os.path.join('../', 'out', 'sindy-out-processed')
out_dir = os.path.join('../', 'out', 'sindy-out-processed', 'segmented')

os.makedirs(out_dir, exist_ok=True)
eval_file = 'cumulative_results.csv'

segment_length = 150

In [5]:
df = pd.read_csv(os.path.join(data_dir, eval_file))
df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0
mean,189.699371,-0.044134,0.524895,192.597056,-0.181346,0.375378,328.652159
std,269.732576,0.79406,2.144014,277.873437,0.780429,2.513176,339.134615
min,0.011174,-2.603501,-4.481685,0.031567,-2.587219,-5.695787,0.269293
25%,21.337781,-0.563363,-0.745945,17.048379,-0.717107,-1.49472,92.898624
50%,87.625094,-0.043391,0.417426,88.15903,-0.139241,0.036898,230.645305
75%,245.639275,0.49305,1.758729,261.133623,0.331338,2.692135,448.316043
max,2405.623606,2.830252,5.361715,2703.950757,3.001654,5.535521,2708.324862


In [6]:
tenth_percentile = df['l2_error'].quantile(0.05)
filtered_df = df[df['l2_error'] <= tenth_percentile]
filtered_df = filtered_df.sort_values(by=['l2_error'])
filtered_df = filtered_df.reset_index(drop=True)
filtered_df.head()

Unnamed: 0,model_name,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
0,model_2024-05-12 08:04:03.127533.pth,0.085856,0.013773,-1.00002,0.25524,-0.027909,0.440825,0.269293
1,model_2024-05-11 20:30:03.833141.pth,0.209353,0.017972,0.266353,0.355279,0.032275,-1.15179,0.412374
2,model_2024-05-10 00:09:34.902008.pth,0.668582,-0.043107,3.521961,0.198171,0.001038,-2.751343,0.697333
3,model_2024-05-10 12:11:15.448987.pth,0.691643,0.046233,2.0,0.226919,0.026162,-0.041026,0.727917
4,model_2024-05-11 07:58:05.251983.pth,0.037243,0.008065,-0.52799,0.729373,0.047836,-0.464822,0.730323


In [7]:
filtered_df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,73.0,73.0,73.0,73.0,73.0,73.0,73.0
mean,4.267601,-0.032983,0.813607,3.642892,0.018577,0.537357,6.663031
std,3.979119,0.114249,2.049457,3.42034,0.107478,2.547642,3.799983
min,0.011174,-0.210704,-3.999868,0.060308,-0.208007,-3.999916,0.269293
25%,0.691643,-0.130772,0.025,0.729373,-0.060382,-1.316123,3.909172
50%,3.410356,-0.043391,0.619876,2.625734,0.032275,0.009498,6.439906
75%,7.595191,0.046233,1.92174,5.658049,0.10426,3.182795,9.61176
max,13.41633,0.210375,5.001482,13.284853,0.191632,4.999972,13.670859


In [8]:
filtered_df.columns

Index(['model_name', 'vx_error', 'mean_adjustment_vx', 'time_shift_vx',
       'vy_error', 'mean_adjustment_vy', 'time_shift_vy', 'l2_error'],
      dtype='object')

In [9]:
filtered_models = filtered_df['model_name']

model_final_results = []

test_loader = DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))
test_x = torch.cat([batch[0] for batch in test_loader], dim=0)

test_y = [batch[1].numpy() for batch in DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))]
test_y = np.concatenate(test_y, axis=0)
test_y = test_y.reshape(test_y.shape[0], test_y.shape[-1])

initial_guess = [0., 0.]

In [10]:
def signal_error(params, reference_signals, target_signals):
    adjusted_signals = adjust_signal(target_signals, *params)
    return np.sum((reference_signals[0] - adjusted_signals[0]) ** 2) + np.sum((reference_signals[1] - adjusted_signals[1]) ** 2)

def adjust_signal(target_signals, mean_adjustment, time_shift):
    return [
        adjust_signal_single(target_signals[0], mean_adjustment, time_shift),
        adjust_signal_single(target_signals[1], mean_adjustment, time_shift),
    ]
    
def adjust_signal_single(target_signal, mean_adjustment, time_shift):
    shifted_target_signal = np.interp(np.arange(len(target_signal)), np.arange(len(target_signal)) - time_shift, target_signal)
    return shifted_target_signal + mean_adjustment

In [11]:
def find_best_overlap(reference_signals, target_signals, segment_length=segment_length):
    n = len(reference_signals[0])
    min_error = float('inf')
    best_params = None
    best_segments = (None, None)
    
    for i in range(1):n - segment_length + 1):
        reference_segment_x = reference_signals[0][i:i + segment_length]
        reference_segment_y = reference_signals[1][i:i + segment_length]
        
        for j in range(n - segment_length + 1):
            target_segment_x = target_signals[0][j:j + segment_length]
            target_segment_y = target_signals[1][j:j + segment_length]
            
            initial_params = [0, 0]
            bounds = [(-np.inf, np.inf), (-segment_length // 2, segment_length // 2)]
            
            result = minimize(signal_error, initial_params, args=([reference_segment_x, reference_segment_y], [target_segment_x, target_segment_y]), bounds=bounds)
            
            if result.fun < min_error:
                min_error = result.fun
                best_params = result.x
                best_segments = (i, j)
    
    return best_segments, best_params, min_error

In [12]:
def plot_signals( 
    adjusted_signals, 
    reference_signals,
    model_name,
    show_plot=False,
    title="Signal Comparison"
):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        t = np.arange(adjusted_signals[0].shape[0])
        
        fig, axs = plt.subplots(1, 2, figsize=(15, 10))
        
        axs[0].plot(t, adjusted_signals[0], label='Adjusted Signal Vx', dashes=[6, 2])
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, adjusted_signals[1], label='Adjusted Signal Vy', dashes=[6, 2])
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        axs[0].plot(t, reference_signals[0], label='Reference Signal Vx')
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, reference_signals[1], label='Reference Signal Vy')
        axs[1].set_xlabel('Time')
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        fig.suptitle(title)
        plt.tight_layout()

        plt.savefig(os.path.join(out_dir, f'{model_name}.png'))

        if show_plot:
            plt.show()


In [13]:
data_store = []

In [14]:
for idx, row in filtered_df.iterrows():
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        with torch.no_grad():
            model = SindyModel(**vars(ARGS))
            model.load_state_dict(torch.load(os.path.join(model_dir, row.model_name)))
            model.eval()
    
            pred_y = model(test_x).numpy()
            pred_y = pred_y.reshape(pred_y.shape[0], pred_y.shape[-1]).astype(np.complex64)
    
            reference_signal_x = np.copy(test_y)[:, 0]
            target_signal_x = pred_y[:, 0]
            adjusted_vx = adjust_signal_single(target_signal_x, row.mean_adjustment_vx, row.time_shift_vx)
            
            reference_signal_y = np.copy(test_y)[:, 1]
            target_signal_y = pred_y[:, 1]
            adjusted_vy = adjust_signal_single(target_signal_y, row.mean_adjustment_vy, row.time_shift_vy)

            best_segments, best_params, min_error = find_best_overlap([reference_signal_x, reference_signal_y], [adjusted_vx, adjusted_vy])

            print(f'#: {idx}, model_name: {row.model_name}, best_segments: {best_segments}, min_error: {min_error}, best_params: {best_params}')
            data_store.append([row.model_name, *best_segments, segment_length, min_error, *best_params])

            best_segmented_vx, best_segmented_vy = adjust_signal([adjusted_vx, adjusted_vy], *best_params)
            
            plot_signals(
                [
                    best_segmented_vx[best_segments[1]:best_segments[1] + segment_length],
                    best_segmented_vy[best_segments[1]:best_segments[1] + segment_length],
                ], 
                [
                    reference_signal_x[best_segments[0]:best_segments[0] + segment_length], 
                    reference_signal_y[best_segments[0]:best_segments[0] + segment_length]
                ], 
                row.model_name
            )

            
            

#: 0, model_name: model_2024-05-12 08:04:03.127533.pth, best_segments: (0, 0), min_error: (0.018748543289825087+0j), best_params: [-7.05087709e-04  6.28141441e-09]
#: 1, model_name: model_2024-05-11 20:30:03.833141.pth, best_segments: (0, 0), min_error: (0.08030003518868929+0j), best_params: [0.00037435 0.03471394]
#: 2, model_name: model_2024-05-10 00:09:34.902008.pth, best_segments: (0, 0), min_error: (0.034178148003067864+0j), best_params: [-0.0003111   0.09750595]
#: 3, model_name: model_2024-05-10 12:11:15.448987.pth, best_segments: (0, 0), min_error: (0.02596881854052996+0j), best_params: [ 1.04642361e-03 -1.00820857e-09]
#: 4, model_name: model_2024-05-11 07:58:05.251983.pth, best_segments: (0, 0), min_error: (0.031275314386600875+0j), best_params: [0.00016013 0.06504032]
#: 5, model_name: model_2024-05-10 05:42:01.116412.pth, best_segments: (0, 0), min_error: (0.022407772576094012+0j), best_params: [0.00011739 0.06834951]
#: 6, model_name: model_2024-05-11 09:49:47.644479.pth, 

In [15]:
out_file_path = os.path.join(out_dir, 'best_segmented_results.csv')
np.savetxt(out_file_path, data_store, fmt='%s', delimiter=',')

In [16]:
headers = ['model_name', 'start_ref', 'start_target', 'segment_length', 'min_error', 'mean_adjustment', 'time_shift']

with open(out_file_path, 'r+') as fh:
    content = fh.read()
    fh.seek(0, 0)
    fh.write(','.join(headers) + '\n' + content)
