In [1]:
import sys
sys.path.append('../')

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import warnings
from scipy.optimize import minimize
from torch.utils.data import DataLoader

from src.model import SindyModel
from src.main import get_args
from src.dataloader import ScrewdrivingDataset

In [3]:
ARGS = get_args()
data_dir = os.path.join('..', 'out', 'sindy-data')
model_dir = os.path.join('..', 'out', 'sindy-model-out', 'checkpoints')
out_dir = os.path.join('..', 'out', 'sindy-out-processed')
os.makedirs(out_dir, exist_ok=True)
model_checkpoints = sorted(filter(lambda f: f.endswith('.pth'), os.listdir(model_dir)))
print(f'found {len(model_checkpoints)} models')

found 12153 models


In [4]:
%matplotlib inline
plt.ioff()

<contextlib.ExitStack at 0x77e3c4091900>

In [5]:
def signal_error(params, reference_signal, target_signal):
    adjusted_signal = adjust_signal(target_signal, *params)
    return np.sum((reference_signal - adjusted_signal) ** 2)

def adjust_signal(target_signal, mean_adjustment, time_shift):
    shifted_target_signal = np.interp(np.arange(len(target_signal)), np.arange(len(target_signal)) - time_shift, target_signal)
    return shifted_target_signal + mean_adjustment

In [6]:
def plot_signals(
    pred_signals, 
    adjusted_signals, 
    reference_signals,
    model_name,
    show_plot=False,
    title="Signal Comparison"
):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        
        t = np.arange(len(pred_signals[0]))
        
        fig, axs = plt.subplots(1, 2, figsize=(15, 10))
        
        axs[0].plot(t, pred_signals[0], label='Predicted Signal Vx')
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, pred_signals[1], label='Predicted Signal Vy')
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        axs[0].plot(t, adjusted_signals[0], label='Adjusted Signal Vx', dashes=[6, 2])
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, adjusted_signals[1], label='Adjusted Signal Vy', dashes=[6, 2])
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        axs[0].plot(t, reference_signals[0], label='Reference Signal Vx')
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel('Amplitude')
        axs[0].legend()
        axs[0].grid(True)
        
        axs[1].plot(t, reference_signals[1], label='Reference Signal Vy')
        axs[1].set_xlabel('Time')
        axs[1].set_ylabel('Amplitude')
        axs[1].legend()
        axs[1].grid(True)
        
        fig.suptitle(title)
        plt.tight_layout()

        plt.savefig(os.path.join(out_dir, f'{model_name}.png'))

        if show_plot:
            plt.show()


In [7]:
model_final_results = []

test_loader = DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': data_dir}))
test_x = torch.cat([batch[0] for batch in test_loader], dim=0)

test_y = [batch[1].numpy() for batch in DataLoader(ScrewdrivingDataset(mode='test', **{**vars(ARGS), 'data_dir': data_dir}))]
test_y = np.concatenate(test_y, axis=0)
test_y = test_y.reshape(test_y.shape[0], test_y.shape[-1])

initial_guess = [0., 0.]

In [8]:
for idx, model_name in enumerate(model_checkpoints):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)

        with torch.no_grad():
            model = SindyModel(**vars(ARGS))
            model.load_state_dict(torch.load(os.path.join(model_dir, model_name)))
            model.eval()

            pred_y = model(test_x).numpy()
            pred_y = pred_y.reshape(pred_y.shape[0], pred_y.shape[-1]).astype(np.complex64)

            reference_signal_x = np.copy(test_y)[:, 0]
            target_signal_x = pred_y[:, 0]
            result_vx = minimize(signal_error, initial_guess.copy(), args=(reference_signal_x, target_signal_x), method='Nelder-Mead')
            adjusted_vx = adjust_signal(target_signal_x, *result_vx.x)

            reference_signal_y = np.copy(test_y)[:, 1]
            target_signal_y = pred_y[:, 1]
            result_vy = minimize(signal_error, initial_guess.copy(), args=(reference_signal_y, target_signal_y), method='Nelder-Mead')
            adjusted_vy = adjust_signal(target_signal_y, *result_vy.x)

            # plot_signals([pred_y[:, 0], pred_y[:, 1]], [adjusted_vx, adjusted_vy], [reference_signal_x, reference_signal_y], model_name, False)

            final_x_error = signal_error(result_vx.x, reference_signal_x, adjusted_vx)
            final_y_error = signal_error(result_vy.x, reference_signal_y, adjusted_vy)
            l2_error = np.linalg.norm([final_x_error, final_y_error], ord=2)

            print(f'#: {idx}, model_name: {model_name}, vx_err: {final_x_error}, vx_params: {result_vx.x}, vy_err: {final_y_error}, vy_params: {result_vy.x}, l2_error: {l2_error}')

            model_final_results.append((model_name, np.real(final_x_error), *result_vx.x, np.real(final_y_error), *result_vy.x, l2_error))


#: 0, model_name: model_2025-02-08 22:09:57.915698_epoch-0.pth, vx_err: (9699.549090065788+0j), vx_params: [5.68611134 1.49485105], vy_err: (2163.0713355561466+0j), vy_params: [-2.68391416  2.5868308 ], l2_error: 9937.813147433428
#: 1, model_name: model_2025-02-08 22:10:05.722402_epoch-1.pth, vx_err: (10625.227542138124+0j), vx_params: [5.95123383 1.49556342], vy_err: (3100.2920287260467+0j), vy_params: [-3.21388018  2.70781025], l2_error: 11068.300275362646
#: 2, model_name: model_2025-02-08 22:10:10.925205_epoch-2.pth, vx_err: (5936.674489972649+0j), vx_params: [4.44836596 0.50370459], vy_err: (4412.883539897832+0j), vy_params: [-3.83668054 54.47568136], l2_error: 7397.137631313428
#: 3, model_name: model_2025-02-08 22:10:18.712723_epoch-3.pth, vx_err: (3331.63435166839+0j), vx_params: [3.33236725 0.62148162], vy_err: (5157.570084875036+0j), vy_params: [-4.14571389  3.49624298], l2_error: 6140.058357508888
#: 4, model_name: model_2025-02-08 22:10:23.953547_epoch-4.pth, vx_err: (1965

In [9]:
out_file_path = os.path.join(out_dir, 'cumulative_results.csv')
np.savetxt(out_file_path, model_final_results, fmt='%s', delimiter=',')

In [10]:
headers = ['model_name', 'vx_error', 'mean_adjustment_vx', 'time_shift_vx', 'vy_error', 'mean_adjustment_vy', 'time_shift_vy', 'l2_error']

with open(out_file_path, 'r+') as fh:
    content = fh.read()
    fh.seek(0, 0)
    fh.write(','.join(headers) + '\n' + content)
