# Example WandB/ptl sweep

In [1]:
import prediction
import ecog_data

import wandb
import pytorch_lightning as ptl
from pytorch_lightning.loggers import WandbLogger
import torch

In [None]:
sweep_config = { # replace this with a *.yaml file!
  "method": "random",   # Random search
  "metric": {           # We want to maximize val_acc
      "name": "valid_acc",
      "goal": "maximize"
  },
  "parameters": {
        "n_layer_1": {
            # Choose from pre-defined values
            "values": [32, 64, 128, 256, 512]
        },
        "n_layer_2": {
            # Choose from pre-defined values
            "values": [32, 64, 128, 256, 512, 1024]
        },
        "lr": {
            # log uniform distribution between exp(min) and exp(max)
            "distribution": "log_uniform",
            "min": -9.21,   # exp(-9.21) = 1e-4
            "max": -4.61    # exp(-4.61) = 1e-2
        }
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="MNIST")

In [2]:
def sweep_iteration():
    # wandb session
    wandb.init()
    wandb_logger = WandbLogger()

    # LightningDataModule
    gw250 = ecog_data.GooseWireless250(
        wandb.config.src_len,
        wandb.config.trg_len,
        wandb.config.batch_size
    ) # heads-up, this should be broken down into separate arguments.

    # model
    model = prediction.EcogPredictionModel(wandb.config)

    # trainer
    trainer = ptl.Trainer(
        logger = wandb_logger,
        max_epochs = 2000,
        gpus = -1
    )

    # train on parameterization
    trainer.fit(model, gw250)

In [None]:
wandb.agent(sweep_id, function=sweep_iteration)