In [10]:
import sys, copy, torch, argparse
import numpy as np
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader

from sEMG_transformer_loo import load_raw_signals, sEMGSignalDataset, sEMGtransformer, count_correct

In [2]:
parser = argparse.ArgumentParser(description="sEMG transformer training configurations")
# experiment config
parser.add_argument('--sub_idx', type=int, default=0, help="subject index")
# training config
parser.add_argument('--seed', type=int, default=0, help="random seed")
parser.add_argument('--epochs', type=int, default=1000, help="number of epochs")
parser.add_argument('--bsz', type=int, default=32, help="batch size")
# optimizer config
parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
parser.add_argument('--wd', type=float, default=0.01, help="weight decay")
parser.add_argument('--step_size', type=int, default=250, help="lr scheduler step size")
parser.add_argument('--gamma', type=float, default=0.8, help="lr scheduler gamma")
# model config
parser.add_argument('--psz', type=int, default=64, help="signal patch size")
parser.add_argument('--d_model', type=int, default=512, help="transformer embedding dim")
parser.add_argument('--nhead', type=int, default=8, help="transformer number of attention heads")
parser.add_argument('--dim_feedforward', type=int, default=2048, help="transformer feed-forward dim")
parser.add_argument('--num_layers', type=int, default=1, help="number of transformer encoder layers")
parser.add_argument('--dropout', type=float, default=0.1, help="dropout rate")

sys.argv = ['f']
config = parser.parse_args()

In [3]:
signals, labels, vfi_1, sub_id, sub_skinfold = load_raw_signals("../data/subjects_40_v6.mat")
np.random.seed(config.seed)
torch.manual_seed(config.seed)

<torch._C.Generator at 0x7f8caeb4d790>

In [4]:
sub_test = config.sub_idx
print(f"Subject R{sub_id[config.sub_idx][0][0][0]}")

X, Y, C = [], [], []
for i in range(40):
  # stack all inputs into [N,C,L] format
  x = np.stack(signals[i], axis=1)

  # one-hot encode the binary labels
  N = labels[i][0].shape[0]
  mapped_indices = (labels[i][0] == 1).astype(int)
  y_onehot = np.zeros((N, 2))
  y_onehot[np.arange(N), mapped_indices.flatten()] = 1

  X.append(x)
  Y.append(y_onehot)
  C.append(sub_skinfold[i][0].mean(axis=1))


Subject R44


In [5]:

# normalize the signals channel-wise
X_means = np.mean(np.concatenate(X, axis=0), axis=(0,2))
X_stds = np.std(np.concatenate(X, axis=0), axis=(0,2))
for i in range(40):
  X[i] = (X[i] - X_means[np.newaxis,:,np.newaxis]) / X_stds[np.newaxis,:,np.newaxis]
print(f"X {np.concatenate(X, axis=0).shape}")

# leave-one-out split
X_test, Y_test, C_test = X[sub_test], Y[sub_test], C[sub_test]
X, Y, C = X[:sub_test] + X[sub_test+1:], Y[:sub_test] + Y[sub_test+1:], C[:sub_test] + C[sub_test+1:]
X, Y, C = np.concatenate(X, axis=0), np.concatenate(Y, axis=0), np.concatenate(C, axis=0)

num_samples = X.shape[0]
indices = np.arange(num_samples)
np.random.shuffle(indices)
split_idx = int(num_samples*0.9)
train_idx, valid_idx = indices[:split_idx], indices[split_idx:]

X_train, X_valid = X[train_idx], X[valid_idx]
Y_train, Y_valid = Y[train_idx], Y[valid_idx]
C_train, C_valid = C[train_idx], C[valid_idx]
print(f"X_train {X_train.shape}")
print(f"X_valid {X_valid.shape}")
print(f"X_test {X_test.shape}")

dataset_train = sEMGSignalDataset(X_train, Y_train)
dataset_valid = sEMGSignalDataset(X_valid, Y_valid)
dataset_test  = sEMGSignalDataset(X_test, Y_test)

dataloader_train = DataLoader(dataset_train, batch_size=config.bsz, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=config.bsz, shuffle=False)
dataloader_test  = DataLoader(dataset_test,  batch_size=config.bsz, shuffle=False)

