In [1]:
%load_ext autoreload
%autoreload 2

import os
import shutil

import omegaconf
import hydra
import numpy as np
import pylab as plt
import swyft.lightning as sl
import torch
from lensx.logging_utils import log_post_plots, log_target_plots, log_train_plots
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from lensx.nn.subN.utils import print_dict

# plt.switch_backend("agg")

In [2]:
cfg = omegaconf.OmegaConf.load("config_uniform.yaml")
from lensx.nn.subN.plot import plt_imshow
imkwargs = dict(extent=(-2.5, 2.5, -2.5, 2.5), origin='lower') #left, right, bottom, top
from tqdm.notebook import tqdm as tqdm

In [3]:
def check_obs(cfg):
    try:   torch.load(cfg.inference.obs_path)
    except FileNotFoundError: print('No mock generated!')
        
check_obs(cfg)

No mock generated!


In [4]:
def simulate(cfg):
    # Loading simulator (potentially bounded)
    simulator = hydra.utils.instantiate(cfg.simulation.model)

    # Generate or load training data & generate datamodule
    train_samples = sl.file_cache(
        lambda: simulator.sample(cfg.simulation.store.store_size),
        cfg.simulation.store.path,
    )[: cfg.simulation.store.train_size]
    datamodule = sl.SwyftDataModule(
        store=train_samples,
        model=simulator,  # Adds noise on the fly. `None` uses noise in store.
        batch_size=cfg.estimation.batch_size,
        num_workers=cfg.estimation.num_workers,
    )

    return datamodule, simulator
datamodule, simulator = simulate(cfg)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
# # Setting up tensorboard logger, which defines also logdir (contains trained network)
# tbl = pl_loggers.TensorBoardLogger(
#     save_dir=cfg.tensorboard.save_dir,
#     name=cfg.tensorboard.name,
#     version=cfg.tensorboard.version,
#     default_hp_metric=False,
# )
# logdir = (
#     tbl.experiment.get_logdir()
# )  # Directory where all logging information and checkpoints etc are stored

# # Load network and train (or re-load trained network)
# network = hydra.utils.instantiate(cfg.estimation.network, cfg)
# #     network = ImgSegmNetwork(cfg, 1)

# lr_monitor = LearningRateMonitor(logging_interval="step")
# early_stop_callback = EarlyStopping(
#     monitor="val_loss",
#     min_delta=cfg.estimation.early_stopping.min_delta,
#     patience=cfg.estimation.early_stopping.patience,
#     verbose=False,
#     mode="min",
# )
# checkpoint_callback = ModelCheckpoint(
#     monitor="val_loss",
#     dirpath=logdir + "/checkpoint/",
#     filename="{epoch:02d}-{val_loss:.2f}",
#     save_top_k=3,
#     mode="min",
# )
# trainer = sl.SwyftTrainer(
#     accelerator=cfg.estimation.accelerator,
#     gpus=1,
#     max_epochs=cfg.estimation.max_epochs,
#     logger=tbl,
#     callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
# )
# best_checkpoint = logdir + "/checkpoint/best.ckpt"
# if not os.path.isfile(best_checkpoint):
#     trainer.fit(network, datamodule)
#     shutil.copy(checkpoint_callback.best_model_path, best_checkpoint)
#     trainer.test(network, datamodule)
# else:
#     print('realoding network?')
#     trainer.fit(network, datamodule, ckpt_path=best_checkpoint)

In [6]:
def load(cfg, simulator):
    print('Loading trained network')
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored
    
#     epoch=09-val_loss=106464.16.ckpt

    checkpoints = os.listdir( os.path.join(logdir, 'checkpoint') )
    if 'best.ckpt' in checkpoints:
        best_ckpt = 'best.ckpt'
    else:
        best_idx = np.argmax(list(map(int, [checkpoint[6:8] for checkpoint in checkpoints])))
        best_ckpt = checkpoints[best_idx]
    print(f'best checkpoint is {best_ckpt}')
    
    checkpoint = torch.load(
        os.path.join(logdir, f'checkpoint/{best_ckpt}'), map_location='cpu'
    )

    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
    network.load_state_dict(checkpoint["state_dict"])

    train_samples = torch.load(cfg.simulation.store.path)
    
    trainer = sl.SwyftTrainer(accelerator=cfg.estimation.accelerator, gpus=1)
    trainer.setup(None)
    
    datamodule = sl.SwyftDataModule(store=train_samples, model=simulator)
    datamodule.setup()
    
    trainer.model = network
    
    return network, trainer, tbl, datamodule

def analyse(cfg, datamodule):
    # Setting up tensorboard logger, which defines also logdir (contains trained network)
    tbl = pl_loggers.TensorBoardLogger(
        save_dir=cfg.tensorboard.save_dir,
        name=cfg.tensorboard.name,
        version=cfg.tensorboard.version,
        default_hp_metric=False,
    )
    logdir = (
        tbl.experiment.get_logdir()
    )  # Directory where all logging information and checkpoints etc are stored

    # Load network and train (or re-load trained network)
    network = hydra.utils.instantiate(cfg.estimation.network, cfg)
#     network = ImgSegmNetwork(cfg, 1)

    lr_monitor = LearningRateMonitor(logging_interval="step")
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=cfg.estimation.early_stopping.min_delta,
        patience=cfg.estimation.early_stopping.patience,
        verbose=False,
        mode="min",
    )
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=logdir + "/checkpoint/",
        filename="{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )
    trainer = sl.SwyftTrainer(
        accelerator=cfg.estimation.accelerator,
        gpus=1,
        max_epochs=cfg.estimation.max_epochs,
        logger=tbl,
        callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
    )
    best_checkpoint = logdir + "/checkpoint/best.ckpt"
    if not os.path.isfile(best_checkpoint):
        trainer.fit(network, datamodule)
        shutil.copy(checkpoint_callback.best_model_path, best_checkpoint)
        trainer.test(network, datamodule)
    else:
        print('realoding network?')
        trainer.fit(network, datamodule, ckpt_path=best_checkpoint)

    return network, trainer, tbl

# network, trainer, tbl = analyse(cfg, datamodule)
network, trainer, tbl, datamodule = load(cfg, simulator)

Loading trained network
best checkpoint is best.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


# Interpret again

In [7]:
import os
import numpy as np
import pylab as plt
import torch
import swyft.lightning as sl

from lensx.nn.subN.interpret import IsotonicRegressionCalibration
from lensx.nn.subN.logging_utils_subN import LogIRC, LogPost, LogObs, LogBounds, LogSingleSub
from lensx.nn.subN.inference import Infer, Prior

In [8]:
logdir = tbl.experiment.get_logdir()

# Calculate expected n_sub
Ms = datamodule.predict_dataloader().dataset[:]['z_sub'][:,:,0]
n_sub_expect = torch.mean( torch.sum(Ms == 0, dim = 1).type(torch.float) )

# Loading the inference class and 
infer = Infer(simulator, network, datamodule, n_sub_expect)

Prior,    M_frac    in subhalo log10 mass range
3.25e-05, 8.33e-02:    [8.000 - 8.250]
3.25e-05, 8.33e-02:    [8.250 - 8.500]
3.25e-05, 8.33e-02:    [8.500 - 8.750]
3.25e-05, 8.33e-02:    [8.750 - 9.000]
3.25e-05, 8.33e-02:    [9.000 - 9.250]
3.25e-05, 8.33e-02:    [9.250 - 9.500]
3.25e-05, 8.33e-02:    [9.500 - 9.750]
3.25e-05, 8.33e-02:    [9.750 - 10.000]
3.25e-05, 8.33e-02:    [10.000 - 10.250]
3.25e-05, 8.33e-02:    [10.250 - 10.500]
3.25e-05, 8.33e-02:    [10.500 - 10.750]
3.25e-05, 8.33e-02:    [10.750 - 11.000]


