# Masked U-net based Cycle-consistent Adversarial Networks

In [1]:
import sys

import numpy as np
import torch
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

sys.path.append('./utils/')
from datakits import get_dataset
from decoder import RidgeRegression as WienerFilter
from model_mucan import Aligner

In [2]:
data_dir = './dataset/Jango_ISO_2015/'
device = torch.device('cuda:0')

---
### Load data

In [3]:
d0_x, d0_y = get_dataset(data_dir, 'Jango_20150730_001.npz', 'EMG')
dk_x, dk_y = get_dataset(data_dir, 'Jango_20150731_001.npz', 'EMG')

---
### Train day-0 decoder

In [4]:
d0_decoder = WienerFilter()
d0_decoder.fit_with_kfold(d0_x, d0_y, n_lags=4, n_splits=4)

Training: 100%|██████████| 4/4 [00:04<00:00,  1.24s/it, [32mCPU=100.0% | 2.5/125.6G[0m, [35mNVIDIA GeForce RTX 4090=0% | 3/24564M[0m]


0.7251491892525731

---
### Pretrain aligner

In [5]:
d0_x_fh, dk_x_sh, _, _ = train_test_split(
    d0_x, d0_y, train_size=0.5, shuffle=False
)

In [6]:
input_dim = d0_x_fh[0].shape[-1]
aligner = Aligner(input_dim=input_dim)
aligner.fit(
    d0_x_fh, dk_x_sh, device, param={'n_epochs': 400, 'n_masks': 30}
)

Training: 100%|██████████| 400/400 [01:10<00:00,  5.66it/s, [32mCPU=4.5% | 4.5/125.6G[0m, [35mNVIDIA GeForce RTX 4090=36% | 1478/24564M[0m]


71.47916566114873

---
### Finetune aligner

In [7]:
d0_x_train = d0_x.copy()
dk_x_train, dk_x_test, _, dk_y_test = train_test_split(
    dk_x, dk_y, train_size=120, shuffle=False
)

In [8]:
aligner.fit(
    d0_x_train, dk_x_train, device, param={'n_epochs': 200}
)

Training: 100%|██████████| 200/200 [00:50<00:00,  3.93it/s, [32mCPU=4.2% | 4.5/125.6G[0m, [35mNVIDIA GeForce RTX 4090=37% | 1926/24564M[0m]


50.8984676906839

### Evaluation

In [9]:
dk_x_test_aligned = aligner.transform(dk_x_test, device)
dk_y_test_pred = d0_decoder.predict(dk_x_test_aligned)
r2_aligned = r2_score(
    np.concatenate(dk_y_test), np.concatenate(dk_y_test_pred),
    multioutput='variance_weighted'
)
print(f'The R\u00b2 achieved {r2_aligned:.4f} on day-k test-set')

The R² achieved 0.7412 on day-k test-set
