In [None]:
import ecog_data
import prediction

import torch

import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

In [None]:
config_dict     = {
    'src_len': 50,
    'trg_len': 50,
    'batch_size': 1000,
    'latent_size': 10,
    'n_kernels': 10,
    'kernel_size': 9,
    'pool_size': 2,
    'dropout': 0.2,
    'learning_rate': 1e-3,
}

In [None]:
name        = 'ecog_conv_ae_test_run'
project     = 'ecog-ConvAE-test'
wandb.init(
    config  = config_dict,
    name    = name,
    project = project
)
wandb_logger = WandbLogger(name=name,project=project)

In [None]:
data    = ecog_data.GooseWireless250(
    src_len     = wandb.config.src_len,
    trg_len     = wandb.config.trg_len,
    batch_size  = wandb.config.batch_size
)

In [None]:
model   = prediction.ConvAE(
    input_size      = data.dims[-1],
    latent_size     = wandb.config.latent_size,
    src_len         = data.src_len,
    trg_len         = data.trg_len,
    n_kernels       = wandb.config.n_kernels,
    kernel_size     = wandb.config.kernel_size,
    pool_size       = wandb.config.pool_size,
    dropout         = wandb.config.dropout,
    learning_rate   = wandb.config.learning_rate
)

In [None]:
ckpt_cb = pl.callbacks.ModelCheckpoint(
    monitor         = 'avg_valid_loss',
    dirpath         = '.\\models\\ConvAE',
    filename        = 'conv_ae-{epoch:03d}-{val_loss:.3f}'
)
trainer = pl.Trainer(max_epochs=100, 
                    logger = wandb_logger, 
                    gpus=1)

In [None]:
trainer.fit(model,data)