In [11]:
%load_ext autoreload
%autoreload 2

import os
import shutil

import omegaconf
import hydra
import numpy as np
import pylab as plt
import swyft.lightning as sl
import torch
from lensx.logging_utils import log_post_plots, log_target_plots, log_train_plots
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from lensx.nn.subN.utils import print_dict

# plt.switch_backend("agg")
plt.rcParams['figure.facecolor'] = 'white'


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
cfg = omegaconf.OmegaConf.load("config_uniform.yaml")
from lensx.nn.subN.plot import plt_imshow
imkwargs = dict(extent=(-2.5, 2.5, -2.5, 2.5), origin='lower') #left, right, bottom, top
from tqdm.notebook import tqdm as tqdm
import matplotlib.colors


In [13]:
def check_obs(cfg):
    try:   torch.load(cfg.inference.obs_path)
    except FileNotFoundError: print('No mock generated!')
        
check_obs(cfg)

No mock generated!


In [14]:
def simulate(cfg):
    # Loading simulator (potentially bounded)
    simulator = hydra.utils.instantiate(cfg.simulation.model)

    # Generate or load training data & generate datamodule
    train_samples = sl.file_cache(
        lambda: simulator.sample(cfg.simulation.store.store_size),
        cfg.simulation.store.path,
    )[: cfg.simulation.store.train_size]
    datamodule = sl.SwyftDataModule(
        store=train_samples,
        model=simulator,  # Adds noise on the fly. `None` uses noise in store.
        batch_size=cfg.estimation.batch_size,
        num_workers=cfg.estimation.num_workers,
    )

    return datamodule, simulator
datamodule, simulator = simulate(cfg)

In [15]:
# # Setting up tensorboard logger, which defines also logdir (contains trained network)
# tbl = pl_loggers.TensorBoardLogger(
#     save_dir=cfg.tensorboard.save_dir,
#     name=cfg.tensorboard.name,
#     version=cfg.tensorboard.version,
#     default_hp_metric=False,
# )
# logdir = (
#     tbl.experiment.get_logdir()
# )  # Directory where all logging information and checkpoints etc are stored

# # Load network and train (or re-load trained network)
# network = hydra.utils.instantiate(cfg.estimation.network, cfg)
# #     network = ImgSegmNetwork(cfg, 1)

# lr_monitor = LearningRateMonitor(logging_interval="step")
# early_stop_callback = EarlyStopping(
#     monitor="val_loss",
#     min_delta=cfg.estimation.early_stopping.min_delta,
#     patience=cfg.estimation.early_stopping.patience,
#     verbose=False,
#     mode="min",
# )
# checkpoint_callback = ModelCheckpoint(
#     monitor="val_loss",
#     dirpath=logdir + "/checkpoint/",
#     filename="{epoch:02d}-{val_loss:.2f}",
#     save_top_k=3,
#     mode="min",
# )
# trainer = sl.SwyftTrainer(
#     accelerator=cfg.estimation.accelerator,
#     gpus=1,
#     max_epochs=cfg.estimation.max_epochs,
#     logger=tbl,
#     callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
# )
# best_checkpoint = logdir + "/checkpoint/best.ckpt"
# if not os.path.isfile(best_checkpoint):
#     trainer.fit(network, datamodule)
#     shutil.copy(checkpoint_callback.best_model_path, best_checkpoint)
#     trainer.test(network, datamodule)
# else:
#     print('realoding network?')
#     trainer.fit(network, datamodule, ckpt_path=best_checkpoint)

In [16]:
def load(cfg, simulator):
    print('Loading trained network')
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored
    
#     epoch=09-val_loss=106464.16.ckpt

    checkpoints = os.listdir( os.path.join(logdir, 'checkpoint') )
    if 'best.ckpt' in checkpoints:
        best_ckpt = 'best.ckpt'
    else:
        best_idx = np.argmax(list(map(int, [checkpoint[6:8] for checkpoint in checkpoints])))
        best_ckpt = checkpoints[best_idx]
    print(f'best checkpoint is {best_ckpt}')
    
    checkpoint = torch.load(
        os.path.join(logdir, f'checkpoint/{best_ckpt}'), map_location='cpu'
    )

    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
    network.load_state_dict(checkpoint["state_dict"])

    train_samples = torch.load(cfg.simulation.store.path)
    
    trainer = sl.SwyftTrainer(accelerator=cfg.estimation.accelerator, gpus=1)
    trainer.setup(None)
    
    datamodule = sl.SwyftDataModule(store=train_samples, model=simulator)
    datamodule.setup()
    
    trainer.model = network
    
    return network, trainer, tbl, datamodule

