In [17]:
import sys
sys.path.append('../')

In [18]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import warnings
from scipy.optimize import minimize
from torch.utils.data import DataLoader

from src.model import SindyModel
from src.main import get_args
from src.dataloader import ScrewdrivingDataset

In [19]:
%matplotlib inline
plt.ioff()

<contextlib.ExitStack at 0x7dc88db02590>

In [20]:
ARGS = get_args()

singal_data_dir = os.path.join('..', 'out', 'sindy-data')
model_dir = os.path.join('..', 'out', 'sindy-model-out', 'checkpoints')
data_dir = os.path.join('../', 'out', 'sindy-out-processed')
out_dir = os.path.join('../', 'out', 'sindy-out-processed', 'segmented')

os.makedirs(out_dir, exist_ok=True)
eval_file = 'cumulative_results.csv'

segment_length = 150

In [21]:
df = pd.read_csv(os.path.join(data_dir, eval_file))
df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,12153.0,12153.0,12153.0,12153.0,12153.0,12153.0,12153.0
mean,387.182464,0.015207,11.184881,338.518074,0.016635,11.70176,599.174658
std,1465.080793,1.135764,47.390325,1329.053537,1.061949,46.673992,1954.055202
min,0.00049,-7.878101,-9.490186,0.000505,-8.991674,-495.500415,0.058624
25%,5.03426,-0.290056,0.057228,4.878925,-0.276337,0.086722,21.542809
50%,23.73896,-0.007377,0.512012,22.465138,-0.007502,0.532314,57.068313
75%,88.8785,0.271416,0.694495,87.81165,0.268567,0.714787,162.727447
max,24511.492896,9.038507,1034.60733,24227.089358,7.818888,581.776419,27402.845358


In [22]:
tenth_percentile = df['l2_error'].quantile(0.01)
filtered_df = df[df['l2_error'] <= tenth_percentile]
filtered_df = filtered_df.sort_values(by=['l2_error'])
filtered_df = filtered_df.reset_index(drop=True)
print(filtered_df.shape)
print()
filtered_df.head()

(122, 8)



Unnamed: 0,model_name,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
0,model_2025-02-09 09:28:24.650612_epoch-35.pth,0.058622,0.009412,0.620081,0.000505,-0.001265,93.524543,0.058624
1,model_2025-02-09 06:26:35.178491_epoch-10.pth,0.040382,-0.008866,0.007851,0.06789,-0.005096,0.004171,0.078992
2,model_2025-02-09 05:14:53.091110_epoch-23.pth,0.054388,-0.012281,0.010408,0.078374,-0.015834,0.013298,0.095397
3,model_2025-02-09 09:33:49.962264_epoch-15.pth,0.055854,-0.010472,0.008649,0.079245,-0.016782,46.484123,0.096951
4,model_2025-02-09 10:20:57.266285_epoch-35.pth,0.073228,0.00437,0.001245,0.084577,-0.012226,0.775537,0.111873


In [23]:
filtered_df.describe()

Unnamed: 0,vx_error,mean_adjustment_vx,time_shift_vx,vy_error,mean_adjustment_vy,time_shift_vy,l2_error
count,122.0,122.0,122.0,122.0,122.0,122.0,122.0
mean,0.312619,-0.001067,5.26868,0.314095,-0.003534,10.006842,0.499832
std,0.245163,0.029097,32.174487,0.233369,0.027178,53.400258,0.24633
min,0.001375,-0.053523,-0.40794,0.000505,-0.053083,-0.476429,0.058624
25%,0.088671,-0.027265,0.007784,0.121957,-0.024883,0.0087,0.298137
50%,0.267425,-0.002056,0.487429,0.24157,-0.005079,0.537015,0.49517
75%,0.482118,0.023856,0.6174,0.507535,0.01657,0.682043,0.692873
max,0.885158,0.052793,337.073295,0.861079,0.053199,555.337785,0.911001


In [24]:
filtered_df.columns

Index(['model_name', 'vx_error', 'mean_adjustment_vx', 'time_shift_vx',
       'vy_error', 'mean_adjustment_vy', 'time_shift_vy', 'l2_error'],
      dtype='object')

In [25]:
filtered_models = filtered_df['model_name']

model_final_results = []

test_loader = DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))
test_x = torch.cat([batch[0] for batch in test_loader], dim=0)

