In [1]:
%load_ext autoreload
%autoreload 2

# Train

In [2]:
#
import os

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import tqdm
from datasets import train_dataset, test_dataset
from params_sssd import params
from SSSD import SSSD, align, loss_fn, eval_error

lr = 1e-2
batch_size = 256
epoch = 50

device = torch.device('cuda:2')

model = SSSD(params=params, mode=params.mode, measure=params.measure, bidirectional=params.bidirectional).to(device)
try:
    model.load_state_dict(torch.load(r'/home/wyl/projects/_BSS_hijack/_end_to_end_compare/02_SSSD/ckpt_sssd/ckpt.pth'))
    print('model loaded')
except:
    print('model does not exise')

train_data = train_dataset()
test_data = test_dataset()
data_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)

optimizer = AdamW(model.parameters(), lr=lr, eps=3e-4)

tqdm_epoch = tqdm.notebook.tqdm(range(epoch))

for _ in tqdm_epoch:
    avg_loss = 0.
    num_items = 0
    for train_data, truth in data_loader:
        train_data = align(train_data.squeeze(1)).to(device)
        truth = align(truth.squeeze(1)).to(device)

        output = model(train_data)
        loss = loss_fn(output, truth)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item() * train_data.shape[0]
        num_items += train_data.shape[0]

    tqdm_epoch.set_description('Average Loss:{:5f}'.format(avg_loss / num_items))
    torch.save(model.state_dict(), r'/home/wyl/projects/_BSS_hijack/_end_to_end_compare/02_SSSD/ckpt_sssd/ckpt.pth')

    with torch.no_grad():
        
        test_loss = 0.
        test_items = 0
        for test_data, test_truth in test_loader:
            test_data = align(test_data.squeeze(1)).to(device)
            test_truth = align(test_truth.squeeze(1)).to(device)

            output = model(test_data)
            loss = loss_fn(output, test_truth)

            test_loss += loss.item() * train_data.shape[0]
            test_items += train_data.shape[0]

    # print('test loss: {:5f}'.format(test_loss / test_items))
    # Print the averaged training loss so far.
    tqdm_epoch.set_description('Average Loss: {:5f},   Test Loss: {:5f}'.format(avg_loss / num_items, test_loss / test_items))
    # Update the checkpoint after each epoch of training.
    # torch.save(model.state_dict(), r'/home/wyl/projects/_BSS_hijack/_end_to_end_compare/02_SSSD/ckpt_sssd/ckpt.pth') 

model does not exise


  0%|          | 0/50 [00:00<?, ?it/s]

# Eval

In [7]:
import torch
from SSSD import SSSD
import matplotlib.pyplot as plt
from datasets import train_dataset, test_dataset
from torch.utils.data import DataLoader, RandomSampler
from params_sssd import params

seed = torch.manual_seed(1001)
train_sampler = RandomSampler(train_dataset(), replacement=False)
test_sampler = RandomSampler(test_dataset(), replacement=False)
train_loader = DataLoader(train_dataset(), batch_size=20, shuffle=True)
test_loader = DataLoader(test_dataset(), batch_size=20, shuffle=True)

train_data, train_truth = next(iter(train_loader))
test_data, test_truth = next(iter(test_loader))

randint = torch.randint(0, 20, (1,))
train_data = train_data.squeeze(1).to(torch.float32)
train_truth = train_truth.squeeze(1).to(torch.float32)

test_data = test_data.squeeze(1).to(torch.float32)
test_truth = test_truth.squeeze(1).to(torch.float32)

In [8]:
# end to end eval
model = SSSD(params)
try: 
    model.load_state_dict(torch.load(r'/home/wyl/projects/_BSS_hijack/_end_to_end_compare/02_SSSD/ckpt_sssd/ckpt.pth'))
    print('::: model loaded :::')
except:
    print('::: ckpt does not exist :::')
model.eval()

with torch.no_grad():
    train_output = model(train_data)
    test_output = model(test_data)

result_model = train_output[randint].squeeze(0)
truth_model = train_truth[randint].squeeze(0)

result_test = test_output[randint].squeeze(0)
result_truth = test_truth[randint].squeeze(0)

loss_model_train = loss_fn(train_output, train_truth)
loss_model_test = loss_fn(test_output, test_truth)

print(loss_model_train)
print(loss_model_test)

::: ckpt does not exist :::
tensor(0.7986)
tensor(0.7573)


In [None]:
# plot
fig, axs = plt.subplots(nrows=3, ncols=1)

axs[0].plot(result_truth[0])
axs[0].plot(result_test[0])
axs[0].set_title("EEG")

# axs[0].set_title("ground truth EEG")

axs[1].plot(result_truth[1])
axs[1].plot(result_test[1])
axs[1].set_title("EMG")

# axs[1].set_title("ground truth EMG")

axs[2].plot(result_truth[2], label='truth')
axs[2].plot(result_test[2], label='separated')
axs[2].set_title("EOG")
axs[2].legend()

# axs[2].set_title("ground truth EOG")

fig.suptitle('Comparison Method')

plt.show()