In [None]:
x,y  = preprocess_files_eoec(eeg_dir)
dls = create_dls_eoec(x, y, bs)

## Train the auto-encoder

In [None]:
ae_filename = ''
acc_factor = 1
model = stagerNetAAE(latent_dim=64, channels=x.shape[1], timestamps=x.shape[-1], acc_factor=acc_factor, dropout_rate=.3)

metrics = [rmse]
learn = Learner(dls, model, loss_func = model.ae_loss_func, metrics=metrics, opt_func=ranger)
learning_rate = learn.lr_find()
print('learning rate: '+str(learning_rate.valley))
learn.fit_flat_cos(n_epoch=100, lr=learning_rate.valley, 
                   cbs=[
                        GradientAccumulation(n_acc=dls.bs*acc_factor),
                        TrackerCallback(),
                        SaveModelCallback(fname=ae_filename),
                        EarlyStoppingCallback(min_delta=1e-4,patience=10)])

state_dict = torch.load(f'models/{ae_filename}.pth') # load the best weights

## Train the Adversarial part

In [None]:
model = stagerNetAAE(latent_dim=64, channels=x.shape[1], timestamps=x.shape[-1], acc_factor=acc_factor, adv_weight=0.9, recons_weight=0.1, dropout_rate=.3)
model.load_state_dict(state_dict, strict=False)
adv_filename = 'monitosed_pretrained_LEMON_aae_10s_EC_EO'

metrics = [LossAttrMetric("recons_loss"),
           LossAttrMetric("adv_loss")]
learn = Learner(dls, model, loss_func=model.aae_loss_func_monitosed,
               metrics=metrics, opt_func=ranger)

learn.fit_flat_cos(50, lr=2e-3,
                        cbs=[
                            GradientAccumulation(n_acc=dls.bs*acc_factor),
                            TrackerCallback(monitor='valid_loss'),
                            SaveModelCallback(fname=adv_filename, monitor='valid_loss'),
                            EarlyStoppingCallback(min_delta=1e-4,patience=10,monitor='valid_loss'),
                            UnfreezeFcCrit(switch_every=2)])

state_dict = torch.load(f'models/{adv_filename}.pth') # load the best weights

In [None]:
save_path = ''

z, target = extract_latent(state_dict, save_path, save)
plot_results(z.to(device),target.cpu(),filename=save_path)

## Train the Classifier part

In [None]:
model = stagerNetAAE(latent_dim=64, channels=x.shape[1], timestamps=x.shape[-1], acc_factor=acc_factor, dropout_rate=.3, recons_weight=.1, adv_weight=.4, classif_weight=.5)
model.load_state_dict(state_dict, strict=False)
classif_filename = 'monitosed_pretrained_LEMON_aae_classif_10s_EC_EO'

metrics = [accuracy, LossAttrMetric("recons_loss"),
           LossAttrMetric("adv_loss"),
           LossAttrMetric("classif_loss")]
learn = Learner(dls, model, loss_func=model.aae_classif_loss_func_monitosed,
               metrics=metrics, opt_func=ranger)

learn.fit_flat_cos(50, lr=1e-3,
                        cbs=[
                            GradientAccumulation(n_acc=dls.bs*acc_factor),
                            TrackerCallback(monitor='valid_loss'),
                            SaveModelCallback(fname=classif_filename, monitor='valid_loss'),
                            EarlyStoppingCallback(min_delta=1e-4,patience=10,monitor='valid_loss'),
                            UnfreezeFcCrit(switch_every=5)])

state_dict = torch.load(f'models/{classif_filename}.pth') # load the best weights

In [None]:
save_path = ''

z, target = extract_latent(state_dict, save_path, save)
plot_results(z.to(device),target.cpu(),filename=save_path)