model = sEMGtransformer(patch_size=config.psz, d_model=config.d_model, nhead=config.nhead, dim_feedforward=config.dim_feedforward,
                        dropout=config.dropout, num_layers=config.num_layers)
model.to("cuda")

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()

accuracy_valid_best = 0
accuracy_test_best = 0
model_best = None
for epoch in tqdm(range(config.epochs), desc="Training"):
  loss_train = 0
  correct_train = 0
  model.train()
  for batch, (inputs, targets) in enumerate(dataloader_train):
    inputs, targets = inputs.to("cuda"), targets.to("cuda")
    optimizer.zero_grad()
    with torch.autocast(device_type="cuda", dtype=torch.float16):
      outputs = model(inputs)
      loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    correct_train += count_correct(outputs, targets)
    loss_train += loss.item()

  loss_valid = 0
  correct_valid = 0
  model.eval()
  for inputs, targets in dataloader_valid:
    inputs, targets = inputs.to("cuda"), targets.to("cuda")
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    correct_valid += count_correct(outputs, targets)
    loss_valid += loss.item()

  if correct_valid/len(dataset_valid) > accuracy_valid_best: 
    accuracy_valid_best = correct_valid/len(dataset_valid)
    print(f"accuracy_valid_best: {accuracy_valid_best}")
    model_best = copy.deepcopy(model)
    correct_test = 0
    for inputs, targets in dataloader_test:
      inputs, targets = inputs.to("cuda"), targets.to("cuda")
      outputs = model(inputs)
      correct_test += count_correct(outputs, targets)
    accuracy_test_best = correct_test/len(dataset_test)
    print(f"accuracy_test_best: {accuracy_test_best}")

  scheduler.step()

print(f"accuracy_valid_best: {accuracy_valid_best}")
print(f"accuracy_test_best: {accuracy_test_best}")

X (6472, 4, 4000)
X_train (5676, 4, 4000)
X_valid (631, 4, 4000)
X_test (165, 4, 4000)


Training:   0%|          | 1/1000 [00:01<21:10,  1.27s/it]

accuracy_valid_best: 0.5768621236133122
accuracy_test_best: 0.09696969696969697


Training:   0%|          | 2/1000 [00:02<18:40,  1.12s/it]

accuracy_valid_best: 0.6973058637083994
accuracy_test_best: 0.7393939393939394


Training:   0%|          | 3/1000 [00:03<17:53,  1.08s/it]

accuracy_valid_best: 0.7242472266244057
accuracy_test_best: 0.6666666666666666


Training:   1%|          | 6/1000 [00:06<16:59,  1.03s/it]

accuracy_valid_best: 0.7432646592709984
accuracy_test_best: 0.8


Training:   1%|          | 12/1000 [00:12<16:31,  1.00s/it]

accuracy_valid_best: 0.7591125198098256
accuracy_test_best: 0.6424242424242425


Training:   1%|▏         | 14/1000 [00:14<16:32,  1.01s/it]

accuracy_valid_best: 0.7638668779714739
accuracy_test_best: 0.6727272727272727


Training:   2%|▏         | 24/1000 [00:24<16:21,  1.01s/it]

accuracy_valid_best: 0.7702060221870047
accuracy_test_best: 0.7090909090909091


Training:   3%|▎         | 28/1000 [00:28<18:49,  1.16s/it]

accuracy_valid_best: 0.7765451664025357
accuracy_test_best: 0.8121212121212121


Training:   4%|▍         | 38/1000 [00:39<16:11,  1.01s/it]

accuracy_valid_best: 0.7781299524564184
accuracy_test_best: 0.7757575757575758


Training:   7%|▋         | 71/1000 [01:11<15:45,  1.02s/it]

accuracy_valid_best: 0.7844690966719493
accuracy_test_best: 0.8121212121212121


Training:   7%|▋         | 72/1000 [01:12<15:59,  1.03s/it]

accuracy_valid_best: 0.786053882725832
accuracy_test_best: 0.7575757575757576


Training:   7%|▋         | 74/1000 [01:14<16:15,  1.05s/it]

accuracy_valid_best: 0.8114104595879557
accuracy_test_best: 0.6666666666666666


Training:  12%|█▏        | 117/1000 [01:58<15:02,  1.02s/it]

