In [None]:
from fastai.tabular.all import *
import torch
from torch import nn
import numpy as np
from model import stagerNetAAE
from utils import UnfreezeFcCrit

# Set your paths
github_path = 'fastAI/GitHub/' # path to GitHub Monitosed folder 
eeg_dir = '/home/JennebauffeC/pytorchVAE/fastAI/data/LEMON/Preprocessed/' # path to public dataset
biowin_file_path = '/home/JennebauffeC/pytorchVAE/fastAI/Monitosed/Biowin/sub-001/VR_1_VRH.vhdr' # path to reference file for electrode positions
data_dir = '/home/JennebauffeC/pytorchVAE/fastAI/data/LEMON'

# device = torch.device(torch.cuda.current_device())
dev = torch.device('cuda:1')
device = torch.device(dev if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print('the current device is: ', device)

%cd fastAI/GitHub/
from model20 import stagerNetAAE
from utils1 import LossAttrMetric, GetLatentSpace, norm_batch, UnfreezeFcCrit, SwitchAttribute, distrib_regul_regression, hist_lab, plot_results

In [None]:
hyp_files = [file for file in eeg_dir.glob('*_HYP.mat')]
rest_files = [file for file in eeg_dir.glob('*_Resting.mat')]
vrh_files = [file for file in eeg_dir.glob('*_VRH.mat')]

In [None]:
x_hyp, y_hyp = preprocess_files(hyp_files, 1)
x_resting, y_resting = preprocess_files(rest_files, 0)
x_vrh, y_vrh = preprocess_files(vrh_files, 2)

### Step 3: Create the Train/Valid sets for different combinations of EEG states.

In [None]:
epochs = create_epochs_from_tensors(data_path)
x = torch.Tensor(get_ica(epochs), device=device)

## Train the Auto-Encoder part

In [None]:
ae_filename = ''
acc_factor = 1
model = stagerNetAAE(latent_dim=64, channels=x.shape[1], timestamps=x.shape[-1], acc_factor=acc_factor, dropout_rate=.3)

metrics = [rmse]
learn = Learner(dls, model, loss_func = model.ae_loss_func, metrics=metrics, opt_func=ranger)
learning_rate = learn.lr_find()
print('learning rate: '+str(learning_rate.valley))
learn.fit_flat_cos(n_epoch=100, lr=learning_rate.valley, 
                   cbs=[
                        GradientAccumulation(n_acc=dls.bs*acc_factor),
                        TrackerCallback(),
                        SaveModelCallback(fname=ae_filename),
                        EarlyStoppingCallback(min_delta=1e-4,patience=10)])

state_dict = torch.load(f'models/{ae_filename}.pth') # load the best weights

## Train Adversarial

In [None]:
acc_factor = 1
model = stagerNetAAE(latent_dim=64, channels=x.shape[1], timestamps=x.shape[-1],
                     acc_factor=acc_factor, dropout_rate=0.5,
                     adv_weight=0.9, recons_weight=0.1)
# pretrained_filename = 'monitosed_pretrained_LEMON_aae_10s_shared'
# state_dict = torch.load(f'models/{pretrained_filename}.pth') # load the best weights
model.load_state_dict(state_dict, strict=False)
model = model.to(device)
model.gen_train = False

# adv_filename = 'monitosed_BIOWIN_aae_REST_VRH_05_24_16h' #reference for adversarial
adv_filename = 'monitosed_BIOWIN_aae_REST_VRH_05_24_17h_latent64'

metrics = [LossAttrMetric("recons_loss"),
           LossAttrMetric("discrim_loss"),
           LossAttrMetric("adv_loss")]
learn = Learner(dls, model, loss_func=model.aae_loss_func_monitosed,
               metrics=metrics, opt_func=ranger)

learning_rate = learn.lr_find()
print('learning rate: '+str(learning_rate.valley))
learn.fit_flat_cos(n_epoch=100, lr=learning_rate.valley,
# learn.fit_flat_cos(50, lr=1e-2,
                        cbs=[
                            GradientAccumulation(n_acc=dls.bs*acc_factor),
                            TrackerCallback(monitor='adv_loss'),
#                             CustomSaveModelCallback(fname=adv_filename, monitor='adv_loss', start_at=2),
                            SaveModelCallback(fname=adv_filename, monitor='adv_loss'),
                            EarlyStoppingCallback(min_delta=1e-4,patience=20,monitor='adv_loss'),
                            UnfreezeFcCrit(switch_every=2)])

state_dict = torch.load(f'models/{adv_filename}.pth') # load the best weights

In [None]:
save_path = ''

z, target = extract_latent(state_dict, save_path, save)
plot_results(z.to(device),target.cpu(),filename=save_path)

## Train Classifier

In [None]:
acc_factor = 1
model = stagerNetAAE(latent_dim=64, dropout_rate=.5, channels=x.shape[1],
                     timestamps=x.shape[-1], acc_factor=acc_factor,
                    recons_weight=.1, adv_weight=.4, classif_weight=.5)

# pretrained_filename = 'monitosed_pretrained_LEMON_aae_10s_shared' # .pth of our best EC/EO model -> rename both
# pretrained_filename = 'monitosed_BIOWIN_aae_classif_REST_VRH_05_24_17h_latent64' # .pth of our best model -> rename both
# state_dict = torch.load(f'models/{pretrained_filename}.pth') # load the best weights
model.load_state_dict(state_dict, strict=False)
model = model.to(device)

classif_filename = ''

metrics = [LossAttrMetric("recons_loss"),
           LossAttrMetric("adv_loss"),
           LossAttrMetric("discrim_loss"),
           LossAttrMetric("classif_loss")]

learn = Learner(dls, model, loss_func=model.aae_classif_loss_func_monitosed, metrics=metrics, opt_func=ranger, wd=1e-1)

learn.fit_flat_cos(100, lr=1e-2,
                        cbs=[
                            GradientAccumulation(n_acc=dls.bs*acc_factor),
                            TrackerCallback(monitor='valid_loss'),
                            SaveModelCallback(fname=classif_filename, monitor='valid_loss'),
                            EarlyStoppingCallback(min_delta=1e-4,patience=20,monitor='valid_loss'),
                            UnfreezeFcCrit(switch_every=5)])

state_dict = torch.load(f'models/{classif_filename}.pth') # load the best weights

In [None]:
save_path = ''

z, target = extract_latent(state_dict, save_path, save)
plot_results(z.to(device),target.cpu(),filename=save_path)