In [None]:
%reload_ext autoreload
%autoreload 2

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar

from src.acnets.deep import ACNetsDataModule, MultiHeadModel
from src.acnets.pipeline import Parcellation

from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)
from ray import tune
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig, CheckpointConfig


In [None]:
atlases = [
    'dosenbach2010',
    'dosenbach2007',
    'difumo_64_2mm',
    'difumo_128_2mm',
    'difumo_256_2mm',
    'difumo_512_2mm',
    'difumo_1024_2mm',
    'gordon2014_2mm',
    'cort-maxprob-thr25-2mm',
    'seitzman2018',
    'friedman2020',
]

kinds = [
    'correlation',
    'partial correlation',
    'tangent',
    'covariance',
    'precisions'
]

In [None]:

def train_func(config):
    datamodule = ACNetsDataModule(atlas=config['atlas'],
                                  kind=config['kind'],
                                  batch_size=config['batch_size'])
    datamodule.setup()

    n_regions = datamodule.train[0][0].shape[1]
    n_networks = datamodule.train[0][2].shape[1]
    model = MultiHeadModel(n_regions, n_networks, n_embeddings=config['n_embeddings'])

    trainer = pl.Trainer(
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
        max_epochs=config['max_epochs'],
        log_every_n_steps=1
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

TorchTrainer(
    train_func,
    train_loop_config={
        'batch_size': tune.choice([32]),
        'atlas': tune.choice(['dosenbach2010']),
        'kind': tune.choice(['partial correlation']),
        'n_embeddings': tune.choice([8]),
        'max_epochs': 1,
    },
    run_config=RunConfig(
        name='acnets_multihead',
        storage_path='~/ray_results',
        checkpoint_config=CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute='val/accuracy',
            checkpoint_score_order='max',
        )
    ),
    scaling_config=ScalingConfig(num_workers=2, use_gpu=False)
).fit()