In [9]:
# Prior information necessary for loggers
prior, prior_grid = infer.calc_prior()[0], infer.prior_grid()
grid_coords = infer.get_grid_coords()
grid_low, grid_high = infer.grid_low, infer.grid_high

In [10]:
# # Simulations inference
# posts_uncalib, targets = infer.get_posts(datamodule.predict_dataloader(), cfg.inference.n_infer)
# torch.save(posts_uncalib, os.path.join(logdir, 'posts_uncalib.pt'))
# torch.save(targets, os.path.join(logdir, 'targets.pt'))

In [11]:
# # Calibration
# irc = IsotonicRegressionCalibration(posts_uncalib, targets)    
# posts_calib = irc.calibrate(posts_uncalib)
# torch.save(posts_calib, os.path.join(logdir, 'posts_calib.pt'))

In [12]:
# Load the saved posterior and targets
posts_uncalib = torch.load(os.path.join(logdir, 'posts_uncalib.pt'))
targets       = torch.load(os.path.join(logdir, 'targets.pt'))
posts_calib = torch.load(os.path.join(logdir, 'posts_calib.pt'))

irc = IsotonicRegressionCalibration(posts_uncalib, targets)    


Calculating reliability curve: 100%|██████████| 626/626 [00:02<00:00, 250.39it/s]


In [13]:
len(posts_uncalib)

10016

In [14]:
# Log simulation inference
LogPost(tbl, posts_uncalib, targets, fig_kwargs = dict(dpi = 100, figsize = (4,3))).plot_all()
LogPost(tbl, posts_calib,   targets, fig_kwargs = dict(dpi = 100, figsize = (4,3)), calib = 'calibrated').plot_all()
LogIRC(tbl, irc).plot()

Calculating reliability curve: 100%|██████████| 201/201 [00:02<00:00, 82.19it/s]
Calculating reliability curve: 100%|██████████| 201/201 [00:02<00:00, 82.35it/s]
Calculating reliability curve: 100%|██████████| 626/626 [00:02<00:00, 251.41it/s]


In [15]:
tbl.experiment.flush()
print("logdir:", logdir)

logdir: ./lightning_logs3/uniform_noise0.25_sub0-5_m8.0-11.0_pix80_msc12_sim200000/version_0


In [16]:
assert 1 == 2

AssertionError: 

# Lavalamp plot

In [None]:
for _ in range(10000):
    test_sim = simulator.sample(1)
    if (test_sim['z_sub'][0,:,0] > 10.).sum() > 3:
        break
test_sim

In [None]:

test_post_uncalib = infer.get_post(test_sim).squeeze(0)
test_sim = infer.squeeze_obs(test_sim)
test_post = irc.calibrate(test_post_uncalib)



In [None]:
logobs = LogObs(None, test_sim, test_post, prior, grid_coords)

for zlog in [False, True]:
    logobs.plot_msc(zlog = zlog, 
                    plot_true = True,
                    title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                    vminmax = True,
                   );




In [None]:
import matplotlib.colors
def get_alphas(post, a = 10, b = 1):
    post = ( post - post.min() ) / post.max()
    alphas = post * 0.8# 1 + b * np.exp( + a* post)
#     alphas = ( alphas - alphas.min() ) / alphas.max()
    return alphas

# fig = plt.figure()
# x = np.linspace(0, 1e-4)
# y = get_alphas(x, a = a)
# plt.plot(x, y, label = a)
# plt.legend()    
# plt.show()

def normalize(d):
    d -= d.min()
    d /= d.max()
    return d

def cuboid_data(l, h):

    x = [
        [l[0], h[0], h[0], l[0], l[0]],  # x coordinate of points in bottom surface
        [l[0], h[0], h[0], l[0], l[0]],  # x coordinate of points in upper surface
        [l[0], h[0], h[0], l[0], l[0]],  # x coordinate of points in outside surface
        [l[0], h[0], h[0], l[0], l[0]]  # x coordinate of points in inside surface
        ] 
    y = [
        [l[1], l[1], h[1], h[1], l[1]],  # y coordinate of points in bottom surface
        [l[1], l[1], h[1], h[1], l[1]],  # y coordinate of points in upper surface
        [l[1], l[1], l[1], l[1], l[1]],          # y coordinate of points in outside surface
         [h[1], h[1], h[1], h[1], h[1]] # y coordinate of points in inside surface
        ]    
    z = [
        [l[2], l[2], l[2], l[2], l[2]],                        # z coordinate of points in bottom surface
         [h[2], h[2], h[2], h[2], h[2]],    # z coordinate of points in upper surface
         [l[2], l[2], h[2], h[2], l[2]],                # z coordinate of points in outside surface
         [l[2], l[2], h[2], h[2], l[2]],                 # z coordinate of points in inside surface
        ]                
    return np.array((x, y, z))

def plotCubeAt(low=(0,0,0), high=(1,1,1), c="b", alpha=1, ax=None):
    # Plotting N cube elements at position pos
    if ax != None:
        X, Y, Z = cuboid_data(low, high)
        ax.plot_surface(X, Y, Z, color=c, alpha=alpha)

# def plotMatrix(ax, x, y, z, data, cmap=plt.cm.viridis, cax=None, alpha=1):
#     # plot a Matrix 
#     norm = matplotlib.colors.Normalize(vmin=data.min(), vmax=data.max())
#     colors = lambda i,j,k : matplotlib.cm.ScalarMappable(norm=norm,cmap = cmap).to_rgba(data[i,j,k]) 
#     alphas = lambda i,j,k : normalize(data)[i,j,k] 
#     for k, zi, in enumerate(tqdm(x)):
#         for j, yi in enumerate(y):
#             for i, xi in enumerate(z):
# #                 print(i, j, k, data.shape)
#                 if data[i,j,k] > 1e-3:
#                     plotCubeAt(low=full_grid_low[i,j,k], high=full_grid_high[i,j,k],
#                                c=colors(i,j,k), alpha=alphas(i,j,k),  ax=ax)
            
def plotMatrix(ax, x, y, z, data, threshold = 0., cmap=plt.cm.viridis, cax=None, alpha=1):
    # plot a Matrix 
    norm = matplotlib.colors.Normalize(vmin=data.min(), vmax=data.max())
    colors = lambda i,j,k : matplotlib.cm.ScalarMappable(norm=norm,cmap = cmap).to_rgba(data[i,j,k]) 
    alphas = lambda i,j,k : get_alphas(data)[i,j,k] 
    
    for i, (x_low,x_high), in tqdm(enumerate(zip(x[:-1], x[1:])), total = len(x[1:])):
        for j, (y_low,y_high) in enumerate(zip(y[:-1], y[1:])):
            for k, (M_low,M_high) in enumerate(zip(z[:-1], z[1:])):
                if data[i,j,k] > threshold:
                    plotCubeAt(low=(x_low,y_low,M_low), high=(x_high,y_high,M_high),
                               c=colors(i,j,k), 
                               alpha= alphas(i,j,k),  
                               ax=ax)
                    
def plotMatrix2(ax, x, y, z, data, cmap=plt.cm.viridis, cax=None, alpha=1):
    # plot a Matrix 
    norm = matplotlib.colors.Normalize(vmin=data.min(), vmax=data.max())
    colors = lambda i,j: matplotlib.cm.ScalarMappable(norm=norm,cmap = cmap).to_rgba(data[i,j]) 
    
    for i, (x_low,x_high), in tqdm(enumerate(zip(x[:-1], x[1:])), total = len(x[1:])):
        for j, (y_low,y_high) in enumerate(zip(y[:-1], y[1:])):
                plotCubeAt(low=(x_low,y_low,z-0.01), high=(x_high,y_high,z),
                           c=colors(i,j), 
                           ax=ax)
                    