accuracy_valid_best: 0.8288431061806656
accuracy_test_best: 0.5696969696969697


Training:  12%|█▎        | 125/1000 [02:06<15:25,  1.06s/it]

accuracy_valid_best: 0.8415213946117274
accuracy_test_best: 0.6181818181818182


Training:  21%|██        | 209/1000 [03:29<13:39,  1.04s/it]

accuracy_valid_best: 0.8462757527733756
accuracy_test_best: 0.6848484848484848


Training:  22%|██▏       | 216/1000 [03:36<14:12,  1.09s/it]

accuracy_valid_best: 0.8510301109350238
accuracy_test_best: 0.6363636363636364


Training:  25%|██▌       | 251/1000 [04:12<12:06,  1.03it/s]

accuracy_valid_best: 0.8557844690966719
accuracy_test_best: 0.5515151515151515


Training:  34%|███▍      | 345/1000 [05:45<10:36,  1.03it/s]

accuracy_valid_best: 0.8589540412044374
accuracy_test_best: 0.5878787878787879


Training:  42%|████▏     | 417/1000 [06:54<08:31,  1.14it/s]

accuracy_valid_best: 0.8621236133122029
accuracy_test_best: 0.6606060606060606


Training:  49%|████▉     | 488/1000 [07:57<07:37,  1.12it/s]

accuracy_valid_best: 0.8652931854199684
accuracy_test_best: 0.6363636363636364


Training:  49%|████▉     | 492/1000 [08:00<07:35,  1.11it/s]

accuracy_valid_best: 0.8716323296354992
accuracy_test_best: 0.6363636363636364


Training:  53%|█████▎    | 530/1000 [08:34<06:56,  1.13it/s]

accuracy_valid_best: 0.873217115689382
accuracy_test_best: 0.6303030303030303


Training:  70%|██████▉   | 699/1000 [11:04<04:28,  1.12it/s]

accuracy_valid_best: 0.8763866877971473
accuracy_test_best: 0.6909090909090909


Training:  72%|███████▏  | 717/1000 [11:21<04:19,  1.09it/s]

accuracy_valid_best: 0.8811410459587956
accuracy_test_best: 0.5757575757575758


Training:  76%|███████▋  | 763/1000 [12:02<03:37,  1.09it/s]

accuracy_valid_best: 0.8827258320126783
accuracy_test_best: 0.6242424242424243


Training:  77%|███████▋  | 770/1000 [12:09<03:28,  1.11it/s]

accuracy_valid_best: 0.884310618066561
accuracy_test_best: 0.6545454545454545


Training:  79%|███████▉  | 788/1000 [12:25<03:10,  1.12it/s]

accuracy_valid_best: 0.8858954041204438
accuracy_test_best: 0.6909090909090909


Training:  79%|███████▉  | 791/1000 [12:28<03:16,  1.06it/s]

accuracy_valid_best: 0.8890649762282092
accuracy_test_best: 0.593939393939394


Training:  84%|████████▍ | 839/1000 [13:11<02:26,  1.10it/s]

accuracy_valid_best: 0.8922345483359746
accuracy_test_best: 0.5818181818181818


Training: 100%|██████████| 1000/1000 [15:34<00:00,  1.07it/s]

accuracy_valid_best: 0.8922345483359746
accuracy_test_best: 0.5818181818181818





In [6]:
C_train = C[train_idx]
Y_train_cpt = np.argmax(Y_train, axis=1)
Y_pred = []
dataloader_train = DataLoader(dataset_train, batch_size=config.bsz, shuffle=False)
for inputs, targets in dataloader_train:
  inputs, targets = inputs.to("cuda"), targets.to("cuda")
  outputs = model_best(inputs)
  _, predicted = torch.max(F.softmax(outputs, dim=1), 1)
  Y_pred.append(predicted.cpu().numpy())
Y_pred_cpt = np.concatenate(Y_pred, axis=0)

In [26]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(Y_train_cpt, Y_pred_cpt)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 1.0000


In [29]:
from mlconfound.stats import partial_confound_test

ret = partial_confound_test(Y_train_cpt, Y_pred_cpt, C_train, cat_y=True, cat_yhat=True, cat_c=False)
print(ret.p)

Permuting: 100%|██████████| 1000/1000 [00:23<00:00, 42.65it/s]

0.0



