In [37]:
%reload_ext autoreload
%autoreload 2

from src.acnets.deep.lemon_data import LEMONDataModule
import numpy as np
import torch
from torch import nn
import torch.nn.functional  as F
import pywt
import matplotlib.pyplot as plt
import seaborn as sns

from src.acnets.deep.cvae import CVAE

In [38]:
datamodule = LEMONDataModule(n_subjects=8, test_ratio=None)
datamodule.setup()

batch = datamodule.train[:]
x_wavelets = batch[5] # -> shape: (subjects, wavelets, regions)

n_channels = x_wavelets.shape[2]
n_wavelets = 32
n_embeddings = 16

x = x_wavelets[:, :n_wavelets, :].transpose(1, 2)  # -> shape: (subjects, regions, wavelets)

h, x_recon, loss = CVAE(n_channels, n_embeddings)(x)
x.shape, h.shape, x_recon.shape, loss


(torch.Size([8, 160, 32]),
 torch.Size([8, 16]),
 torch.Size([8, 160, 32]),
 tensor(3.2291, grad_fn=<AddBackward0>))

In [39]:
# # DEBUG PLOTS
# plt.subplots(1, 2, figsize=(10, 3))
# plt.subplot(1, 2, 1)
# plt.title('original')
# sns.heatmap(ts_regions[0].T, cmap='viridis')

# # plot 100 wavelet coefficients of the first subject
# plt.subplot(1, 2, 2)
# plt.title('coefs')
# sns.heatmap(wt_regions[0][:100,:].T, cmap='viridis', label='coefs')
# plt.show()

In [40]:
%reload_ext autoreload
%autoreload 2

import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

from src.acnets.deep import MultiHeadWaveletModel, LEMONDataModule, Julia2018DataModule

lemon_datamodule = LEMONDataModule(
    atlas='dosenbach2010', kind='partial correlation',
    n_subjects=201, batch_size=32)
lemon_datamodule.setup()

julia2018_datamodule = Julia2018DataModule(
    atlas='dosenbach2010', kind='partial correlation',
    test_ratio=.5, batch_size=8)

n_regions = lemon_datamodule.train[0][0].shape[1]
n_embeddings = 64

model = MultiHeadWaveletModel(n_regions, n_embeddings=n_embeddings)

# pre-train
model.disable_finetune()
trainer = pl.Trainer(accelerator='auto',
                     max_epochs=200,
                     accumulate_grad_batches=5,
                    #  gradient_clip_val=.5,
                   logger=TensorBoardLogger('lightning_logs', name='mh_wvt'),
                     log_every_n_steps=1,
                     callbacks=[RichProgressBar()])
trainer.fit(model, datamodule=lemon_datamodule)
run_version = f'version_{trainer.logger.version}'

# fine-tune
model.enable_finetune()
tuner = pl.Trainer(accelerator='auto',
                   max_epochs=1000,
                   logger=TensorBoardLogger('lightning_logs', name='mh_wvt', version=f'{run_version}_ft'),
                   accumulate_grad_batches=2,
                   #  gradient_clip_val=.5,
                   log_every_n_steps=1,
                   callbacks=[RichProgressBar()])
tuner.fit(model, datamodule=julia2018_datamodule)

# test
tuner.test(model, datamodule=julia2018_datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

`Trainer.fit` stopped: `max_epochs=200` reached.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

`Trainer.fit` stopped: `max_epochs=1000` reached.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'loss_recon/test': 0.5521789193153381,
  'loss_cls/test': 0.6931471824645996,
  'accuracy/test': 0.5}]