In [1]:
import torch
from src.utils.eegimu_dataset import EEGIMUDataset
from src.lstm_train.models import LSTMRegressor
from src.utils.utils import plot_imu_reconstruction, reconstruct_signal
from pathlib import Path
import json

modelname = "run_20250607_1945"
weights_path = f"../../saved_models/weights/{modelname}.pt"
config_path = f"../../saved_models/configs/{modelname}.json"

with open(config_path, "r") as f:
    config = json.load(f)


In [2]:
# Define Dataset
csv_folder = "../../data/EEG_IMU/"
csv_path = sorted(Path(csv_folder).glob("*.csv"))[0]
ds = EEGIMUDataset(csv_path, window=config["window"], stride=config["stride"], bandpass = (1,30))

# TRAIN/TEST/VAL SPLIT
train_loader, val_loader, test_loader = ds.train_test_val_split(config["batch_size"])

# LOAD MODEL
net = LSTMRegressor(in_dim=16, config=config, out_dim=12)
net.load_state_dict(torch.load(weights_path, map_location="cpu"))
net.eval()

# TESTING
test_loss = 0.0
with torch.no_grad():
    for x_test, y_test in test_loader:
        preds = net(x_test)
        loss = net.crit(preds, y_test)
        test_loss += loss.item() * x_test.size(0)
test_loss /= len(test_loader.dataset)
print(f"Test loss: {test_loss:.3f}")

Test loss: 0.598


In [4]:
full_pred, full_true = reconstruct_signal(
    net     = net,
    loader  = train_loader
)

# Plot an entire IMU channel (e.g. x-axis)
fig = plot_imu_reconstruction(
    true    = full_true,
    pred    = full_pred,
    channel = 8,
    timesteps= (1300,2000),
    title   = "Test-segment IMU X-axis reconstruction"
)


KeyboardInterrupt: 

[[-3.9432794e-03  2.2080325e-02  4.5805170e-03 ...  7.3078856e-02
   6.9365976e-03  1.6468369e-03]
 [-5.4118247e-04  4.3996884e-03  7.6756900e-04 ... -3.3106312e-02
   4.9341930e-03  1.1426837e-02]
 [ 4.6584120e-03  1.5170259e-02 -1.7349355e-02 ... -6.9723316e-02
  -1.2804978e-03 -5.2166265e-03]
 ...
 [ 1.4914966e-01  1.0748724e-01  1.6005800e-02 ... -1.6658472e-02
   2.8512442e-01  1.9448730e-01]
 [ 3.5194140e-02  1.7595879e-04 -6.9403592e-03 ...  1.9673896e-01
   1.1952744e-02  3.9801410e-03]
 [-3.2141530e-03  1.6117412e+00  1.1049990e-01 ...  6.9124080e-02
  -2.4186051e-03  1.4344696e-02]]


In [1]:
import time

start = time.time()
for i in range(100000):
    x = 100
print(time.time()-start)

0.019251108169555664
