In [1]:
%load_ext autoreload
%autoreload 2

import os
import shutil

import omegaconf
import hydra
import numpy as np
import matplotlib.pyplot as plt
import swyft.lightning as sl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# plt.switch_backend("agg")
plt.rcParams['figure.facecolor'] = 'white'

In [7]:
cfg = omegaconf.OmegaConf.load("config_blobs.yaml")
from udens.plot import plt_imshow
imkwargs = dict(extent=(-2.5, 2.5, -2.5, 2.5), origin='lower') #left, right, bottom, top
from tqdm.notebook import tqdm as tqdm
import matplotlib.colors

In [3]:
def simulate(cfg):
    # Loading simulator (potentially bounded)
    simulator = hydra.utils.instantiate(cfg.simulation.model)

    # Generate or load training data & generate datamodule
    train_samples = sl.file_cache(
        lambda: simulator.sample(cfg.simulation.store.store_size),
        cfg.simulation.store.path,
    )[: cfg.simulation.store.train_size]
    datamodule = sl.SwyftDataModule(
        store=train_samples,
        simulator=simulator,  # Adds noise on the fly. `None` uses noise in store.
        batch_size=cfg.estimation.batch_size,
        num_workers=cfg.estimation.num_workers,
    )

    return datamodule, simulator
datamodule, simulator = simulate(cfg)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
def load(cfg, simulator):
    print('Loading trained network')
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored
    
    checkpoints = os.listdir( os.path.join(logdir, 'checkpoint') )
    if 'best.ckpt' in checkpoints:
        best_ckpt = 'best.ckpt'
    else:
        best_idx = np.argmax(list(map(int, [checkpoint[6:8] for checkpoint in checkpoints])))
        best_ckpt = checkpoints[best_idx]
    print(f'best checkpoint is {best_ckpt}')
    
    checkpoint = torch.load(
        os.path.join(logdir, f'checkpoint/{best_ckpt}'), map_location='cpu'
    )

    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
    network.load_state_dict(checkpoint["state_dict"])

    train_samples = torch.load(cfg.simulation.store.path)
    
    trainer = sl.SwyftTrainer(accelerator=cfg.estimation.accelerator, gpus=1)
    trainer.setup(None)
    
    datamodule = sl.SwyftDataModule(store=train_samples, model=simulator)
    datamodule.setup(None)
    
    trainer.model = network
    
    return network, trainer, tbl, datamodule

def analyse(cfg, datamodule):
    # Setting up tensorboard logger, which defines also logdir (contains trained network)
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored

    # Load network and train (or re-load trained network)
    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
    lr_monitor = LearningRateMonitor(logging_interval="step")
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=cfg.estimation.early_stopping.min_delta,
        patience=cfg.estimation.early_stopping.patience,
        verbose=False,
        mode="min",
    )
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=logdir + "/checkpoint/",
        filename="{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )
    trainer = sl.SwyftTrainer(
        accelerator=cfg.estimation.accelerator,
        gpus=1,
        max_epochs=cfg.estimation.max_epochs,
        logger=tbl,
        callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
    )
    best_checkpoint = logdir + "/checkpoint/best.ckpt"
    if not os.path.isfile(best_checkpoint):
        trainer.fit(network, datamodule)
        shutil.copy(checkpoint_callback.best_model_path, best_checkpoint)
        trainer.test(network, datamodule)
    else:
        trainer.fit(network, datamodule, ckpt_path=best_checkpoint)

    return network, trainer, tbl

# network, trainer, tbl, datamodule = load(cfg, simulator)
network, trainer, tbl = analyse(cfg, datamodule)

Prior,    M_frac    in subhalo log10 mass range
3.12e-04, 5.00e-01:    [9.000 - 9.500]
3.12e-04, 5.00e-01:    [9.500 - 10.000]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                         | Params
----------------------------------------------------------------
0 | online_z_score | OnlineDictStandardizingLayer | 0     
1 | classifier     | RatioEstimatorUNet           | 31.0 M
----------------------------------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.147   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        hp/JS-div            4434.78076171875
        hp/KL-div           -12.211993217468262
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [8]:
def interpret(cfg, simulator, network, trainer, datamodule, tbl):
    hydra.utils.call(
        cfg.inference.interpreter, cfg, simulator, network, trainer, datamodule, tbl
    )
interpret(cfg, simulator, network, trainer, datamodule, tbl)

Prior,    M_frac    in subhalo log10 mass range
3.12e-04, 5.00e-01:    [9.000 - 9.500]
3.12e-04, 5.00e-01:    [9.500 - 10.000]


Calculating posteriors: 100%|██████████| 63/63 [00:24<00:00,  2.58it/s]
Calculating reliability curve: 100%|██████████| 63/63 [00:00<00:00, 2251.93it/s]
Calculating reliability curve: 100%|██████████| 21/21 [00:00<00:00, 1364.47it/s]
Calculating reliability curve: 100%|██████████| 21/21 [00:00<00:00, 1325.61it/s]
Calculating reliability curve: 100%|██████████| 63/63 [00:00<00:00, 2792.27it/s]
100%|██████████| 1/1 [00:00<00:00, 714.90it/s]
100%|██████████| 1/1 [00:00<00:00, 1909.11it/s]


logdir: ./lightning_logs_blobs/uniform_noise0.0_sub1-1_m9.0-10.0_pix40_msc2_sim10000/version_2


In [6]:
# @hydra.main(config_path=".", config_name="config")
# def main(cfg):
#     print_dict(cfg)
#     datamodule, simulator = simulate(cfg)
    
#     if cfg.load:
#         network, trainer, tbl, datamodule = load(cfg, simulator)
#     else:
#         network, trainer, tbl = analyse(cfg, datamodule)
#     interpret(cfg, simulator, network, trainer, datamodule, tbl)


# if __name__ == "__main__":
#     main()