def plotLava(post, obs, threshold, azim = 0):



#     post, obs, threshold, azim = test_post.cpu().numpy(), test_sim, 0.01*prior.min(), 270

    m_centers, m_edges, xy_centers, xy_edges = grid_coords


    z_sub = obs['z_sub'].numpy()
    z_sub = z_sub[np.sum(np.abs(z_sub), axis = 1) != 0] 
    M_sub, x_sub, y_sub = z_sub.T




    fig = plt.figure(figsize=(18,15))

    labelsize = 15

    ax = fig.add_subplot(111, projection='3d')

    ax.view_init(elev=30., azim=azim)


    plotMatrix(ax, xy_edges, xy_edges, m_edges, np.transpose(post, [2,1,0]), threshold = threshold)
    ax.scatter(x_sub, y_sub, M_sub, s = 200, c = 'red', marker = 'x')

    ax.set_xlim(xy_edges.min(), xy_edges.max())
    ax.set_ylim(xy_edges.min(), xy_edges.max())
    ax.set_zlim(m_edges.min(),  m_edges.max())

    ax.set_xlabel(r"$x\ ['']$", fontsize = labelsize)
    ax.set_ylabel(r"$y\ ['']$", fontsize = labelsize)
    ax.set_zlabel(r'$log_{10}(M_{sub}/M_{\odot})$', fontsize = labelsize)


    X, Y = np.meshgrid(xy_edges, xy_edges)

    # plotMatrix2(ax, xy_edges, xy_edges, m_edges.min(), obs['img'])



    # norm = matplotlib.colors.Normalize()
    # ax.set_axisbelow(False)

    # ax.imshow(obs['img'], **imkwargs)

    # ax.plot_surface(X, Y, np.atleast_2d(m_edges.min()), facecolors = plt.cm.viridis(norm(obs['img'])), 
    #                 rcount = 80,
    #                 ccount = 80,
    #                 linewidth=0,
    #                 antialiased=True,
    #                 alpha = 0.5, shade = False, zorder = -100)#, 1, zdir = 'z', offset = 0, cmap = 'viridis')

    plt.show()
    
plotLava(test_post.cpu().numpy(), test_sim, 
         threshold = prior.min() ,
         azim = 270
        )

# plotLava(test_post.cpu().numpy(), test_sim, 
#          threshold = prior.min() ,
#          azim = 45
#         )

In [None]:
post, obs = test_post.cpu().numpy(), test_sim
m_centers, m_edges, xy_centers, xy_edges = grid_coords
X, Y, Z = torch.meshgrid(xy_centers, xy_centers, m_centers)
values = np.transpose(post, [2, 1, 0])

z_sub = obs['z_sub'].numpy()
z_sub = z_sub[np.sum(np.abs(z_sub), axis = 1) != 0] 
M_sub, x_sub, y_sub = z_sub.T

# values = np.log10(values)


# x1 = np.linspace(2, 4, 3) 
# y1 = np.linspace(2, 5, 4) 
# z1 = np.linspace(2, 5, 2) 
# X, Y, Z = np.meshgrid(x1, y1, z1)
# values = (np.sin(X**2 + Y**2))/(X**2 + Y**2)

# i, j, k = 6, 6, 2
# X = X[i,j,k]
# Y = Y[i,j,k]
# Z = Z[i,j,k]
# values = values[i,j,k]

In [None]:
x, y, z = cuboid_data([0.1, 0.1, 0.1], [0.5, 0.5, 0.5])


In [None]:
def draw_mesh(fig, low, high, opacity):

In [None]:
import matplotlib.colors

In [None]:
norm = matplotlib.colors.Normalize()
im = np.array(obs['img'])
colorscale = [colormap(i) for i in np.linspace(0, 1, 10)]

im_x, im_y = im.shape
x = np.linspace(-2.5,2.5, im_x)
y = np.linspace(-2.5,2.5, im_y)
z = np.ones(im.shape[:2]) * grid_low[0]

import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=1,
                    specs=[[{'is_3d': True}]],#, {'is_3d': True}]],
#                     subplot_titles=['Normal scale', 'Logarithmic scale'],
                    )

def cuboid_data(l, h):
    x = [l[0], l[0], h[0], h[0], l[0], l[0], h[0], h[0]]
    y = [l[1], h[1], h[1], l[1], l[1], h[1], h[1], l[1]]
    z = [l[2], l[2], l[2], l[2], h[2], h[2], h[2], h[2]]
    return np.array((x, y, z))

x, y, z = cuboid_data([0.1, 0.1, 0.1], [0.5, 0.5, 0.5])
fig.add_trace(go.Mesh3d(
    x = x, y = y, z = z,
    i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
    j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
    k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
    opacity = 0.4,
    flatshading = True

))

fig.show()
# fig.write_html("plotly.html")

In [None]:
x == [0.608, 0.608, 0.998, 0.998, 0.608, 0.608, 0.998, 0.998]
y == [0.091, 0.963, 0.963, 0.091, 0.091, 0.963, 0.963, 0.091]

In [None]:
def cuboid_data(l, h):
    x = [l[0], l[0], h[0], h[0], l[0], l[0], h[0], h[0]]
    y = [l[1], h[1], h[1], l[1], l[1], h[1], h[1], l[1]]
    z = [l[2], l[2], l[2], l[2], h[2], h[2], h[2], h[2]]
    return np.array((x, y, z))

x, y, z = cuboid_data([.608, .091, .140], [.998, .963, .571])

fig = go.Figure(data=[
     go.Mesh3d(
         x = x, y = y, z = z,
        # 8 vertices of a cube
#         x=[0.608, 0.608, 0.998, 0.998, 0.608, 0.608, 0.998, 0.998],
#         y=[0.091, 0.963, 0.963, 0.091, 0.091, 0.963, 0.963, 0.091],
#         z=[0.140, 0.140, 0.140, 0.140, 0.571, 0.571, 0.571, 0.571],

        i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
        j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
        k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
        opacity=0.6,
        color='#DC143C',
        flatshading = True
    )                    
    ])

fig.show()

In [None]:
def colormap(x):
    cmap = matplotlib.cm.get_cmap('viridis')(x)
    return [x, f'rgb{cmap[:-1]}']
    

In [None]:
norm = matplotlib.colors.Normalize()
im = np.array(obs['img'])
colorscale = [colormap(i) for i in np.linspace(0, 1, 10)]

im_x, im_y = im.shape
x = np.linspace(-2.5,2.5, im_x)
y = np.linspace(-2.5,2.5, im_y)
z = np.ones(im.shape[:2]) * grid_low[0]

import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=2,
                    specs=[[{'is_3d': True}, {'is_3d': True}]],
                    subplot_titles=['Normal scale', 'Logarithmic scale'],
                    )

for ncol, v, cbar_x in zip([1, 2], [values, np.log10(values)], [-0.10, None]):
    fig.add_trace(go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=v.flatten(),
        surface_count = 20,
    #     opacity = 0.1,
        opacityscale = [[0, 0], [1, 0.9]],
        colorbar_x=cbar_x,
    ), 1, ncol)

    fig.add_trace(go.Scatter3d(
        x = x_sub,
        y = y_sub,
        z = M_sub,
        mode ='markers',
        marker = dict(
            color = 'red',
            symbol = 'x',
            size = 5,
        ),
    ), 1, ncol)

fig.add_trace(go.Surface(x=x, y=y, z=z,
    surfacecolor=im, 
    colorscale=colorscale,
    showscale=False,
#     lighting_diffuse=1,
#     lighting_ambient=1,
#     lighting_fresnel=1,
#     lighting_roughness=1,
#     lighting_specular=0.5,

), 1, 1)