test_y = [batch[1].numpy() for batch in DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': singal_data_dir}))]
test_y = np.concatenate(test_y, axis=0)
test_y = test_y.reshape(test_y.shape[0], test_y.shape[-1])

initial_guess = [0., 0.]

In [26]:
def signal_error(params, reference_signals, target_signals):
    adjusted_signals = adjust_signal(target_signals, *params)
    return np.sum((reference_signals[0] - adjusted_signals[0]) ** 2) + np.sum((reference_signals[1] - adjusted_signals[1]) ** 2)

def adjust_signal(target_signals, mean_adjustment, time_shift):
    return [
        adjust_signal_single(target_signals[0], mean_adjustment, time_shift),
        adjust_signal_single(target_signals[1], mean_adjustment, time_shift),
    ]
    
def adjust_signal_single(target_signal, mean_adjustment, time_shift):
    shifted_target_signal = np.interp(np.arange(len(target_signal)), np.arange(len(target_signal)) - time_shift, target_signal)
    return shifted_target_signal + mean_adjustment

In [27]:
def find_best_overlap(reference_signals, target_signals, segment_length=segment_length):
    n = len(reference_signals[0])
    min_error = float('inf')
    best_params = None
    best_segments = (None, None)
    
    for i in range(n - segment_length + 1):
        reference_segment_x = reference_signals[0][i:i + segment_length]
        reference_segment_y = reference_signals[1][i:i + segment_length]
        
        for j in range(n - segment_length + 1):
            target_segment_x = target_signals[0][j:j + segment_length]
            target_segment_y = target_signals[1][j:j + segment_length]
            
            initial_params = [0, 0]
            bounds = [(-np.inf, np.inf), (-segment_length // 2, segment_length // 2)]
            
            result = minimize(signal_error, initial_params, args=([reference_segment_x, reference_segment_y], [target_segment_x, target_segment_y]), bounds=bounds)
            
            if result.fun < min_error:
                min_error = result.fun
                best_params = result.x
                best_segments = (i, j)
    
    return best_segments, best_params, min_error

In [28]:
def plot_signals( 
    adjusted_signals, 
    reference_signals,
    model_name,
    show_plot=False,
    title="Signal Comparison"
):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        t = np.arange(adjusted_signals[0].shape[0])
        
        fig, axs = plt.subplots(1, 2, figsize=(15, 10))
        
        axs[0].plot(t, adjusted_signals[0], label='Adjusted Signal Vx', dashes=[6, 2])
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, adjusted_signals[1], label='Adjusted Signal Vy', dashes=[6, 2])
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        axs[0].plot(t, reference_signals[0], label='Reference Signal Vx')
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, reference_signals[1], label='Reference Signal Vy')
        axs[1].set_xlabel('Time')
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        fig.suptitle(title)
        plt.tight_layout()

        plt.savefig(os.path.join(out_dir, f'{model_name}.png'))

        if show_plot:
            plt.show()


In [29]:
data_store = []

In [30]:
for idx, row in filtered_df.iterrows():
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        with torch.no_grad():
            model = SindyModel(**vars(ARGS))
            model.load_state_dict(torch.load(os.path.join(model_dir, row.model_name)))
            model.eval()
    
            pred_y = model(test_x).numpy()
            pred_y = pred_y.reshape(pred_y.shape[0], pred_y.shape[-1]).astype(np.complex64)
    
            reference_signal_x = np.copy(test_y)[:, 0]
            target_signal_x = pred_y[:, 0]
            adjusted_vx = adjust_signal_single(target_signal_x, row.mean_adjustment_vx, row.time_shift_vx)
            
            reference_signal_y = np.copy(test_y)[:, 1]
            target_signal_y = pred_y[:, 1]
            adjusted_vy = adjust_signal_single(target_signal_y, row.mean_adjustment_vy, row.time_shift_vy)

            best_segments, best_params, min_error = find_best_overlap([reference_signal_x, reference_signal_y], [adjusted_vx, adjusted_vy])

            print(f'#: {idx}, model_name: {row.model_name}, best_segments: {best_segments}, min_error: {min_error}, best_params: {best_params}')
            data_store.append([row.model_name, *best_segments, segment_length, min_error, *best_params])

            best_segmented_vx, best_segmented_vy = adjust_signal([adjusted_vx, adjusted_vy], *best_params)

            
            
            # plot_signals(
            #     [
            #         best_segmented_vx[best_segments[1]:best_segments[1] + segment_length],
            #         best_segmented_vy[best_segments[1]:best_segments[1] + segment_length],
            #     ],
            #     [
            #         reference_signal_x[best_segments[0]:best_segments[0] + segment_length],
            #         reference_signal_y[best_segments[0]:best_segments[0] + segment_length]
            #     ],
            #     row.model_name
            # )

            
            

#: 0, model_name: model_2025-02-09 09:28:24.650612_epoch-35.pth, best_segments: (74, 26), min_error: (0.0006304606432739946+0j), best_params: [2.00638038e-03 7.35868576e+01]
#: 1, model_name: model_2025-02-09 06:26:35.178491_epoch-10.pth, best_segments: (28, 142), min_error: (0.0006377767736374539+0j), best_params: [-0.00472966  0.50895494]
#: 2, model_name: model_2025-02-09 05:14:53.091110_epoch-23.pth, best_segments: (7, 150), min_error: (0.000552768704535379+0j), best_params: [-1.06317734e-03 -6.65628388e+01]
#: 3, model_name: model_2025-02-09 09:33:49.962264_epoch-15.pth, best_segments: (106, 27), min_error: (0.002446936961256473+0j), best_params: [-2.04647518e-03  7.45638093e+01]
#: 4, model_name: model_2025-02-09 10:20:57.266285_epoch-35.pth, best_segments: (130, 140), min_error: (0.00019616062757491957+0j), best_params: [-0.00536603  0.5094581 ]
#: 5, model_name: model_2025-02-09 06:23:38.698451_epoch-15.pth, best_segments: (115, 39), min_error: (0.0048123218568794125+0j), best_

In [31]:
out_file_path = os.path.join(out_dir, 'best_segmented_results.csv')
np.savetxt(out_file_path, data_store, fmt='%s', delimiter=',')

In [32]:
headers = ['model_name', 'start_ref', 'start_target', 'segment_length', 'min_error', 'mean_adjustment', 'time_shift']

with open(out_file_path, 'r+') as fh:
    content = fh.read()
    fh.seek(0, 0)
    fh.write(','.join(headers) + '\n' + content)
