In [1]:
%load_ext autoreload
%autoreload 2

import os
import shutil

import omegaconf
import hydra
import numpy as np
import matplotlib.pyplot as plt
import swyft.lightning as sl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# plt.switch_backend("agg")
plt.rcParams['figure.facecolor'] = 'white'

In [2]:
cfg = omegaconf.OmegaConf.load("config_blobs.yaml")
from udens.plot import plt_imshow
imkwargs = dict(extent=(-2.5, 2.5, -2.5, 2.5), origin='lower') #left, right, bottom, top
from tqdm.notebook import tqdm as tqdm
import matplotlib.colors

In [3]:
def simulate(cfg):
    # Loading simulator (potentially bounded)
    simulator = hydra.utils.instantiate(cfg.simulation.model)

    # Generate or load training data & generate datamodule
    train_samples = sl.file_cache(
        lambda: simulator.sample(cfg.simulation.store.store_size),
        cfg.simulation.store.path,
    )[: cfg.simulation.store.train_size]
    datamodule = sl.SwyftDataModule(
        store=train_samples,
        model=simulator,  # Adds noise on the fly. `None` uses noise in store.
        batch_size=cfg.estimation.batch_size,
        num_workers=cfg.estimation.num_workers,
    )

    return datamodule, simulator
datamodule, simulator = simulate(cfg)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
def load(cfg, simulator):
    print('Loading trained network')
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored
    
    checkpoints = os.listdir( os.path.join(logdir, 'checkpoint') )
    if 'best.ckpt' in checkpoints:
        best_ckpt = 'best.ckpt'
    else:
        best_idx = np.argmax(list(map(int, [checkpoint[6:8] for checkpoint in checkpoints])))
        best_ckpt = checkpoints[best_idx]
    print(f'best checkpoint is {best_ckpt}')
    
    checkpoint = torch.load(
        os.path.join(logdir, f'checkpoint/{best_ckpt}'), map_location='cpu'
    )

    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
    network.load_state_dict(checkpoint["state_dict"])

    train_samples = torch.load(cfg.simulation.store.path)
    
    trainer = sl.SwyftTrainer(accelerator=cfg.estimation.accelerator, gpus=1)
    trainer.setup(None)
    
    datamodule = sl.SwyftDataModule(store=train_samples, model=simulator)
    datamodule.setup(None)
    
    trainer.model = network
    
    return network, trainer, tbl, datamodule

def analyse(cfg, datamodule):
    # Setting up tensorboard logger, which defines also logdir (contains trained network)
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored

    # Load network and train (or re-load trained network)
    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
    lr_monitor = LearningRateMonitor(logging_interval="step")
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=cfg.estimation.early_stopping.min_delta,
        patience=cfg.estimation.early_stopping.patience,
        verbose=False,
        mode="min",
    )
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=logdir + "/checkpoint/",
        filename="{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )
    trainer = sl.SwyftTrainer(
        accelerator=cfg.estimation.accelerator,
        gpus=1,
        max_epochs=cfg.estimation.max_epochs,
        logger=tbl,
        callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
    )
    best_checkpoint = logdir + "/checkpoint/best.ckpt"
    if not os.path.isfile(best_checkpoint):
        trainer.fit(network, datamodule)
        shutil.copy(checkpoint_callback.best_model_path, best_checkpoint)
        trainer.test(network, datamodule)
    else:
        trainer.fit(network, datamodule, ckpt_path=best_checkpoint)

    return network, trainer, tbl

# network, trainer, tbl, datamodule = load(cfg, simulator)
network, trainer, tbl = analyse(cfg, datamodule)

Prior,    M_frac    in subhalo log10 mass range
3.12e-04, 5.00e-01:    [9.000 - 9.500]
3.12e-04, 5.00e-01:    [9.500 - 10.000]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Restoring states from the checkpoint path at ./lightning_logs_blobs/uniform_noise0.0_sub1-1_m9.0-10.0_pix40_msc2_sim10000/version_0/checkpoint/best.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                         | Params
----------------------------------------------------------------
0 | online_z_score | OnlineDictStandardizingLayer | 0     
1 | classifier     | RatioEstimatorUNet           | 31.0 M
----------------------------------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.147   Total estimated model params size (MB)
Restored all states from the checkpoint file at ./lightning_logs_blobs/uniform_noise0.0_sub1-1_m9.0-10.0_pix40_msc2_sim10000/version_0/checkpoint

In [5]:
def interpret(cfg, simulator, network, trainer, datamodule, tbl):
    hydra.utils.call(
        cfg.inference.interpreter, cfg, simulator, network, trainer, datamodule, tbl
    )
interpret(cfg, simulator, network, trainer, datamodule, tbl)

Prior,    M_frac    in subhalo log10 mass range
3.12e-04, 5.00e-01:    [9.000 - 9.500]
3.12e-04, 5.00e-01:    [9.500 - 10.000]