fig.update_layout(
    height = 800, 
    width = 1600, 
    title_text="Subhalo posteriors test",
    scene = dict(
        xaxis=dict(title=r"x"),
        yaxis=dict(title=r"y"),
        zaxis=dict(title=r'M'),
    ),
    showlegend=False
)
# fig.show()
fig.write_html("plotly.html")

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import plotly.graph_objects as go
from scipy import misc

# im = misc.face()
# im_x, im_y, im_layers = im.shape
# eight_bit_img = Image.fromarray(im).convert('P', palette='WEB', dither=None)
# dum_img = Image.fromarray(np.ones((3,3,3), dtype='uint8')).convert('P', palette='WEB')
# idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3))
# colorscale=[[i/255.0, "rgb({}, {}, {})".format(*rgb)] for i, rgb in enumerate(idx_to_color)]

# Sample data: 3 trajectories
t = np.linspace(0, 10, 200)
df = pd.concat([pd.DataFrame({'x': 400 * (1 + np.cos(t + 5 * i)), 'y': 400 * (1 + np.sin(t)), 't': t, 'id': f'id000{i}'}) for i in [0, 1, 2]])
# im = im.swapaxes(0, 1)[:, ::-1]
colors=df['t'].to_list()

# # 3d scatter plot
x = np.linspace(0,404.8, im_x)
y = np.linspace(0, 504.4, im_y)
z = np.zeros(im.shape[:2])

# x = np.linspace(-2.5,2.5, im_x)
# y = np.linspace(-2.5,2.5, im_y)
# z = np.ones(im.shape) * 8.

fig = go.Figure()

fig.add_trace(go.Scatter3d(
    x=df['x'], 
    y=df['y'], 
    z=df['t'],
    marker=dict(
        color=colors,
        size=4,
    )
    ))

fig.add_trace(go.Surface(x=x, y=y, z=z,
    surfacecolor=eight_bit_img, 
    cmin=0, 
    cmax=255,
    colorscale=colorscale,
    showscale=False,
    lighting_diffuse=1,
    lighting_ambient=1,
    lighting_fresnel=1,
    lighting_roughness=1,
    lighting_specular=0.5,

))

fig.update_layout(
    title="My 3D scatter plot",
    width=800,
    height=800,
    scene=dict(xaxis_visible=True,
                yaxis_visible=True, 
                zaxis_visible=True, 
                xaxis_title="X",
                yaxis_title="Y",
                zaxis_title="Z" ,

    ))

In [None]:
plt.hist(vol.flatten())

In [None]:
plt.hist(vol.flatten())

In [None]:
np.log10(post)

In [None]:
import numpy as np
import plotly.graph_objects as go

# Generate nicely looking random 3D-field
np.random.seed(0)
l = 30
X, Y, Z = np.mgrid[:l, :l, :l]
vol = np.zeros((l, l, l))
pts = (l * np.random.rand(3, 15)).astype(np.int)
vol[tuple(indices for indices in pts)] = 1
from scipy import ndimage
vol = ndimage.gaussian_filter(vol, 4)
vol /= vol.max()

fig = go.Figure(data=go.Volume(
    x=X.flatten(), y=Y.flatten(), z=Z.flatten(),
    value=vol.flatten(),
    isomin=0.2,
    isomax=0.7,
    opacity=0.1,
    surface_count=100,
    ))
fig.update_layout(scene_xaxis_showticklabels=False,
                  scene_yaxis_showticklabels=False,
                  scene_zaxis_showticklabels=False)
fig.show()

In [None]:
fig = go.Figure(data=[
     go.Scatter3d(x=x, y=y, z=z,
                  mode='markers',
                  marker=dict(size=2)
                 ),
     go.Mesh3d(
        # 8 vertices of a cube
        x=[0.608, 0.608, 0.998, 0.998, 0.608, 0.608, 0.998, 0.998],
        y=[0.091, 0.963, 0.963, 0.091, 0.091, 0.963, 0.963, 0.091],
        z=[0.140, 0.140, 0.140, 0.140, 0.571, 0.571, 0.571, 0.571],

        i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
        j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
        k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
        opacity=0.6,
        color='#DC143C',
        flatshading = True
    )                    
    ])

In [None]:
import plotly

In [None]:
plotly.plot(fig)

In [None]:
trace1 = go.Scatter3d(
    x = df['XXXX'],
    y = df['XXXX'],
    z = df['XXXX'],
    mode='markers',
    marker=dict(
        size=12,
        color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    )
)

data = [trace1]
layout = go.Layout(
    scene = dict(
                    xaxis = dict(
                        title='XXXX-XXXXXX'),
                    yaxis = dict(
                        title='XXXX-XXXXXX'),
                    zaxis = dict(
                        title='XXXX-XXXXXX'),),
    margin=dict(
        r=20, b=10, l=10, t=10
    )
)
fig = go.Figure(data=data, layout=layout)
#py.iplot(fig, filename='3d-scatter-colorscale')
plot(fig, filename='D:\\plots\\3dplots\\xx.html')

In [None]:
z[1:]

In [None]:
ax.plot_surface(X, Y, np.atleast_2d(m_edges.min()), facecolors = plt.cm.viridis(norm(obs['img'])), 
                alpha = 0.1, shade = False, zorder = -100)#, 1, zdir = 'z', offset = 0, cmap = 'viridis')

In [None]:
plt.cm.viridis(norm(obs['img']))

In [None]:
plt.imshow(obs['img'])

In [None]:
for p in cmap(obs['img']).T:
    plt.imshow(p)
    plt.show()

In [None]:
m_centers, m_edges, xy_centers, xy_edges = grid_coords


In [None]:
x = np.ones((1, 2, 3))
np.transpose(x, (1, 0, 2)).shape


In [None]:
post.shape

In [None]:
np.transpose(post, (1, 2, 0)).shape

In [None]:
obs['z_sub'].T.shape

In [None]:
np.swapaxes(obs['z_sub'].T.numpy()

In [None]:
np.transpose(obs['z_sub'].numpy(), axes = [0, 1, 2])

In [None]:
m_edges

In [None]:
full_grid_low

In [None]:
grid = torch.stack(torch.meshgrid((m_edges, xy_edges, xy_edges)), dim = -1)
full_grid_low  = grid[:-1,:-1,:-1]
full_grid_high = grid[1:,1:,1:]

In [None]:
x, y, z

In [None]:
cuboid_data([0,0,0])

In [None]:
fig = plt.figure(figsize=(18,15))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=10., azim=10)
plotMatrix(ax, x, y, z, data_value)
plt.show()

In [None]:
#         azims = [10., 100., 190., 250.]
azims = [10.]

fig = plt.figure(figsize=(18,15))

for i, azim in enumerate(azims):
    ax = fig.add_subplot(2, 2, i+1, projection='3d')
#         plotMatrix(ax, x, y, z, data_value)

#         ax.scatter(*scatter, marker = 'x', color = 'red', s = 100)

    ax.view_init(elev=10., azim=azim)

    labelsize = 15
    ax.set_xlabel(r'$x\ [\deg]$', fontsize = labelsize)
    ax.set_ylabel(r'$y\ [\deg]$', fontsize = labelsize)
    ax.set_zlabel(r'$log_{10}(M_{sub}/M_{\odot})$', fontsize = labelsize)

    ax.set_xticks(np.linspace(0, L, 11)[1::2])
    ax.set_xticklabels(np.linspace(-2.5, 2.5, 11)[1::2])
    ax.set_yticks(np.linspace(0, L, 11)[1::2])
    ax.set_yticklabels(np.linspace(-2.5, 2.5, 11)[1::2])
    ax.set_zticks(z)
    ax.set_zticklabels(np.log10(m_centers.numpy()))





