# Deep Learning Playground

We are interested in the following features:

- H1: region-averaged time-series
- H2: region-level connectivities (from H1, optional: triu-k1)
- H3: network-averaged time-series (from H1)
- H4: network connectivity (from H3, optional: triu-k1)
- H5: networks connectivity (from H2, optional: triu-k0)

> Note that we are not going to take the upper triangular part of the connectivity matrix and full matrices are used instead.

In [5]:
%reload_ext autoreload
%autoreload 2

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar

from src.acnets.deep import ACNetsDataModule, MultiHeadModel

In [6]:
datamodule = ACNetsDataModule(atlas='dosenbach2010', kind='partial correlation', batch_size=8)
datamodule.setup()

n_regions = datamodule.train[0][0].shape[1]
n_networks = datamodule.train[0][2].shape[1]
model = MultiHeadModel(n_regions, n_networks, n_embeddings=8)

# DEBUG
trainer = pl.Trainer(accelerator='auto', max_epochs=1,
                     log_every_n_steps=1, callbacks=[RichProgressBar()])

trainer.fit(model, datamodule)
trainer.test(model, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

`Trainer.fit` stopped: `max_epochs=1` reached.


Output()

[{'test/loss_cls': 0.6940850019454956,
  'test/accuracy': 0.5,
  'test/dropout_accuracy': 0.5}]

In [28]:
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)
from ray import tune
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig, CheckpointConfig

def train_func(config):
    datamodule = ACNetsDataModule(atlas=config['atlas'],
                                  kind=config['kind'],
                                  batch_size=config['batch_size'])
    datamodule.setup()

    n_regions = datamodule.train[0][0].shape[1]
    n_networks = datamodule.train[0][2].shape[1]
    model = MultiHeadModel(n_regions, n_networks, n_embeddings=config['n_embeddings'])

    trainer = pl.Trainer(
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
        max_epochs=config['max_epochs'],
        log_every_n_steps=1
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

TorchTrainer(
    train_func,
    train_loop_config={
        'batch_size': tune.choice([32]),
        'atlas': tune.choice(['dosenbach2010']),
        'kind': tune.choice(['partial correlation']),
        'n_embeddings': tune.choice([8]),
        'max_epochs': 1,
    },
    run_config=RunConfig(
        name='acnets_multihead',
        storage_path='~/ray_results',
        checkpoint_config=CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute='val/accuracy',
            checkpoint_score_order='max',
        )
    ),
    scaling_config=ScalingConfig(num_workers=2, use_gpu=False)
).fit()

0,1
Current time:,2024-01-07 19:05:04
Running for:,00:01:39.16
Memory:,5.5/16.0 GiB

Trial name,status,loc,train_loop_config/at las,train_loop_config/ba tch_size,train_loop_config/ki nd,train_loop_config/n_ embeddings,iter,total time (s),train/loss_cls,train/loss_recon,train/loss
TorchTrainer_0d7c6_00000,TERMINATED,127.0.0.1:3290,dosenbach2010,32,partial correlation,8,1,57.8792,0.686403,0.200263,0.886666


[36m(TorchTrainer pid=3290)[0m Starting distributed worker processes: ['3298 (127.0.0.1)', '3299 (127.0.0.1)']
[36m(RayTrainWorker pid=3298)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(RayTrainWorker pid=3298)[0m GPU available: False, used: False
[36m(RayTrainWorker pid=3298)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3298)[0m IPU available: False, using: 0 IPUs
[36m(RayTrainWorker pid=3298)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3299)[0m Missing logger folder: /Users/morteza/ray_results/acnets_multihead/TorchTrainer_0d7c6_00000_0_atlas=dosenbach2010,batch_size=32,kind=partial_correlation,n_embeddings=8_2024-01-07_19-03-25/lightning_logs
[36m(RayTrainWorker pid=3298)[0m 
[36m(RayTrainWorker pid=3298)[0m   | Name           | Type               | Params
[36m(RayTrainWorker pid=3298)[0m ------------------------------------------------------
[36m(RayTrainWorker pid=3298)[0m 0 | train_accuracy | M

[36m(RayTrainWorker pid=3298)[0m ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
[36m(RayTrainWorker pid=3298)[0m ┃[1m [0m[1m       Test metric       [0m[1m [0m┃[1m [0m[1m      DataLoader 0       [0m[1m [0m┃
[36m(RayTrainWorker pid=3298)[0m ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
[36m(RayTrainWorker pid=3298)[0m │[36m [0m[36m      test/accuracy      [0m[36m [0m│[35m [0m[35m          0.25           [0m[35m [0m│
[36m(RayTrainWorker pid=3298)[0m │[36m [0m[36m  test/dropout_accuracy  [0m[36m [0m│[35m [0m[35m          0.25           [0m[35m [0m│
[36m(RayTrainWorker pid=3298)[0m │[36m [0m[36m      test/loss_cls      [0m[36m [0m│[35m [0m[35m   0.7183140516281128    [0m[35m [0m│
[36m(RayTrainWorker pid=3298)[0m └───────────────────────────┴───────────────────────────┘


2024-01-07 19:05:04,530	INFO tune.py:1047 -- Total run time: 99.19 seconds (99.16 seconds for the tuning loop).


Result(
  metrics={'train/loss_cls': 0.6864034533500671, 'train/loss_recon': 0.20026294887065887, 'train/loss': 0.8866664171218872, 'train/accuracy': 0.5833333134651184, 'val/loss_cls': 0.7183140516281128, 'val/accuracy': 0.25, 'val/dropout_accuracy': 0.25, 'epoch': 0, 'step': 1},
  path='/Users/morteza/ray_results/acnets_multihead/TorchTrainer_0d7c6_00000_0_atlas=dosenbach2010,batch_size=32,kind=partial_correlation,n_embeddings=8_2024-01-07_19-03-25',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/Users/morteza/ray_results/acnets_multihead/TorchTrainer_0d7c6_00000_0_atlas=dosenbach2010,batch_size=32,kind=partial_correlation,n_embeddings=8_2024-01-07_19-03-25/checkpoint_000000)
)

In [None]:
atlases = [
    'dosenbach2010',
    'dosenbach2007',
    'difumo_64_2mm',
    'difumo_128_2mm',
    'difumo_256_2mm',
    'difumo_512_2mm',
    'difumo_1024_2mm',
    'gordon2014_2mm',
    'cort-maxprob-thr25-2mm',
    'seitzman2018',
    'friedman2020',
]

kinds = [
    'correlation',
    'partial correlation',
    'tangent',
    'covariance',
    'precisions'
]