Calculating posteriors: 100%|██████████| 63/63 [00:25<00:00,  2.51it/s]
Calculating reliability curve: 100%|██████████| 63/63 [00:00<00:00, 2759.44it/s]
Calculating reliability curve: 100%|██████████| 63/63 [00:00<00:00, 2728.92it/s]
Calculating reliability curve: 100%|██████████| 21/21 [00:00<00:00, 1313.79it/s]
Calculating reliability curve: 100%|██████████| 21/21 [00:00<00:00, 1346.51it/s]
Calculating reliability curve: 100%|██████████| 63/63 [00:00<00:00, 2815.75it/s]
Calculating reliability curve: 100%|██████████| 21/21 [00:00<00:00, 1293.07it/s]
Calculating reliability curve: 100%|██████████| 21/21 [00:00<00:00, 1306.77it/s]
Calculating reliability curve: 100%|██████████| 63/63 [00:00<00:00, 2810.18it/s]
100%|██████████| 1/1 [00:00<00:00, 493.56it/s]
100%|██████████| 1/1 [00:00<00:00, 2432.89it/s]


logdir: ./lightning_logs_blobs/uniform_noise0.0_sub1-1_m9.0-10.0_pix40_msc2_sim10000/version_0


In [6]:
# @hydra.main(config_path=".", config_name="config")
# def main(cfg):
#     print_dict(cfg)
#     datamodule, simulator = simulate(cfg)
    
#     if cfg.load:
#         network, trainer, tbl, datamodule = load(cfg, simulator)
#     else:
#         network, trainer, tbl = analyse(cfg, datamodule)
#     interpret(cfg, simulator, network, trainer, datamodule, tbl)


# if __name__ == "__main__":
#     main()

In [7]:
# Make the scripts stop here
assert 1 == 2

AssertionError: 

# Interpret

In [None]:
import os
import numpy as np
import pylab as plt
import torch
import swyft.lightning as sl

from udens.interpret import IsotonicRegressionCalibration
from udens.logging_utils_subN import LogIRC, LogPost, LogObs, LogBounds, LogSingleSub
from udens.inference import Infer, Prior

In [None]:
logdir = tbl.experiment.get_logdir()

# Calculate expected n_sub
Ms = datamodule.predict_dataloader().dataset[:]['z_sub'][:,:,0]
n_sub_expect = cfg.simulation.model.nsub_expect #torch.mean(torch.tensor(np.count_nonzero(Ms.numpy(), axis  = 1), dtype = torch.float32))

# Loading the inference class and 
infer = Infer(simulator, network, datamodule, n_sub_expect)

# Prior information necessary for loggers
prior, prior_grid = infer.calc_prior()[0], infer.prior_grid()
grid_coords = infer.get_grid_coords()
grid_low, grid_high = infer.grid_low, infer.grid_high
    

In [None]:
# Simulations inference
posts_norm, posts_unnorm, targets = infer.get_posts(datamodule.predict_dataloader(), cfg.inference.n_infer)
torch.save(posts_norm, os.path.join(logdir, 'posts_norm.pt'))
torch.save(posts_unnorm, os.path.join(logdir, 'posts_unnorm.pt'))
torch.save(targets, os.path.join(logdir,'targets.pt'))
# posts_norm = torch.load(os.path.join(logdir, 'posts_norm.pt'))
# posts_unnorm = torch.load(os.path.join(logdir, 'posts_unnorm.pt'))
# targets       = torch.load(os.path.join(logdir, 'targets.pt'))

In [None]:
# Calibration
irc_norm = IsotonicRegressionCalibration(posts_norm, targets)    
posts_norm_calib = irc_norm.calibrate(posts_norm)
torch.save(posts_norm_calib, os.path.join(logdir, 'posts_norm_calib.pt'))

irc_unnorm = IsotonicRegressionCalibration(posts_unnorm, targets)    
posts_unnorm_calib = irc_unnorm.calibrate(posts_unnorm)
torch.save(posts_unnorm_calib, os.path.join(logdir, 'posts_unnorm_calib.pt'))

In [None]:
# Log simulation inference
LogPost(tbl, posts_norm, targets, title = 'norm_uncalib').plot_all()
LogPost(tbl, posts_norm_calib, targets, title = 'norm_calib').plot_all()
LogIRC(tbl, irc_norm, title = 'norm_calibration').plot()

LogPost(tbl, posts_unnorm, targets, title = 'unnorm_uncalib').plot_all()
LogPost(tbl, posts_unnorm_calib, targets, title = 'unnorm_calib').plot_all()
LogIRC(tbl, irc_unnorm, title = 'unnorm_calibration').plot()

test_sim = simulator.sample(1)


test_post_uncalib = infer.get_post(test_sim)[0] #[0] refers to normalized case
test_sim = infer.squeeze_obs(test_sim)
test_post = irc_norm.calibrate(test_post_uncalib.squeeze(0))

log_obs = LogObs(tbl, test_sim, test_post, prior, grid_coords)
log_obs.plot_obs()

In [None]:
sim_id = ''
fig = lavalamp(test_post, test_sim, grid_coords, grid_low, grid_high)
fig.write_html(os.path.join(logdir, f'lavalamp{sim_id}.html'))

In [None]:
tbl.experiment.flush()
print("logdir:", tbl.experiment.get_logdir())