fig.subplots_adjust(right=0.8)
ax_cb = fig.add_axes([0.85, 0.15, 0.02, 0.7])

norm = matplotlib.colors.Normalize(vmin=data_value.min(), vmax=data_value.max())
cbar = matplotlib.colorbar.ColorbarBase(ax_cb, cmap=plt.cm.viridis,
                                norm=norm,
                                orientation='vertical')  
#     return fig
#     plt.savefig(f'figs/lava_{plot_name}.png',bbox_inches='tight')
plt.show()

#     print('Done!')
    
# plotLava(post)

In [None]:

plt.show()

In [None]:


    def get_alphas(post):
        post = ( post - post.min() ) / post.max()

        a, b = 50, 1
        alphas = 1 - b * np.exp( - a * post)

        return alphas

    def normalize(d):
        d -= d.min()
        d /= d.max()
        return d

    def cuboid_data(center, size=(1,1,1)):
        # code taken from
        # http://stackoverflow.com/questions/30715083/python-plotting-a-wireframe-3d-cuboid?noredirect=1&lq=1
        # suppose axis direction: x: to left; y: to inside; z: to upper
        # get the (left, outside, bottom) point
        o = [a - b / 2 for a, b in zip(center, size)]
        # get the length, width, and height
        l, w, h = size
        x = [[o[0], o[0] + l, o[0] + l, o[0], o[0]],  # x coordinate of points in bottom surface
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  # x coordinate of points in upper surface
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  # x coordinate of points in outside surface
             [o[0], o[0] + l, o[0] + l, o[0], o[0]]]  # x coordinate of points in inside surface
        y = [[o[1], o[1], o[1] + w, o[1] + w, o[1]],  # y coordinate of points in bottom surface
             [o[1], o[1], o[1] + w, o[1] + w, o[1]],  # y coordinate of points in upper surface
             [o[1], o[1], o[1], o[1], o[1]],          # y coordinate of points in outside surface
             [o[1] + w, o[1] + w, o[1] + w, o[1] + w, o[1] + w]]    # y coordinate of points in inside surface
        z = [[o[2], o[2], o[2], o[2], o[2]],                        # z coordinate of points in bottom surface
             [o[2] + h, o[2] + h, o[2] + h, o[2] + h, o[2] + h],    # z coordinate of points in upper surface
             [o[2], o[2], o[2] + h, o[2] + h, o[2]],                # z coordinate of points in outside surface
             [o[2], o[2], o[2] + h, o[2] + h, o[2]]]                # z coordinate of points in inside surface
        return np.array((x, y, z))

    def plotCubeAt(pos=(0,0,0), c="b", alpha=1, ax=None):
        # Plotting N cube elements at position pos
        if ax !=None:
            X, Y, Z = cuboid_data( (pos[0],pos[1],pos[2]) )
            ax.plot_surface(X, Y, Z, color=c, rstride=1, cstride=1, alpha=alpha)

    def plotMatrix(ax, x, y, z, data, cmap=plt.cm.viridis, cax=None, alpha=1):
        # plot a Matrix 
        norm = matplotlib.colors.Normalize(vmin=data.min(), vmax=data.max())
        colors = lambda i,j,k : matplotlib.cm.ScalarMappable(norm=norm,cmap = cmap).to_rgba(data[i,j,k]) 
        alphas = lambda i,j,k : normalize(data)[i,j,k] 
        for i, xi in enumerate(tqdm(x)):
                for j, yi in enumerate(y):
                    for k, zi, in enumerate(z):
                        plotCubeAt(pos=(xi, yi, zi), c=colors(i,j,k), alpha=alphas(i,j,k),  ax=ax)


    def plotLava(obs0_i = -1):
        post, target_coords, scatter, obs0_i = get_pred(obs0_i = obs0_i)

        # x and y and z coordinates
        x = np.array(range(21)) #np.linspace(0,9,11) #
        y = np.array(range(10,15))
        z = np.array(range(15,20))
        # data_value = np.random.randint(1,4, size=(len(x), len(y), len(z)) )
        data_value = np.random.rand(len(x), len(y), len(z))

        x = y = np.arange(L) #np.linspace(0, 1, L)
        z = np.arange(len(m_centers)) #m_centers.numpy()
        data_value = np.transpose(post, [1,2,0])

        azims = [10., 100., 190., 250.]
        # azims = [10.]

        print(data_value.shape)
        plot_name = f'{mre_name}_obs0_i={obs0_i}'
        print(f'plot_name {plot_name}')

        fig = plt.figure(figsize=(18,15))

        for i, azim in enumerate(azims):
            ax = fig.add_subplot(2, 2, i+1, projection='3d')
            plotMatrix(ax, x, y, z, data_value)

            ax.scatter(*scatter, marker = 'x', color = 'red', s = 100)

            ax.view_init(elev=10., azim=azim)

            labelsize = 15
            ax.set_xlabel(r'$x\ [\deg]$', fontsize = labelsize)
            ax.set_ylabel(r'$y\ [\deg]$', fontsize = labelsize)
            ax.set_zlabel(r'$log_{10}(M_{sub}/M_{\odot})$', fontsize = labelsize)

            ax.set_xticks(np.linspace(0, L, 11)[1::2])
            ax.set_xticklabels(np.linspace(-2.5, 2.5, 11)[1::2])
            ax.set_yticks(np.linspace(0, L, 11)[1::2])
            ax.set_yticklabels(np.linspace(-2.5, 2.5, 11)[1::2])
            ax.set_zticks(z)
            ax.set_zticklabels(np.log10(m_centers.numpy()))





        fig.subplots_adjust(right=0.8)
        ax_cb = fig.add_axes([0.85, 0.15, 0.02, 0.7])

        norm = matplotlib.colors.Normalize(vmin=data_value.min(), vmax=data_value.max())
        cbar = matplotlib.colorbar.ColorbarBase(ax_cb, cmap=plt.cm.viridis,
                                        norm=norm,
                                        orientation='vertical')  

        plt.savefig(f'figs/lava_{plot_name}.png',bbox_inches='tight')
        plt.show()
        
        print('Done!')

# Extra Plots

In [None]:
def get_testdata(simulator, infer, irc, n_test = 4):
    CALIBRATE = True
    test_posts, test_sims = [], []

    for _ in range(n_test):
        test_sim = simulator.sample(1)

        test_post_uncalib = infer.get_post(test_sim).squeeze(0)
        test_sim = infer.squeeze_obs(test_sim)
        test_post = irc.calibrate(test_post_uncalib) if CALIBRATE is True else test_post_uncalib

        test_sims.append(test_sim)
        test_posts.append(test_post)

    test_posts = torch.stack(test_posts)       
    
    
    return test_sims, test_posts
test_sims, test_posts = get_testdata(simulator, infer, irc, n_test = 4)

In [None]:
def plot_testdata(test_sims, test_posts, zlog = True, vminmax = True):
    vmax, vmin = test_posts.max(), test_posts.min()
    for test_sim, test_post in zip(test_sims, test_posts):
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax) if vminmax is True else {}

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
#                         **kwargs, 
                        vminmax = True,
                       );
#     for test_sim, test_post in zip(test_sims, test_posts):
#         logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
#         logobs.plot_obs();
        
# plot_testdata(test_sims, test_posts, zlog = True)
plot_testdata(test_sims, test_posts, zlog = False)#, vminmax = False)

In [None]:

    
    
#     posts_uncalib = torch.load(os.path.join(tbl.experiment.get_logdir(), 'posts_uncalib.pt'))
#     targets       = torch.load(os.path.join(tbl.experiment.get_logdir(), 'targets.pt'))
    

    
    
    

    
#     # Observation inference
#     obs = torch.load(cfg.inference.obs_path)
#     obs_post = infer.get_post( dict(img=obs['img'].unsqueeze(0).cpu()))
    
