In [1]:
import os
import torch
import lightning as L

from markov_bridges.models.generative_models.cmb import CMB
from markov_bridges.utils.experiment_files import ExperimentFiles
from markov_bridges.models.generative_models.cmb_lightning import MixedForwardMapL
from markov_bridges.configs.config_classes.generative_models.cmb_config import CMBConfig
from markov_bridges.configs.config_classes.data.basics_configs import IndependentMixConfig
from markov_bridges.configs.config_classes.trainers.trainer_config import CMBTrainerConfig

from markov_bridges.data.dataloaders_utils import get_dataloaders

In [2]:
model_config = CMBConfig(continuous_loss_type="drift")
model_config.data = IndependentMixConfig(has_context_discrete=True)
model_config.trainer = CMBTrainerConfig(number_of_epochs=10,
                                        scheduler="exponential",
                                        warm_up=1,
                                        clip_grad=True)
experiment_files = ExperimentFiles(experiment_name="cmb",
                                    experiment_type="independent",
                                    experiment_indentifier="lightning_test",
                                    delete=True)

In [3]:
experiment_files.create_directories(model_config)
dataloaders = get_dataloaders(model_config)

In [4]:
mixed_model = MixedForwardMapL(model_config)
# saves checkpoints to 'some/path/' at every epoch end
trainer = L.Trainer(default_root_dir=experiment_files.experiment_dir,
                    max_epochs=model_config.trainer.number_of_epochs)
trainer.fit(mixed_model, dataloaders.train_dataloader, dataloaders.test_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name               | Type             | Params | Mode 
----------------------------------------------------------------
0 | mixed_network      | MixedDeepMLP     | 241 K  | train
1 | discrete_loss_nn   | CrossEntropyLoss | 0      | train
2 | continuous_loss_nn | MSELoss          | 0      | train
----------------------------------------------------------------
241 K     Trainable params
0         Non-trainable params
241 K     Total params
0.965     Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\cesar\anaconda4\envs\rate_matching\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
c:\Users\cesar\anaconda4\envs\rate_matching\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\cesar\anaconda4\envs\rate_matching\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\cesar\anaconda4\envs\rate_matching\lib\site-packages\lightning\pytorch\loops\fit_loop.p

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [5]:
experiment_files = ExperimentFiles(experiment_name="cmb",
                                    experiment_type="independent",
                                    experiment_indentifier="lightning_test",
                                    delete=True)

In [6]:
CKPT_PATH = r"C:\Users\cesar\Desktop\Projects\DiffusiveGenerativeModelling\OurCodes\markov_bridges\results\cmb\independent\lightning_test\lightning_logs\version_0\checkpoints\epoch=9-step=320.ckpt"
checkpoint = torch.load(CKPT_PATH)

In [8]:
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])