def analyse(cfg, datamodule):
    # Setting up tensorboard logger, which defines also logdir (contains trained network)
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored

    # Load network and train (or re-load trained network)
    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
#     network = ImgSegmNetwork(cfg, 1)

    lr_monitor = LearningRateMonitor(logging_interval="step")
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=cfg.estimation.early_stopping.min_delta,
        patience=cfg.estimation.early_stopping.patience,
        verbose=False,
        mode="min",
    )
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=logdir + "/checkpoint/",
        filename="{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )
    trainer = sl.SwyftTrainer(
        accelerator=cfg.estimation.accelerator,
        gpus=1,
        max_epochs=cfg.estimation.max_epochs,
        logger=tbl,
        callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
    )
    best_checkpoint = logdir + "/checkpoint/best.ckpt"
    if not os.path.isfile(best_checkpoint):
        trainer.fit(network, datamodule)
        shutil.copy(checkpoint_callback.best_model_path, best_checkpoint)
        trainer.test(network, datamodule)
    else:
        print('realoding network?')
        trainer.fit(network, datamodule, ckpt_path=best_checkpoint)

    return network, trainer, tbl

# network, trainer, tbl = analyse(cfg, datamodule)
network, trainer, tbl, datamodule = load(cfg, simulator)

Loading trained network
best checkpoint is best.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


# Interpret again

In [17]:
import os
import numpy as np
import pylab as plt
import torch
import swyft.lightning as sl

from lensx.nn.subN.interpret import IsotonicRegressionCalibration
from lensx.nn.subN.logging_utils_subN import LogIRC, LogPost, LogObs, LogBounds, LogSingleSub
from lensx.nn.subN.inference import Infer, Prior

In [18]:
logdir = tbl.experiment.get_logdir()

# Calculate expected n_sub
Ms = datamodule.predict_dataloader().dataset[:]['z_sub'][:,:,0]
n_sub_expect = torch.mean( torch.sum(Ms == 0, dim = 1).type(torch.float) )

# Loading the inference class and 
infer = Infer(simulator, network, datamodule, n_sub_expect)

Prior,    M_frac    in subhalo log10 mass range
3.25e-05, 8.33e-02:    [8.000 - 8.250]
3.25e-05, 8.33e-02:    [8.250 - 8.500]
3.25e-05, 8.33e-02:    [8.500 - 8.750]
3.25e-05, 8.33e-02:    [8.750 - 9.000]
3.25e-05, 8.33e-02:    [9.000 - 9.250]
3.25e-05, 8.33e-02:    [9.250 - 9.500]
3.25e-05, 8.33e-02:    [9.500 - 9.750]
3.25e-05, 8.33e-02:    [9.750 - 10.000]
3.25e-05, 8.33e-02:    [10.000 - 10.250]
3.25e-05, 8.33e-02:    [10.250 - 10.500]
3.25e-05, 8.33e-02:    [10.500 - 10.750]
3.25e-05, 8.33e-02:    [10.750 - 11.000]


In [19]:
# Prior information necessary for loggers
prior, prior_grid = infer.calc_prior()[0], infer.prior_grid()
grid_coords = infer.get_grid_coords()
grid_low, grid_high = infer.grid_low, infer.grid_high

In [10]:
# Simulations inference
posts_uncalib, targets = infer.get_posts(datamodule.predict_dataloader(), cfg.inference.n_infer)
torch.save(posts_uncalib, os.path.join(logdir, 'posts_uncalib.pt'))
torch.save(targets, os.path.join(logdir, 'targets.pt'))

Calculating posteriors:   2%|▏         | 5/313 [00:57<59:24, 11.57s/it]


KeyboardInterrupt: 

In [None]:
# Calibration
irc = IsotonicRegressionCalibration(posts_uncalib, targets)    
posts_calib = irc.calibrate(posts_uncalib)
torch.save(posts_calib, os.path.join(logdir, 'posts_calib.pt'))