#     # Log observation inference
#     log_obs = LogObs(tbl, obs, obs_post, prior, grid_coords, fig_kwargs = dict(dpi = 250, figsize = (8, 5)))
#     log_obs.plot_all()   
    


#     if (cfg.simulation.model.n_sub, cfg.estimation.network.n_msc) == (1, 1):
#         log_single_sub = LogSingleSub(tbl, obs, obs_post, prior_grid, grid_coords)
#         log_single_sub.plot_all()

    # Log bounds
#     log_bounds = LogBounds(tbl, obs, obs_post, grid_coords, grid_low, grid_high)
#     log_bounds.plot_all()
                              


# Interpret

In [None]:
import os
import numpy as np
import pylab as plt
import torch
import swyft.lightning as sl

from lensx.nn.subN.interpret import IsotonicRegressionCalibration
from lensx.nn.subN.logging_utils_subN import LogIRC, LogPost, LogObs, LogBounds, LogSingleSub
from lensx.nn.subN.inference import Infer, Prior

In [None]:
logdir = tbl.experiment.get_logdir()

# Calculate expected n_sub
Ms = datamodule.predict_dataloader().dataset[:]['z_sub'][:,:,0]
n_sub_expect = torch.mean( torch.sum(Ms == 0, dim = 1).type(torch.float) )

# Loading the inference class and 
infer = Infer(simulator, network, datamodule, n_sub_expect)

# Prior information necessary for loggers
prior, prior_grid = infer.calc_prior(), infer.prior_grid()[0]
grid_coords = infer.get_grid_coords()
grid_low, grid_high = infer.grid_low, infer.grid_high

In [None]:
posts_uncalib = torch.load(os.path.join(logdir, 'posts_uncalib.pt'))
posts_calib = torch.load(os.path.join(logdir, 'posts_calib.pt'))
targets       = torch.load(os.path.join(logdir, 'targets.pt'))

In [None]:
# # Simulations inference
# posts_uncalib, targets = infer.get_posts(datamodule.predict_dataloader(), cfg.inference.n_infer)
# torch.save(posts_uncalib, os.path.join(logdir, 'posts_uncalib.pt'))
# torch.save(targets, os.path.join(logdir, 'targets.pt'))

In [None]:
# Calibration
irc = IsotonicRegressionCalibration(posts_uncalib, targets)    
# posts_calib = irc.calibrate(posts_uncalib)
# torch.save(posts_calib, os.path.join(logdir, 'posts_calib.pt'))

In [None]:
LogIRC(tbl, irc).plot()

In [None]:

def get_testdata(simulator, infer, irc, n_test = 4):
    CALIBRATE = True
    test_posts, test_sims = [], []

    for _ in range(n_test):
        test_sim = simulator.sample(1)

        test_post_uncalib = infer.get_post(test_sim).squeeze(0)
        test_sim = infer.squeeze_obs(test_sim)
        test_post = irc.calibrate(test_post_uncalib) if CALIBRATE is True else test_post_uncalib

        test_sims.append(test_sim)
        test_posts.append(test_post)

    test_posts = torch.stack(test_posts)       
    
    
    return test_sims, test_posts
test_sims, test_posts = get_testdata(simulator, infer, irc, n_test = 4)

In [None]:
def plot_testdata(test_sims, test_posts, zlog = True, vminmax = True):
    vmax, vmin = test_posts.max(), test_posts.min()
    for test_sim, test_post in zip(test_sims, test_posts):
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax) if vminmax is True else {}

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                        **kwargs, 
                       );
    for test_sim, test_post in zip(test_sims, test_posts):
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        logobs.plot_obs();

In [None]:
plot_testdata(test_sims, test_posts, zlog = True)
# plot_testdata(test_sims, test_posts, zlog = False, vminmax = False)

In [None]:
assert 1 == 2

# Uniform

In [None]:
import os
import numpy as np
import pylab as plt
import torch
import swyft.lightning as sl

from lensx.nn.subN.interpret import IsotonicRegressionCalibration
from lensx.nn.subN.logging_utils_subN import LogIRC, LogPost, LogObs, LogBounds, LogSingleSub
from lensx.nn.subN.inference import Infer, Prior

In [None]:
Ms = datamodule.predict_dataloader().dataset[:]['z_sub'][:,:,0]
n_sub_expect = torch.mean( torch.sum(Ms == 0, dim = 1).type(torch.float) )

In [None]:
logdir = tbl.experiment.get_logdir()

# Loading the inference class and 
infer = Infer(simulator, network, datamodule, n_sub_expect)

In [None]:
# Prior information necessary for loggers
prior, prior_grid = infer.calc_prior(), infer.prior_grid()
grid_coords = infer.get_grid_coords()
grid_low, grid_high = infer.grid_low, infer.grid_high

# Simulations inference
posts_uncalib, targets = infer.get_posts(datamodule.predict_dataloader(), cfg.inference.n_infer)
torch.save(posts_uncalib, os.path.join(logdir, 'posts_uncalib.pt'))
torch.save(targets, os.path.join(logdir, 'targets.pt'))
# posts_uncalib = torch.load(os.path.join(tbl.experiment.get_logdir(), 'posts.pt'))
# targets       = torch.load(os.path.join(tbl.experiment.get_logdir(), 'targets.pt'))

# Calibration
irc = IsotonicRegressionCalibration(posts_uncalib, targets)    
posts_calib = irc.calibrate(posts_uncalib)
torch.save(posts_calib, os.path.join(logdir, 'posts_calib.pt'))

In [None]:
LogIRC(None, irc).plot()

# Single

In [None]:
def interpret(cfg, simulator, network, trainer, datamodule, tbl):
    hydra.utils.call(
        cfg.inference.interpreter, cfg, simulator, network, trainer, datamodule, tbl
    )

# def interpret(cfg, simulator, network, trainer, datamodule, tbl):
#     hydra.utils.instantiate(cfg.inference.interpreter, cfg, simulator, network, trainer, datamodule, tbl)

interpret(cfg, simulator, network, trainer, datamodule, tbl)

In [None]:
import os
import numpy as np
import pylab as plt
import torch
import swyft.lightning as sl

from lensx.nn.subN.interpret import IsotonicRegressionCalibration
from lensx.nn.subN.logging_utils_subN import LogIRC, LogPost, LogObs, LogBounds, LogSingleSub
from lensx.nn.subN.inference import Infer, Prior

In [None]:
logdir = tbl.experiment.get_logdir()

# Loading the inference class and 
infer = Infer(simulator, network, datamodule)

# Prior information necessary for loggers
prior, prior_grid = infer.calc_prior(), infer.prior_grid()
grid_coords = infer.get_grid_coords()
grid_low, grid_high = infer.grid_low, infer.grid_high

In [None]:
# Simulations inference
#     posts_uncalib, targets = infer.get_posts(datamodule.predict_dataloader(), cfg.inference.n_infer)
#     torch.save(posts_uncalib, os.path.join(logdir, 'posts_uncalib.pt'))
#     torch.save(targets, os.path.join(logdir, 'targets.pt'))
posts_uncalib = torch.load(os.path.join(tbl.experiment.get_logdir(), 'posts.pt'))
targets       = torch.load(os.path.join(tbl.experiment.get_logdir(), 'targets.pt'))

In [None]:
def normalize(posts_uncalib, mult = 1):
    ndim = posts_uncalib.ndim
    if ndim == 3: posts_uncalib = posts_uncalib.unsqueeze(0) 
    sum_posts_uncalib = torch.sum(posts_uncalib, dim = (1,2,3))
    sum_posts_uncalib = sum_posts_uncalib.unsqueeze(1).unsqueeze(1).unsqueeze(1)
    sum_posts_uncalib = torch.tile(sum_posts_uncalib, (1, *posts_uncalib.shape[1:]))
    posts_uncalib_norm = mult * posts_uncalib / sum_posts_uncalib
    if ndim == 3: posts_uncalib_norm = posts_uncalib_norm.squeeze() 
    return posts_uncalib_norm 

