In [1]:
# Convenient jupyter setup
%load_ext autoreload
%autoreload 2

In [2]:
from dynamics.model.dms import DMSWrapper
from dynamics.model.pbmp import PBMP
from dynamics.data.datasets.greener.datamodule import GreenerDataModule
import torch
import pytorch_lightning as pl

from pytorch_lightning.callbacks import ModelCheckpoint

import hydra
from logging import log
from omegaconf import DictConfig, OmegaConf

In [3]:
hydra.initialize('../config')
hydra.main('../config')

<function hydra.main.main.<locals>.main_decorator(task_function: Callable[[Any], Any]) -> Callable[[], NoneType]>

In [4]:
config = hydra.compose("main")



In [5]:
data_module = GreenerDataModule(config.dataset.dir, config.training.batch_size, config.dataset.fraction)
data = next(iter(data_module.val_dataloader()))

In [6]:
config.model.params

{'input_dim': 24, 'hidden_dim': 128, 'num_hidden_layers': 4, 'temperature': 0.05, 'timestep': 0.05, 'n_steps': 100, 'k': 20}

In [7]:
model = PBMP(**config.model.params)

TypeError: can only concatenate str (not "int") to str

In [None]:
@hydra.main("../config/", config_name="main")
def train(config: DictConfig):
    print(OmegaConf.to_yaml(config))

    # Prepare dataloaders and model
    data_module = GreenerDataModule(config.dataset.dir, config.training.batch_size, config.dataset.fraction)
    model = DMSWrapper(config)

    # Configure Trainer
    logger = pl.loggers.WandbLogger(log_model='all', project="dynamics", config=config)
    logger.watch(model)

    checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")

    trainer = pl.Trainer(
                    logger=logger,
                    callbacks=[checkpoint_callback],
                    max_epochs=config.training.epochs,
                    log_every_n_steps=config.training.logging_freq,
                    flush_logs_every_n_steps=config.training.logging_freq,
                    val_check_interval=0.001,
                    )

    # Train
    trainer.fit(model, data_module)

train()