In [20]:
# Load the saved posterior and targets
posts_uncalib = torch.load(os.path.join(logdir, 'posts_uncalib.pt'))
targets       = torch.load(os.path.join(logdir, 'targets.pt'))
posts_calib = torch.load(os.path.join(logdir, 'posts_calib.pt'))

In [None]:
irc = IsotonicRegressionCalibration(posts_uncalib, targets)    

In [None]:
# Log simulation inference
LogPost(tbl, posts_uncalib, targets, fig_kwargs = dict(dpi = 250, tight_layout = True)).plot_all()
LogPost(tbl, posts_calib,   targets, fig_kwargs = dict(dpi = 250, tight_layout = True), calib = 'calibrated').plot_all()
LogIRC(tbl, irc, fig_kwargs = dict(dpi = 250, tight_layout = True)).plot()

In [None]:
tbl.experiment.flush()
print("logdir:", logdir)

In [None]:
# assert 1 == 2

# Lavalamp plot

In [None]:
for _ in range(10000):
    test_sim = simulator.sample(1)
    if (test_sim['z_sub'][0,:,0] > 9.).sum() > 2:
        break
test_sim

In [None]:
test_post_uncalib = infer.get_post(test_sim).squeeze(0)
test_sim = infer.squeeze_obs(test_sim)
test_post = irc.calibrate(test_post_uncalib)

In [None]:
# logobs = LogObs(None, test_sim, test_post, prior, grid_coords)

# for zlog in [False, True]:
#     logobs.plot_msc(zlog = zlog, 
#                     plot_true = True,
#                     title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
#                     vminmax = True,
#                    );

In [None]:
# 
# m_centers, m_edges, xy_centers, xy_edges = grid_coords
# X, Y, Z = torch.meshgrid(xy_centers, xy_centers, m_centers)
# values = np.transpose(post, [2, 1, 0])

# z_sub = obs['z_sub'].numpy()
# z_sub = z_sub[np.sum(np.abs(z_sub), axis = 1) != 0] 
# M_sub, x_sub, y_sub = z_sub.T

# # values = values[:20, :20, :6]

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
def colormap(x):
    cmap = matplotlib.cm.get_cmap('viridis')(x)
    return [x, f'rgb{cmap[:-1]}']

norm = matplotlib.colors.Normalize()
colorscale = [colormap(i) for i in np.linspace(0, 1, 10)]

In [None]:
def lavalamp(post, obs, grid_coords):
    m_centers, m_edges, xy_centers, xy_edges = grid_coords
    X, Y, Z = torch.meshgrid(xy_centers, xy_centers, m_centers)
    values = np.transpose(post, [2, 1, 0])
    
    z_sub = obs['z_sub'].numpy()
    z_sub = z_sub[np.sum(np.abs(z_sub), axis = 1) != 0] 
    M_sub, x_sub, y_sub = z_sub.T
    
    im = np.array(obs['img'])
    im_x, im_y = im.shape
    x = np.linspace(grid_low[1], grid_high[1], im_x)
    y = np.linspace(grid_low[2], grid_high[2], im_y)
    z = np.ones(im.shape) * grid_low[0]
    
    

    fig = make_subplots(rows=1, cols=2,
                        specs=[[{'is_3d': True}, {'is_3d': True}]],
                        subplot_titles=['Normal scale', 'Logarithmic scale'],
                        )

    for ncol, v, cbar_x in zip([1, 2], [values, np.log10(values)], [-0.10, None]):
        fig.add_trace(go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=v.flatten(),
            surface_count = 20,
        #     opacity = 0.1,
            opacityscale = [[0, 0], [1, 0.9]],
            colorbar_x=cbar_x,
        ), 1, ncol)

        fig.add_trace(go.Scatter3d(
            x = x_sub,
            y = y_sub,
            z = M_sub,
            mode ='markers',
            marker = dict(
                color = 'red',
                symbol = 'x',
                size = 5,
            ),
        ), 1, ncol)

    fig.add_trace(go.Surface(x=x, y=y, z=z,
        surfacecolor=im, 
        colorscale=colorscale,
        showscale=False,
    #     lighting_diffuse=1,
    #     lighting_ambient=1,
    #     lighting_fresnel=1,
    #     lighting_roughness=1,
    #     lighting_specular=0.5,
    ), 1, 1)


    fig.update_layout(
        height = 800, 
        width = 1600, 
        title_text="Subhalo posteriors",
        scene = dict(
            xaxis=dict(title=r"x"),
            yaxis=dict(title=r"y"),
            zaxis=dict(title=r'M'),
        ),
        showlegend=False
    )
    
    return fig
    


