In [17]:
import sys
sys.path.append('../')

In [18]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import warnings
from scipy.optimize import minimize
from torch.utils.data import DataLoader

from src.model import SindyModel
from src.main import get_args
from src.dataloader import ScrewdrivingDataset

In [19]:
%matplotlib inline
plt.ioff()

<contextlib.ExitStack at 0x76a7be25f1f0>

In [20]:
ARGS = get_args()

singal_data_dir = os.path.join('..', 'out', 'sindy-data')
model_dir = os.path.join('..', 'out', 'sindy-model-out', 'checkpoints')
data_dir = os.path.join('../', 'out', 'sindy-out-processed')
out_dir = os.path.join('../', 'out', 'sindy-out-processed', 'segmented')

os.makedirs(out_dir, exist_ok=True)
eval_file = 'cumulative_results.csv'

segment_length = 150

In [21]:
df = pd.read_csv(os.path.join(data_dir, eval_file))
df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0
mean,189.699371,-0.044134,0.524895,192.597056,-0.181346,0.375378,328.652159
std,269.732576,0.79406,2.144014,277.873437,0.780429,2.513176,339.134615
min,0.011174,-2.603501,-4.481685,0.031567,-2.587219,-5.695787,0.269293
25%,21.337781,-0.563363,-0.745945,17.048379,-0.717107,-1.49472,92.898624
50%,87.625094,-0.043391,0.417426,88.15903,-0.139241,0.036898,230.645305
75%,245.639275,0.49305,1.758729,261.133623,0.331338,2.692135,448.316043
max,2405.623606,2.830252,5.361715,2703.950757,3.001654,5.535521,2708.324862


In [22]:
tenth_percentile = df['l2_error'].quantile(0.05)
filtered_df = df[df['l2_error'] <= tenth_percentile]
filtered_df = filtered_df.sort_values(by=['l2_error'])
filtered_df = filtered_df.reset_index(drop=True)
# disblae model filtering
filtered_df = df
filtered_df.head()

Unnamed: 0,model_name,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
0,model_2024-05-09 22:01:42.786493.pth,229.449344,0.874527,1.461312,342.509715,1.068415,-0.287391,412.261939
1,model_2024-05-09 22:12:13.978245.pth,3.825385,0.112007,-3.147911,113.881954,0.615231,-3.690688,113.946185
2,model_2024-05-09 22:14:46.276889.pth,92.930926,-0.556426,0.999976,470.415226,-1.252004,-1.815464,479.506665
3,model_2024-05-09 22:17:16.523394.pth,274.549062,0.956753,1.56664,259.057256,0.929026,-1.175424,377.475627
4,model_2024-05-09 22:19:51.606696.pth,58.717447,-0.442384,2.338684,1.169345,0.061039,-0.999958,58.72909


In [23]:
filtered_df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0,1441.0
mean,189.699371,-0.044134,0.524895,192.597056,-0.181346,0.375378,328.652159
std,269.732576,0.79406,2.144014,277.873437,0.780429,2.513176,339.134615
min,0.011174,-2.603501,-4.481685,0.031567,-2.587219,-5.695787,0.269293
25%,21.337781,-0.563363,-0.745945,17.048379,-0.717107,-1.49472,92.898624
50%,87.625094,-0.043391,0.417426,88.15903,-0.139241,0.036898,230.645305
75%,245.639275,0.49305,1.758729,261.133623,0.331338,2.692135,448.316043
max,2405.623606,2.830252,5.361715,2703.950757,3.001654,5.535521,2708.324862


In [24]:
filtered_df.columns

Index(['model_name', 'vx_error', 'mean_adjustment_vx', 'time_shift_vx',
       'vy_error', 'mean_adjustment_vy', 'time_shift_vy', 'l2_error'],
      dtype='object')

In [25]:
filtered_models = filtered_df['model_name']

model_final_results = []

test_loader = DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))
test_x = torch.cat([batch[0] for batch in test_loader], dim=0)

test_y = [batch[1].numpy() for batch in DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))]
test_y = np.concatenate(test_y, axis=0)
test_y = test_y.reshape(test_y.shape[0], test_y.shape[-1])

initial_guess = [0., 0.]

In [26]:
def signal_error(params, reference_signals, target_signals):
    adjusted_signals = adjust_signal(target_signals, *params)
    return np.sum((reference_signals[0] - adjusted_signals[0]) ** 2) + np.sum((reference_signals[1] - adjusted_signals[1]) ** 2)

def adjust_signal(target_signals, mean_adjustment, time_shift):
    return [
        adjust_signal_single(target_signals[0], mean_adjustment, time_shift),
        adjust_signal_single(target_signals[1], mean_adjustment, time_shift),
    ]
    
def adjust_signal_single(target_signal, mean_adjustment, time_shift):
    shifted_target_signal = np.interp(np.arange(len(target_signal)), np.arange(len(target_signal)) - time_shift, target_signal)
    return shifted_target_signal + mean_adjustment