In [None]:
NORMALIZE = True
# posts_uncalib = normalize(posts_uncalib) if NORMALIZE is True else posts_uncalib

posts_uncalib_unnorm = posts_uncalib
posts_uncalib_norm   = normalize(posts_uncalib)

In [None]:
# Calibration
irc_unnorm = IsotonicRegressionCalibration(posts_uncalib_unnorm, targets)    
posts_calib_unnorm = irc_unnorm.calibrate(posts_uncalib_unnorm)
# torch.save(posts_calib, os.path.join(logdir, 'posts_calib.pt'))
LogIRC(None, irc_unnorm).plot()

In [None]:
# Calibration
irc_norm = IsotonicRegressionCalibration(posts_uncalib_norm, targets)    
posts_calib_norm = irc_norm.calibrate(posts_uncalib_norm)
# torch.save(posts_calib, os.path.join(logdir, 'posts_calib.pt'))
LogIRC(None, irc_norm).plot()

# Calibration

In [None]:
# from sklearn.isotonic import IsotonicRegression

# class IsotonicRegressionCalibration():
#     def __init__(self, posts_uncalib, targets):
# #         super().__init__()
#         self.posts_uncalib = posts_uncalib
#         self.targets = targets
        
#         print('initialzing')
        
#         assert torch.sum(torch.isnan(posts_uncalib.view(-1))).item() == 0
        
        
#         self.post_data_uncalib = PostData(posts_uncalib, targets)
    
#         self.relicurve_uncalib = self.post_data_uncalib.get_relicurve().cpu()
        
#         self.ir = self.get_ir(self.relicurve_uncalib, self.post_data_uncalib.alpha_centers.cpu())
    
    
#     def get_ir(self, relicurve, alpha_centers):
#         alpha_centers_zero = torch.cat((torch.tensor([0]), alpha_centers))
#         relicurve_zero     = torch.cat((torch.tensor([0]), relicurve))

#         ir = IsotonicRegression(out_of_bounds = 'clip')
        
#         ir.fit(alpha_centers_zero, relicurve_zero);
#         return ir
    
#     def calibrate(self, posts):
#         posts_calib = self.ir.predict(posts.cpu().flatten()).reshape(posts.shape)
#         posts_calib = torch.tensor(posts_calib, device = posts.device, dtype = posts.dtype)
#         return posts_calib

# class LogIRC:
#     def __init__(self, irc):
                
#         self.posts_uncalib = irc.posts_uncalib
#         self.targets = irc.targets
#         self.relicurve_uncalib = irc.relicurve_uncalib
#         self.alpha_edges = irc.post_data_uncalib.alpha_edges.cpu()
#         self.alpha_centers = irc.post_data_uncalib.alpha_centers.cpu()
#         self.ir = irc.ir

        
#         self.posts_calib = irc.calibrate(self.posts_uncalib)
#         self.relicurve_calib = PostData(self.posts_calib, self.targets).get_relicurve().cpu()
        
#     def plot(self):
#         fig = plt.figure()
#         plt.stairs(self.relicurve_uncalib, self.alpha_edges, label = 'Uncalibrated')
#         plt.stairs(self.relicurve_calib, self.alpha_edges, label = 'Calibrated')
#         plt.plot(self.alpha_centers, self.ir.predict(self.alpha_centers), label = 'Fit')
#         plt.legend(loc = 2)
#         plt.plot((0, 1), (0, 1), 'k:')
#         plt.show()




In [None]:
irc = IsotonicRegressionCalibration(posts_uncalib, targets)    

posts_calib = irc.calibrate(posts_uncalib)
torch.save(posts_calib, os.path.join(tbl.experiment.get_logdir(), 'posts_calib.pt'))

log_irc = LogIRC(None, irc).plot()

In [None]:
# Log simulation inference
log_post = LogPost(None, posts_calib, targets, fig_kwargs = dict(dpi = 100, figsize = (4,3)), calib = 'calibrated')
log_post.plot_all()

# Test with original simulator

In [None]:
NORMALIZE = False

def get_testdata(simulator, infer, irc, mult, n_test = 4):
    test_posts, test_sims = [], []

    for _ in range(n_test):
        test_sim = simulator.sample(1)

        test_post_uncalib = infer.get_post(test_sim).squeeze(0)
        test_sim = infer.squeeze_obs(test_sim)
        test_post_uncalib = normalize(test_post_uncalib, mult) if NORMALIZE is True else test_post_uncalib
        test_post = irc.calibrate(test_post_uncalib)

        test_sims.append(test_sim)
        test_posts.append(test_post)

    test_posts = torch.stack(test_posts)       
    
    
    return test_sims, test_posts

def plot_testdata(test_sims, test_posts, zlog = True):
    vmax, vmin = test_posts.max(), test_posts.min()
    for test_sim, test_post in zip(test_sims, test_posts):
#         print(f'Msub = {coord[0,0].item()}')
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax)

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                       **kwargs, 
                       );
test_sims, test_posts = get_testdata(simulator, infer, irc, mult = 1, n_test = 4)
plot_testdata(test_sims, test_posts, zlog = True)

In [None]:
test_sims, test_posts = get_testdata(simulator, infer, irc, mult = 1, n_test = 4)
plot_testdata(test_sims, test_posts, zlog = True)

# External simulator

In [None]:
class Ext_post(Infer):
    def __init__(self, simulator, network, datamodule, n_ext):
        super().__init__(simulator, network, datamodule)
        self.simulator = simulator
        self.n_ext = n_ext
        
    def get_uncalib_posts(self):
        sims = self.simulator.sample(self.n_ext)
        posts, targets = self.get_posts2(sims, max_n_test = self.n_ext)
        return posts, targets

In [None]:
n_sub = 4
n_ext = 5_000

In [None]:
ext_simulator = hydra.utils.instantiate(cfg.simulation.model, n_sub = n_sub, part_empty = 0.)
ext_infer = Infer(ext_simulator, network, datamodule)

ext_sims = ext_simulator.sample(n_ext)
ext_posts_uncalib, ext_targets = ext_infer.get_posts2(ext_sims, n_ext)

In [None]:
ext_posts_uncalib_unnorm = ext_posts_uncalib
ext_posts_uncalib_norm   = normalize(ext_posts_uncalib, mult = n_sub)

In [None]:
ext_irc_unnorm = IsotonicRegressionCalibration(ext_posts_uncalib_unnorm, ext_targets)    
ext_posts_calib_unnorm = ext_irc_unnorm.calibrate(ext_posts_uncalib_unnorm)
LogIRC(None, ext_irc_unnorm).plot()

In [None]:
ext_irc_norm  = IsotonicRegressionCalibration(ext_posts_uncalib_norm , ext_targets)    
ext_posts_calib_norm = ext_irc_norm.calibrate(ext_posts_uncalib_norm)
LogIRC(None, ext_irc_norm).plot()

In [None]:
ext_test_sims, ext_test_posts = get_testdata(ext_simulator, ext_infer, ext_irc_unnorm, mult = n_sub, n_test = 3)
plot_testdata(ext_test_sims, ext_test_posts, zlog = True)

# Old stuff

In [None]:
assert 1 == 2

In [None]:
test_posts_calib = irc.calibrate(test_posts_uncalib)

In [None]:
test_posts_calib.shape, test_targets.shape

In [None]:
log_post = LogPost(None, test_posts_calib, test_targets, fig_kwargs = dict(dpi = 100, figsize = (4,3)), calib = 'calibrated')
log_post.plot_all()

