In [None]:
import ecog_data
import prediction

import torch

import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

In [None]:
sweep_config    = {
    'method': "random",
    'metric': {
        'name': 'avg_valid_loss',
        'goal': 'minimize'
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 10
    },
    'parameters': {
        'latent_size': {
            'values': [2, 4, 8, 16, 32]
        },
        'n_kernels': {
            'values': [2, 4, 8, 16]
        },
        'kernel_size': {
            'values': [3, 7, 15]
        },
        'learning_rate': {
            'values': [1e-3, 5e-4, 1e-4, 5e-5]
        },
        'src_len': {
            'values': [50]
        },
        'trg_len': {
            'values': [50]
        },
        'batch_size': {
            'values': [4000]
        },
        'pool_size': {
            'values': [2],
        },
        'dropout': {
            'values': [0.1, 0.2, 0.3]
        },
    }
}

In [None]:
name        = 'ecog_ConvAE-sweep_test'
project     = 'ecog-ConvAE-sweep'
sweep_id = wandb.sweep(sweep_config, project=project)

In [None]:
def sweep_iteration():
    # wandb session
    wandb.init()
    wandb_logger = WandbLogger()

    # LightningDataModule
    data = ecog_data.GooseWireless250(
        wandb.config.src_len,
        wandb.config.trg_len,
        wandb.config.batch_size
    ) # heads-up, this should be broken down into separate arguments.

    # model
    model = prediction.ConvAE(
        input_size      = data.dims[-1],
        latent_size     = wandb.config.latent_size,
        src_len         = data.src_len,
        trg_len         = data.trg_len,
        n_kernels       = wandb.config.n_kernels,
        kernel_size     = wandb.config.kernel_size,
        pool_size       = wandb.config.pool_size,
        dropout         = wandb.config.dropout,
        learning_rate   = wandb.config.learning_rate
    )

    # trainer
    trainer = pl.Trainer(
        logger = wandb_logger,
        max_epochs = 2000,
        gpus = -1
    )

    # train on parameterization
    trainer.fit(model, data)

In [None]:
wandb.agent(sweep_id, function=sweep_iteration)