In [1]:
#!/usr/bin/env python
# coding: utf-8

import os, datetime
import torch, pyro, numpy as np
# torch.set_default_tensor_type(torch.cuda.FloatTensor)
from torch import tensor
import torch.nn as nn
import torchvision.transforms.functional as TF

import click


import swyft

DEVICE = 'cuda'

from utils import *

from swyft.utils import tensor_to_array, array_to_tensor
from toolz import compose
from pyrofit.lensing.distributions import get_default_shmf


In [2]:
m = 1
nsub = 3
nsim = 100

In [3]:
time_start = datetime.datetime.now()

# Set definitions (should go to click)
system_name = "ngc4414"

# Set utilities
sim_name, sim_path = get_sim_path(m, nsub, nsim, system_name)

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
CONFIG = get_config(system_name, str(nsub), str(m))
torch.set_default_tensor_type(torch.FloatTensor)

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
ppd = CONFIG.ppd()['model_trace'].nodes
torch.set_default_tensor_type(torch.FloatTensor)

Store _M_m1_nsub3_nsim100 exists!


In [4]:
config = CONFIG
main = config.umodel.alphas["main"]

In [5]:
main.sub.mass_sampler

InverseTransformDistribution()

In [6]:
main.sub.mass_sampler.sampling_distribution

Uniform(low: 0.0, high: 1.0)

In [7]:

def get_prior(CONFIG):
    config = CONFIG
    main = config.umodel.alphas["main"]
    prior_p_sub = main.sub.pos_sampler.base_dist
    m_sub_grid = main.sub.mass_sampler.y

    nsub = main.sub.nsub
    z_lens = config.kwargs['defs']['z_lens']

    lows = np.array([
            prior_p_sub.low[0].item(),
            prior_p_sub.low[1].item(),
            m_sub_grid.min().log10().item(),
        ])
    highs = np.array([
            prior_p_sub.high[0].item(),
            prior_p_sub.high[1].item(),
            m_sub_grid.max().log10().item(),
        ])
    
    uniform = torch.distributions.Uniform(array_to_tensor(lows[:-1]), array_to_tensor(highs[:-1]))
    shmf = get_default_shmf(z_lens = z_lens, log_range = (lows[-1], highs[-1]))

    parameter_dimensions = [2, 1]*nsub
    n_pars = sum(parameter_dimensions)
    
    prior = swyft.Prior.composite_prior(
        cdfs=list(map(swyft.Prior.conjugate_tensor_func, [uniform.cdf, shmf.cdf]*nsub)),
        icdfs=list(map(swyft.Prior.conjugate_tensor_func, [uniform.icdf, shmf.icdf]*nsub)),
        log_probs=list(map(swyft.Prior.conjugate_tensor_func, [uniform.log_prob, shmf.log_prob]*nsub)),
        parameter_dimensions=parameter_dimensions,
    )



    return prior, n_pars, lows, highs

# torch.set_default_tensor_type(torch.cuda.FloatTensor)  
prior, n_pars, lows, highs = get_prior(CONFIG)
# torch.set_default_tensor_type(torch.FloatTensor)

# samples = swyft.PriorTruncator(prior, bound=None).sample(1000_000)
# for i in range(6):
#     plt.hist(samples[:, i], bins=100, alpha=0.5)
#     plt.show()

In [8]:
L = CONFIG.kwargs["defs"]["nx"]
print(f'Image has L = {L}.')

assert nsub == CONFIG.umodel.alphas["main"].sub.nsub
print('m samples:', [f"{i:.2}" for i in ppd['main/sub/m_sub']['value']])


    
# Create Store
simulator = swyft.Simulator(model = lambda v: simul(v, CONFIG), 
                            parameter_names = n_pars,
                            sim_shapes={"image": (L, L)})
store = swyft.Store.directory_store(
    overwrite = True,
    path = sim_path, simulator = simulator)
store.add(nsim, prior)
store.simulate()

print('Done!')
print(f"Total creating time is {str(datetime.datetime.now() - time_start).split('.')[0]}!")

Image has L = 40.
m samples: ['1.3e+10', '1.2e+10', '1.1e+10']
Creating new store.
Store: Adding 107 new samples to simulator store.
Done!
Total creating time is 0:00:09!
