# Deep Learning Playground

We are interested in the following features:

- H1: region-averaged time-series
- H2: region-level connectivities (from H1, optional: triu-k1)
- H3: network-averaged time-series (from H1)
- H4: network connectivity (from H3, optional: triu-k1)
- H5: networks connectivity (from H2, optional: triu-k0)

> Note that we are not going to take the upper triangular part of the connectivity matrix and full matrices are used instead.

In [None]:
%reload_ext autoreload
%autoreload 2

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar

from src.acnets.deep import ACNetsDataModule, MultiHeadModel

In [None]:
datamodule = ACNetsDataModule(atlas='dosenbach2010', kind='partial correlation', batch_size=8)
datamodule.setup()

n_regions = datamodule.train[0][0].shape[1]
n_networks = datamodule.train[0][2].shape[1]
model = MultiHeadModel(n_regions, n_networks, n_embeddings=8)

# DEBUG
trainer = pl.Trainer(accelerator='auto', max_epochs=1,
                     log_every_n_steps=1, callbacks=[RichProgressBar()])

trainer.fit(model, datamodule)
trainer.test(model, datamodule=datamodule)

In [None]:
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)


def train_func(config):
    datamodule = ACNetsDataModule(atlas=config['atlas'],
                                  kind=config['kind'],
                                  batch_size=config['batch_size'])
    datamodule.setup()

    n_regions = datamodule.train[0][0].shape[1]
    n_networks = datamodule.train[0][2].shape[1]
    model = MultiHeadModel(n_regions, n_networks, n_embeddings=config['n_embeddings'])

    trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, datamodule=dm)

In [None]:
# DEBUG: Testing Simple Recurrent AutoEncoder

import torch
from torch import nn

n_timesteps = 124
n_networks = 6
batch_size = 8
hidden_size = 6

x = torch.randn(batch_size, n_timesteps, n_networks)

encoder = nn.LSTM(n_networks, hidden_size, batch_first=True)
decoder = nn.LSTM(hidden_size, hidden_size, batch_first=True)

_, (h_enc, c_enc) = encoder(x)
x_dec = torch.randn(batch_size, n_timesteps, hidden_size)
x_recon, _ = decoder(x_dec, (h_enc, c_enc))

x.shape, x_recon.shape

In [None]:
atlases = [
    'dosenbach2010',
    'dosenbach2007',
    'difumo_64_2mm',
    'difumo_128_2mm',
    'difumo_256_2mm',
    'difumo_512_2mm',
    'difumo_1024_2mm',
    'gordon2014_2mm',
    'cort-maxprob-thr25-2mm',
    'seitzman2018',
    'friedman2020',
]

kinds = [
    'correlation',
    'partial correlation',
    'tangent',
    'covariance',
    'precisions'
]