In [27]:
def find_best_overlap(reference_signals, target_signals, segment_length=segment_length):
    n = len(reference_signals[0])
    min_error = float('inf')
    best_params = None
    best_segments = (None, None)
    
    for i in range(n - segment_length + 1):
        reference_segment_x = reference_signals[0][i:i + segment_length]
        reference_segment_y = reference_signals[1][i:i + segment_length]
        
        for j in range(n - segment_length + 1):
            target_segment_x = target_signals[0][j:j + segment_length]
            target_segment_y = target_signals[1][j:j + segment_length]
            
            initial_params = [0, 0]
            bounds = [(-np.inf, np.inf), (-segment_length // 2, segment_length // 2)]
            
            result = minimize(signal_error, initial_params, args=([reference_segment_x, reference_segment_y], [target_segment_x, target_segment_y]), bounds=bounds)
            
            if result.fun < min_error:
                min_error = result.fun
                best_params = result.x
                best_segments = (i, j)
    
    return best_segments, best_params, min_error

In [28]:
def plot_signals( 
    adjusted_signals, 
    reference_signals,
    model_name,
    show_plot=False,
    title="Signal Comparison"
):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        t = np.arange(adjusted_signals[0].shape[0])
        
        fig, axs = plt.subplots(1, 2, figsize=(15, 10))
        
        axs[0].plot(t, adjusted_signals[0], label='Adjusted Signal Vx', dashes=[6, 2])
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, adjusted_signals[1], label='Adjusted Signal Vy', dashes=[6, 2])
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        axs[0].plot(t, reference_signals[0], label='Reference Signal Vx')
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, reference_signals[1], label='Reference Signal Vy')
        axs[1].set_xlabel('Time')
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        fig.suptitle(title)
        plt.tight_layout()

        plt.savefig(os.path.join(out_dir, f'{model_name}.png'))

        if show_plot:
            plt.show()


In [29]:
data_store = []

In [30]:
for idx, row in filtered_df.iterrows():
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        with torch.no_grad():
            model = SindyModel(**vars(ARGS))
            model.load_state_dict(torch.load(os.path.join(model_dir, row.model_name)))
            model.eval()
    
            pred_y = model(test_x).numpy()
            pred_y = pred_y.reshape(pred_y.shape[0], pred_y.shape[-1]).astype(np.complex64)
    
            reference_signal_x = np.copy(test_y)[:, 0]
            target_signal_x = pred_y[:, 0]
            adjusted_vx = adjust_signal_single(target_signal_x, row.mean_adjustment_vx, row.time_shift_vx)
            
            reference_signal_y = np.copy(test_y)[:, 1]
            target_signal_y = pred_y[:, 1]
            adjusted_vy = adjust_signal_single(target_signal_y, row.mean_adjustment_vy, row.time_shift_vy)

            best_segments, best_params, min_error = find_best_overlap([reference_signal_x, reference_signal_y], [adjusted_vx, adjusted_vy])

            print(f'#: {idx}, model_name: {row.model_name}, best_segments: {best_segments}, min_error: {min_error}, best_params: {best_params}')
            data_store.append([row.model_name, *best_segments, segment_length, min_error, *best_params])

            best_segmented_vx, best_segmented_vy = adjust_signal([adjusted_vx, adjusted_vy], *best_params)

            
            
            plot_signals(
                [
                    best_segmented_vx[best_segments[1]:best_segments[1] + segment_length],
                    best_segmented_vy[best_segments[1]:best_segments[1] + segment_length],
                ], 
                [
                    reference_signal_x[best_segments[0]:best_segments[0] + segment_length], 
                    reference_signal_y[best_segments[0]:best_segments[0] + segment_length]
                ], 
                row.model_name
            )

            
            

#: 0, model_name: model_2024-05-09 22:01:42.786493.pth, best_segments: (0, 57), min_error: (17053.79585236489+0j), best_params: [ -5.70030098 -75.        ]
#: 1, model_name: model_2024-05-09 22:12:13.978245.pth, best_segments: (149, 30), min_error: (1344.8437285859957+0j), best_params: [ -0.3574229 -75.       ]
#: 2, model_name: model_2024-05-09 22:14:46.276889.pth, best_segments: (149, 128), min_error: (38241.07259208623+0j), best_params: [12.52183028 75.        ]
#: 3, model_name: model_2024-05-09 22:17:16.523394.pth, best_segments: (149, 71), min_error: (1521.6705966708398+0j), best_params: [ -0.4412071 -75.       ]
#: 4, model_name: model_2024-05-09 22:19:51.606696.pth, best_segments: (149, 24), min_error: (671.1154340581083+0j), best_params: [ -0.68988518 -75.        ]
#: 5, model_name: model_2024-05-09 22:22:31.619788.pth, best_segments: (149, 126), min_error: (89233.80057961527+0j), best_params: [-2.23049433 75.        ]
#: 6, model_name: model_2024-05-09 22:25:01.747806.pth, be

In [31]:
out_file_path = os.path.join(out_dir, 'best_segmented_results.csv')
np.savetxt(out_file_path, data_store, fmt='%s', delimiter=',')

In [32]:
headers = ['model_name', 'start_ref', 'start_target', 'segment_length', 'min_error', 'mean_adjustment', 'time_shift']

with open(out_file_path, 'r+') as fh:
    content = fh.read()
    fh.seek(0, 0)
    fh.write(','.join(headers) + '\n' + content)
