In [None]:
import ecog_data
import prediction
import torch

import wandb
import pytorch_lightning as ptl
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size

import os

In [None]:
config_dict = {
    'lr': 5e-4,
    'lr_factor': 0.5,
    'src_len': 50,
    'trg_len': 50,
    'batch_size': 1000,
    'encoder_size': 256,
    'encoder_layers': 1,
    'generator_size': 256,
    'generator_layers': 1,
    'factor_size': 64,
    'loss_weight_dict': {
        'ayy': 'lmao',
    },
    'dropout': 0.3,
}

In [None]:
# configure wandb
wandb.init(
    config = config_dict,
    mode = 'disabled'
)
wandb_logger = WandbLogger(name='ah-jeez',project='lfads')

In [None]:
# configure data module
ldm = ecog_data.GooseWireless250(
    wandb.config.src_len,
    wandb.config.trg_len,
    wandb.config.batch_size
)
ldm.prepare_data()
ldm.setup()

In [None]:
model = prediction.Lfads(
    src_size            = ldm.size()[-1],
    encoder_size        = wandb.config.encoder_size,
    encoder_layers      = wandb.config.encoder_layers,
    generator_size      = wandb.config.generator_size,
    generator_layers    = wandb.config.generator_layers,
    factor_size         = wandb.config.factor_size,
    loss_weight_dict    = wandb.config.loss_weight_dict,
    dropout             = wandb.config.dropout,
    learning_rate       = wandb.config.lr,
    lr_factor           = wandb.config.lr_factor
)

In [None]:
trainer = ptl.Trainer(max_epochs=100, logger=wandb_logger, gpus=1)
trainer.tune(model,ldm)
scale_batch_size(trainer,model,init_val=1024,max_trials=3)

In [None]:
trainer.fit(model)

In [None]:
src, trg = iter(ldm.train_dataloader()).__next__()
pred_dict = model(src,trg)