In [None]:
# def get_test_post(n_test, test_simulator):
#     test_sim = test_simulator.sample(n_test)
    
#     test_post = infer.get_post(test_sim).squeeze(0)
#     test_sim = infer.squeeze_obs(test_sim)
    
#     return test_post, test_sim

In [None]:
def get_testdata(simulator, infer, irc. n_test = 4):
    test_posts, test_sims = [], []

    for _ in range(n_test):
        test_sim = simulator.sample(1)

        test_post = infer.get_post(test_sim).squeeze(0)
        test_sim = infer.squeeze_obs(test_sim)
        test_post = irc.calibrate(test_post_uncalib)

        test_sims.append(test_sim)
        test_posts.append(test_post)

    test_posts = torch.stack(test_posts)       
    
    
    return test_sims, test_posts

def plot_testdata(test_sims, test_posts, zlog = True):
    vmax, vmin = test_posts.max(), test_posts.min()
    for test_sim, test_post in zip(test_sims, test_posts):
#         print(f'Msub = {coord[0,0].item()}')
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax)

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                       **kwargs, 
                       );


In [None]:
for zlog in [True]:
    for test_sim, test_post in zip(test_sims, test_posts):
#         print(f'Msub = {coord[0,0].item()}')
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax)

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                       **kwargs, 
                       );


In [None]:
assert 1 == 2

## Mock

In [None]:
def get_mock_post(mock_z_sub):

    mock_simulator = hydra.utils.instantiate(cfg.simulation.model, z_sub_true = mock_z_sub)
    mock_sim = infer.unsqueeze_obs(mock_simulator.gen_mock())

    mock_post = infer.get_post(mock_sim).squeeze(0)
    mock_sim = infer.squeeze_obs(mock_sim)
    
    return mock_post, mock_sim

In [None]:
Ms = torch.tensor([8.5, 9.5, 10.5])
xs = torch.linspace(-1.5, 1.5, 1)
ys = torch.linspace(0.5, -2., 1)

# Ms = torch.tensor([8.5])
# xs = torch.linspace(-1.5, 1.5, 3)
# ys = torch.linspace(2, -2., 3)

In [None]:
mock_z_sub = torch.tensor([[0., 0., 0.]])
mock_post, mock_sim = get_mock_post(mock_z_sub)

mock_posts, mock_sims, coords = [mock_post], [mock_sim], [mock_z_sub]
M = Ms[0]
for M in tqdm(Ms):
    for y in ys:
        for x in xs:
            mock_z_sub = torch.tensor([[M, x, y]])
            mock_post, mock_sim = get_mock_post(mock_z_sub)
            mock_posts.append(mock_post)
            mock_sims.append(mock_sim)
            coords.append(mock_z_sub)
mock_posts = torch.stack(mock_posts)       
coords = torch.stack(coords)       
vmax, vmin = mock_posts.max(), mock_posts.min()

In [None]:
plot_trues = [False] + [True]*(len(mock_posts)-1)

for zlog in [False, True]:
    for mock_sim, mock_post, coord, plot_true in zip(mock_sims, mock_posts, coords, plot_trues):
#         print(f'Msub = {coord[0,0].item()}')
        logobs = LogObs(None, mock_sim, mock_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax)
        logobs.plot_msc(zlog = zlog, 
                        plot_true = plot_true,
                        title = rf'Sum posterios $= {torch.sum(mock_post).item():.2f}$, $M_{{ \rm{{sub}} }} = {coord[0,0].item():.2f}$',
#                         title = rf'$M_{{ \rm{{sub}} }} = {coord[0,0].item():.2f}$',
                       **kwargs, 
                       );


## Test

In [None]:
def get_test_post(n_test, test_simulator):
    test_sim = test_simulator.sample(n_test)
    
    test_post = infer.get_post(test_sim).squeeze(0)
    test_sim = infer.squeeze_obs(test_sim)
    
    return test_post, test_sim

In [None]:
def get_test_post(n_test, n_sub_test):
    test_simulator = hydra.utils.instantiate(cfg.simulation.model, n_sub = n_sub_test, part_empty = 0.)
    test_sim = test_simulator.sample(n_test)
    
    test_post = infer.get_post(test_sim).squeeze(0)
    test_sim = infer.squeeze_obs(test_sim)
    
    return test_post, test_sim

In [None]:
test_posts, test_sims, coords = [], [], []

for _ in range(7):
    test_post, test_sim = get_test_post(n_test = 1, n_sub_test = 4)

    test_posts.append(test_post)
    test_sims.append(test_sim)
    coords.append(test_sim['z_sub'])
test_posts = torch.stack(test_posts)       
coords = torch.stack(coords)       
vmax, vmin = test_posts.max(), test_posts.min()

In [None]:
for zlog in [True]:
    for test_sim, test_post, coord in zip(test_sims, test_posts, coords):
#         print(f'Msub = {coord[0,0].item()}')
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax)

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                       **kwargs, 
                       );


## Test 2

In [None]:
test_simulator = hydra.utils.instantiate(cfg.simulation.model, n_sub = 4, part_empty = 0.)
test_sim = test_simulator.sample(10_000)

In [None]:
test_posts, test_targets = infer.get_posts2(test_sim, max_n_test = 10_00)

In [None]:
log_post = LogPost(None, test_posts, test_targets, fig_kwargs = dict(dpi = 100, figsize = (4, 3)) )
log_post.plot_relicurve()

## Calibration

In [None]:
n_alpha = 50
post_data = PostData(test_posts, test_targets, n_alpha = n_alpha)

alpha_edges, alpha_centers = post_data.get_alpha(n_alpha = n_alpha)
alpha_edges, alpha_centers = alpha_edges.cpu(), alpha_centers.cpu()

In [None]:
relicurve = post_data.get_relicurve(n_alpha = n_alpha).cpu()

In [None]:
cir.calibrate()


In [None]:
alpha_centers_zero = torch.cat((torch.tensor([0]), alpha_centers))
relicurve_zero     = torch.cat((torch.tensor([0]), relicurve))

ir = IsotonicRegression(out_of_bounds = 'clip')
ir.fit(alpha_centers_zero, relicurve_zero);

relicurve_ir = ir.predict(alpha_centers)

plt.stairs(relicurve, alpha_edges)
plt.plot(alpha_centers, relicurve_ir )
plt.plot((0, 1), (0, 1), 'k:')

relicurve_ir

In [None]:
posts_calib = calibrate(test_posts, ir)

In [None]:
log_post = LogPost(None, posts_calib, test_targets, fig_kwargs = dict(dpi = 100, figsize = (4, 3)) )
log_post.plot_all()

## Plotting after calibration

In [None]:
test_posts_calib, test_sims_calib, coords_calib = [], [], []

for _ in range(4):
    test_post_calib, test_sim_calib = get_test_post(n_test = 1, n_sub_test = 4)
    test_post_calib = calibrate(test_post_calib, ir)

    test_posts_calib.append(test_post_calib)
    test_sims_calib.append(test_sim_calib)
    coords_calib.append(test_sim_calib['z_sub'])
test_posts_calib = torch.stack(test_posts_calib)       
coords_calib = torch.stack(coords_calib)       
vmax, vmin = test_posts_calib.max(), test_posts_calib.min()

In [None]:
for zlog in [True]:
    for test_sim, test_post, coord in zip(test_sims_calib, test_posts_calib, coords_calib):
#         print(f'Msub = {coord[0,0].item()}')
        logobs = LogObs(None, test_sim, test_post, prior, grid_coords)
        kwargs = dict(vmin = vmin, vmax = vmax)

        logobs.plot_msc(zlog = zlog, 
                        plot_true = True,
                        title = rf'Sum posterios $= {torch.sum(test_post).item():.2f}$',
                       **kwargs, 
                       );