version = 'nosum'
post, obs = test_post.cpu().numpy(), test_sim
fig = lavalamp(post, obs, grid_coords)
fig.show()
fig.write_html(f"lavalamp_v{version}.html")

# Logarithmic relicurve

In [None]:
assert 1 == 2

In [None]:
DEVICE = 'cuda'

class Alpha():
    def __init__(self, posts, n_alpha = 50):
        self.n_alpha = n_alpha
        self.alpha_edges = torch.linspace(0, 1, n_alpha, device = DEVICE)#, dtype=torch.float64)
        self.alpha_centers = (self.alpha_edges[:-1] + self.alpha_edges[1:])/2
        
class LogAlpha():
    def __init__(self, posts, n_alpha = 50):
        self.n_alpha = n_alpha
        
        posts_min = torch.log10(posts.min())
        posts_max = torch.log10(posts.max())
        self.alpha_edges = torch.logspace(posts_min, posts_max, self.n_alpha, device = DEVICE)
        
        self.alpha_edges = torch.cat((
#             torch.tensor([0.], device = DEVICE),
            self.alpha_edges,
#             torch.tensor([1], device = DEVICE)
        ))
        
        self.n_alpha = len(self.alpha_edges)

        self.alpha_centers = (self.alpha_edges[:-1] + self.alpha_edges[1:])/2
        

class PostData(LogAlpha):
    def __init__(self, posts, targets, n_alpha = 50):
        super().__init__(posts = posts, n_alpha = n_alpha)
        self.posts = posts
        self.targets = targets
        
    def get_histogram(self):
        hist = torch.histogram(self.posts.flatten().cpu(), bins = self.alpha_edges.cpu())[0].to(DEVICE)
        return hist
    
    def get_relicurve(self, batch_size = 16):
        
        is_between_sum = torch.zeros_like(self.alpha_centers)
        
        for batch_idx in tqdm(range(int(np.ceil(len(self.posts) / batch_size))), desc='Calculating reliability curve'):
            i, j = batch_idx*batch_size, (batch_idx+1)*batch_size
            posts_alpha = torch.repeat_interleave(self.posts[i:j].unsqueeze(-1), self.n_alpha-1, dim = -1)
            targets_alpha = torch.repeat_interleave(self.targets[i:j].unsqueeze(-1), self.n_alpha-1, dim = -1)
        
            is_between = (posts_alpha > self.alpha_edges[:-1]) & (posts_alpha < self.alpha_edges[1:])
            is_between_sum += torch.sum(targets_alpha * is_between, dim = (0, 1, 2, 3))
        hist = self.get_histogram() 
        relicurve = is_between_sum/hist
        relicurve = torch.nan_to_num(relicurve)
        return relicurve, is_between_sum

In [None]:
postdata = PostData(posts_uncalib, targets)

In [None]:
relicurve, is_between_sum = postdata.get_relicurve()

In [None]:
hist

In [None]:
is_between_sum

In [None]:
plt.stairs(is_between_sum.cpu(), postdata.alpha_edges.cpu())
# plt.plot((posts_min, posts_max ), (posts_min, posts_max ), 'k:')
plt.xscale('log')
plt.yscale('log')

In [None]:
posts_min, posts_max = postdata.posts.min().cpu(), postdata.posts.max().cpu()

In [None]:
plt.stairs(relicurve.cpu(), postdata.alpha_edges.cpu())
plt.plot((posts_min, posts_max ), (posts_min, posts_max ), 'k:')
plt.xscale('log')
plt.yscale('log')

In [None]:
a   = torch.tensor([1e-10])
eps = torch.tensor([1e-17])

a + eps == a

In [None]:
print(f'smallest posterior = {posts_uncalib.min()}, largest posterior = {posts_uncalib.max()}')


plt.stairs(hist.cpu(), postdata.alpha_edges.cpu())
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Predicted pixel posterior')
plt.ylabel('Counts')
plt.show()

print('>>> torch.finfo(torch.float32)', torch.finfo(torch.float32), sep = '\n')
