In [None]:
%reload_ext autoreload
%autoreload 2

from pytorch_lightning import seed_everything
from src.acnets.deep import WaveletModel, MaskedModel
from src.acnets.datasets import LEMONDataModule, Julia2018DataModule

for n_embeddings in [4, 8, 16, 32, 64, 128, 256]:

    seed_everything(42)
    feature_type = 'time_networks'

    pretrain_datamodule = LEMONDataModule(
        atlas='dosenbach2010', kind='partial correlation',
        aggregation_strategy=feature_type,
        test_ratio=.1, val_ratio=.05,
        segment_length=32,
        n_subjects=215, batch_size=32, shuffle=True)
    pretrain_datamodule.setup()

    finetune_datamodule = Julia2018DataModule(
        atlas='dosenbach2010', kind='partial correlation',
        aggregation_strategy='time_networks',
        segment_length=32,
        test_ratio=.5, batch_size=32, shuffle=True)
    finetune_datamodule.setup()

    n_features = finetune_datamodule.train[0][0].shape[-1]

    model = MaskedModel(n_features, n_embeddings)

    trainer = model.fit(pretrain_datamodule, max_epochs=2, phase='pretrain')
    tuner = model.fit(finetune_datamodule, max_epochs=2, phase='finetune')