In [None]:
import torch
import numpy as np

from prep import TimeWindowTransformer, LabelWindowExtractor
from models import ConvNN, CrossValidationManager

In [None]:
# === Dataset Class ===
class EMGDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, standardize=True):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        if standardize:
            mean = self.X.mean(dim=(0, 2), keepdim=True)
            std = self.X.std(dim=(0, 2), keepdim=True)
            self.X = (self.X - mean) / (std + 1e-8)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Model and training configs
training_config = {
    "lr": 1e-3,
    "epochs": 100,
    "batch_size": 64,
    "log_every": 1
}

In [None]:
DATASET = 'freemoves'
# FILTER = '_filt'
X_PATH = f'data/{DATASET}/{DATASET}_dataset_X.npy'
Y_PATH = f'data/{DATASET}/{DATASET}_dataset_Y.npy'

X = np.load(X_PATH)
Y = np.load(Y_PATH)

tw_extractor = TimeWindowTransformer(size = 500, step = 100)
label_extractor = LabelWindowExtractor(size = 500, step = 100)

X_windows = tw_extractor.transform(X)
Y_labels = label_extractor.transform(Y)

train_val_idx = [0,1,2,4]
test_idx = 3

X_train_val = X_windows[train_val_idx].reshape(-1, *X_windows.shape[2:])
Y_train_val = Y_labels[train_val_idx].reshape(-1, *Y_labels.shape[2:])

X_test = X_windows[test_idx]
Y_test = Y_labels[test_idx]

print(X_train_val.shape)
print(Y_train_val.shape)
print(X_test.shape)
print(Y_test.shape)

In [None]:
cv = CrossValidationManager(
    model_class = ConvNN,
    model_config = {'end_dim': 51},
    data = X_train_val,
    labels = Y_train_val,
    training_config = training_config,
    dataset_class = EMGDataset,
    n_folds = 4
)

cv.run()