In [1]:
#!/usr/bin/env python
# coding: utf-8

import os, datetime
import torch, pyro, numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)
from torch import tensor
import torch.nn as nn
import torchvision.transforms.functional as TF

import click


import swyft

DEVICE = 'cuda'

from utils import *

from swyft.utils import tensor_to_array, array_to_tensor
from toolz import compose
from pyrofit.lensing.distributions import get_default_shmf


In [2]:
m = 0
nsub = 3
nsim = 200

In [3]:
time_start = datetime.datetime.now()

# Set definitions (should go to click)
system_name = "ngc4414"

# Set utilities
sim_name, sim_path = get_sim_path(m, nsub, nsim, system_name)

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
CONFIG = get_config(system_name, str(nsub), str(m))
torch.set_default_tensor_type(torch.FloatTensor)

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
ppd = CONFIG.ppd()['model_trace'].nodes
torch.set_default_tensor_type(torch.FloatTensor)

def get_priorr(config: Clipppy):
    """
    Set up subhalo parameter priors using a config
    """
    main = config.umodel.alphas["main"]
    prior_p_sub = main.sub.pos_sampler.base_dist
    m_sub_grid = main.sub.mass_sampler.y
    lows = np.array(
        [
            prior_p_sub.low[0].item(),
            prior_p_sub.low[1].item(),
            m_sub_grid.min().log10().item(),
        ]
    )
    highs = np.array(
        [
            prior_p_sub.high[0].item(),
            prior_p_sub.high[1].item(),
            m_sub_grid.max().log10().item(),
        ]
    )
    
    
    
    nsub = main.sub.nsub
    
    
    
    
    
    lows_u = array_to_tensor(np.tile(lows, nsub))
    highs_u = array_to_tensor(np.tile(highs, nsub))
    
    uniform = torch.distributions.Uniform(lows_u, highs_u)
    
    prior = swyft.Prior(
        cdf = compose(tensor_to_array, lambda x: uniform.cdf(x), array_to_tensor),
        icdf = compose(tensor_to_array, lambda x: uniform.icdf(x), array_to_tensor),
        log_prob = compose(tensor_to_array, lambda x: uniform.log_prob(x), array_to_tensor),
        n_parameters = nsub*3,
    )

    
    uv = lambda u: (highs_u - lows_u) * u + lows_u
    return prior, uv, lows, highs

prior, uv, lows, highs = get_priorr(CONFIG)

Store _M_m0_nsub3_nsim200 exists!


In [4]:
main = CONFIG.umodel.alphas["main"]

In [5]:
main.sub.mass_sampler.sampling_distribution

Uniform(low: 0.0, high: 1.0)

In [6]:
def get_prior(config: Clipppy):
    """
    Set up subhalo parameter priors using a config
    """
    main = config.umodel.alphas["main"]
    prior_p_sub = main.sub.pos_sampler.base_dist
    m_sub_grid = main.sub.mass_sampler.y
    
    nsub = main.sub.nsub
    z_lens = CONFIG.kwargs['defs']['z_lens']
    
    lows = np.array(
        [
            prior_p_sub.low[0].item(),
            prior_p_sub.low[1].item(),
            m_sub_grid.min().log10().item(),
        ]
    )
    highs = np.array(
        [
            prior_p_sub.high[0].item(),
            prior_p_sub.high[1].item(),
            m_sub_grid.max().log10().item(),
        ]
    )
    
    n_parameters = len(lows)*nsub
    uniform = torch.distributions.Uniform(array_to_tensor(lows[:-1]), array_to_tensor(highs[:-1]))
    shmf = get_default_shmf(z_lens = z_lens, log_range = (lows[-1], highs[-1]))
    
    
    
    
    
#     def cat(uniform, shmf, x, nsub):
# #     x = x.to(DEVICE)
#         return torch.cat(
#             [torch.cat(
#                     [uniform(x[:,3*i:3*i+2]), shmf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#                 ) for i in range(nsub)]
#             , dim = 1)

#     def cdf(x):      return cat(uniform.cdf, shmf.cdf, x, nsub)
#     def icdf(x):     return cat(uniform.icdf, shmf.icdf, x, nsub)
#     def log_prob(x): return cat(uniform.log_prob, shmf.log_prob, x, nsub)
    
    
    def cdf(x):      
        return torch.cat(
            [torch.cat(
                    [uniform.cdf(x[:,3*i:3*i+2]), shmf.cdf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
                ) for i in range(nsub)]
            , dim = 1)

    def icdf(x):     
        return torch.cat(
            [torch.cat(
                    [uniform.icdf(x[:,3*i:3*i+2]), shmf.icdf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
                ) for i in range(nsub)]
            , dim = 1)
    
    def log_prob(x): 
        return torch.cat(
            [torch.cat(
                    [uniform.log_prob(x[:,3*i:3*i+2]), shmf.log_prob(x[:, 3*i+2]).unsqueeze(1)], dim = 1
                ) for i in range(nsub)]
            , dim = 1)

    prior = swyft.Prior(
        
        
        
            cdf = compose(tensor_to_array, cdf, array_to_tensor),
            icdf = compose(tensor_to_array, icdf, array_to_tensor),
            log_prob = compose(tensor_to_array, log_prob, array_to_tensor),
            n_parameters = n_parameters,
        )
    
    return prior, cdf, icdf, log_prob, lows, highs, n_parameters

prior, cdf, icdf, log_prob, lows, highs, n_parameters = get_prior(CONFIG)
    
    
    
print('p')
    
    
    
#     lows_u = array_to_tensor(np.tile(lows, nsub))
#     highs_u = array_to_tensor(np.tile(highs, nsub))
    
#     uniform = torch.distributions.Uniform(lows_u, highs_u)
    
#     prior = swyft.Prior(
#         cdf = compose(tensor_to_array, lambda x: uniform.cdf(x), array_to_tensor),
#         icdf = compose(tensor_to_array, lambda x: uniform.icdf(x), array_to_tensor),
#         log_prob = compose(tensor_to_array, lambda x: uniform.log_prob(x), array_to_tensor),
#         n_parameters = nsub*3,
#     )

    
#     uv = lambda u: (highs_u - lows_u) * u + lows_u
#     return prior, uv, lows, highs

p


In [7]:
# # lows_u = array_to_tensor(np.tile(lows, nsub))
# # highs_u = array_to_tensor(np.tile(highs, nsub))

# # # uniform = torch.distributions.Uniform(lows_u, highs_u)
# # # shmf = get_default_shmf(z_lens = 0.5, log_range = (8.5, 10.5))

# # def cat(uniform, shmf, x, nsub):
# #     x = x.to(DEVICE)
# #     return torch.cat(
# #         [torch.cat(
# #                 [uniform(x[:,3*i:3*i+2]), shmf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
# #             ) for i in range(nsub)]
# #         , dim = 1)


# # uniform = torch.distributions.Uniform(array_to_tensor(lows[:-1]).to(DEVICE), array_to_tensor(highs[:-1]).to(DEVICE))
# # shmf = get_default_shmf(z_lens = 0.5, log_range = (10., 12.))


# # def cdf(x):      return cat(uniform.cdf, shmf.cdf, x, nsub)
# # def icdf(x):     return cat(uniform.icdf, shmf.icdf, x, nsub)
# # def log_prob(x): return cat(uniform.cdlog_probf, shmf.log_prob, x, nsub)

# # prior = swyft.Prior(
# #         cdf = compose(tensor_to_array, cdf, array_to_tensor),
# #         icdf = compose(tensor_to_array, icdf, array_to_tensor),
# #         log_prob = compose(tensor_to_array, log_prob, array_to_tensor),
# #         n_parameters = 3*nsub,
# #     )


# config = CONFIG
# main = config.umodel.alphas["main"]
# prior_p_sub = main.sub.pos_sampler.base_dist
# m_sub_grid = main.sub.mass_sampler.y

# nsub = main.sub.nsub
# z_lens = config.kwargs['defs']['z_lens']

# lows = np.array(
#     [
#         prior_p_sub.low[0].item(),
#         prior_p_sub.low[1].item(),
#         m_sub_grid.min().log10().item(),
#     ]
# )
# highs = np.array(
#     [
#         prior_p_sub.high[0].item(),
#         prior_p_sub.high[1].item(),
#         m_sub_grid.max().log10().item(),
#     ]
# )
# # lows = array_to_tensor(lows).to(DEVICE)
# # highs = array_to_tensor(highs).to(DEVICE)

# n_parameters = len(lows)*nsub
# uniform = torch.distributions.Uniform(array_to_tensor(lows[:-1]), array_to_tensor(highs[:-1]))
# shmf = get_default_shmf(z_lens = z_lens, log_range = (lows[-1], highs[-1]))

# def cdf(x):      
# #     x = x.to(DEVICE)
#     return torch.cat(
#         [torch.cat(
#                 [uniform.cdf(x[:,3*i:3*i+2]), shmf.cdf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#             ) for i in range(nsub)]
#         , dim = 1)

# def icdf(x):     
# #     x = x.to(DEVICE)
#     return torch.cat(
#         [torch.cat(
#                 [uniform.icdf(x[:,3*i:3*i+2]), shmf.icdf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#             ) for i in range(nsub)]
#         , dim = 1)

# def log_prob(x): 
# #     x = x.to(DEVICE)
#     return torch.cat(
#         [torch.cat(
#                 [uniform.log_prob(x[:,3*i:3*i+2]), shmf.log_prob(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#             ) for i in range(nsub)]
#         , dim = 1)

# prior = swyft.Prior(
#         cdf = compose(tensor_to_array, cdf, array_to_tensor),
#         icdf = compose(tensor_to_array, icdf, array_to_tensor),
#         log_prob = compose(tensor_to_array, log_prob, array_to_tensor),
#         n_parameters = n_parameters,
#     )

# samples = swyft.PriorTruncator(prior, bound=None).sample(1000_000)
# for i in range(3):
#     plt.hist(samples[:, i], bins=100, alpha=0.5)
#     plt.show()

In [14]:
L = CONFIG.kwargs["defs"]["nx"]
print(f'Image has L = {L}.')

assert nsub == CONFIG.umodel.alphas["main"].sub.nsub
if m > 4:
    assert all([i == pow(10, m) for i in ppd['main/sub/m_sub']['value']])
else:
    print(f'm = {m} <= 0!', ppd['main/sub/m_sub']['value'])



    
# Create Store
pnames = [f'{z}_{i+1}' for i in range(nsub) for z in ['x', 'y', 'm']]
n_pars = int(nsub * 3)
print(n_pars)
simulator = swyft.Simulator(model = lambda v: simul(v, CONFIG), 
                            parameter_names = n_pars,
#                                 pnames = pnames,
                            sim_shapes={"image": (L, L)})

store = swyft.Store.directory_store(path = sim_path, simulator = simulator, overwrite = True)



# def get_prior(CONFIG):
config = CONFIG
main = config.umodel.alphas["main"]
prior_p_sub = main.sub.pos_sampler.base_dist
m_sub_grid = main.sub.mass_sampler.y

nsub = main.sub.nsub
z_lens = config.kwargs['defs']['z_lens']

lows = np.array([
        prior_p_sub.low[0].item(),
        prior_p_sub.low[1].item(),
        m_sub_grid.min().log10().item(),
    ])
highs = np.array([
        prior_p_sub.high[0].item(),
        prior_p_sub.high[1].item(),
        m_sub_grid.max().log10().item(),
    ])

n_parameters = len(lows)*nsub
uniform = torch.distributions.Uniform(array_to_tensor(lows[:-1]), array_to_tensor(highs[:-1]))
shmf = get_default_shmf(z_lens = z_lens, log_range = (lows[-1], highs[-1]))

def cat(uniform, shmf, x, nsub):
    return torch.cat(
        [torch.cat(
                [uniform(x[:,3*i:3*i+2]), shmf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
            ) for i in range(nsub)]
        , dim = 1)

def cdf(x):      return cat(uniform.cdf, shmf.cdf, x, nsub)
def icdf(x):     return cat(uniform.icdf, shmf.icdf, x, nsub)
def log_prob(x): return cat(uniform.log_prob, shmf.log_prob, x, nsub)
    

# def cdf(x):      
# #     x = x.to(DEVICE)
#     return torch.cat(
#         [torch.cat(
#                 [uniform.cdf(x[:,3*i:3*i+2]), shmf.cdf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#             ) for i in range(nsub)]
#         , dim = 1)

# def icdf(x):     
# #     x = x.to(DEVICE)
#     return torch.cat(
#         [torch.cat(
#                 [uniform.icdf(x[:,3*i:3*i+2]), shmf.icdf(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#             ) for i in range(nsub)]
#         , dim = 1)

# def log_prob(x): 
# #     x = x.to(DEVICE)
#     return torch.cat(
#         [torch.cat(
#                 [uniform.log_prob(x[:,3*i:3*i+2]), shmf.log_prob(x[:, 3*i+2]).unsqueeze(1)], dim = 1
#             ) for i in range(nsub)]
#         , dim = 1)

prior = swyft.Prior(
        cdf = compose(tensor_to_array, cdf, array_to_tensor),
        icdf = compose(tensor_to_array, icdf, array_to_tensor),
        log_prob = compose(tensor_to_array, log_prob, array_to_tensor),
        n_parameters = n_parameters,
    )
#     return prior

# prior = get_prior(CONFIG)

store.add(nsim, prior)
store.simulate()

print('Done!')
print(f"Total creating time is {str(datetime.datetime.now() - time_start).split('.')[0]}!")

Image has L = 40.
m = 0 <= 0! tensor([8.5089e+08, 4.6566e+08, 3.2541e+08], device='cuda:0')
9
Creating new store.
Store: Adding 184 new samples to simulator store.
Done!
Total creating time is 0:05:21!


In [15]:
store[0]

({'image': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]])},
 array([-2.48098803e+00, -1.00415218e+00,  3.45343712e+08,  3.99880886e-01,
         8.26270819e-01,  7.85417408e+08,  2.06034136e+00,  9.18305159e-01,
         5.75203584e+08]))

In [9]:
assert 1 == 2

AssertionError: 

In [None]:
import torch
import scipy.stats
import numpy as np

import pylab as plt

from toolz import compose
from swyft.prior import Prior, PriorTruncator
from swyft.utils import tensor_to_array, array_to_tensor

In [None]:
loc = torch.tensor([1.0, -3.2])
scale = torch.tensor([0.1, 2.1])

upper = torch.tensor([2.5, 3.1])
lower = torch.tensor([0.1, -2.0])

n_parameters = len(loc) + len(upper)

In [None]:
normal = torch.distributions.Normal(loc, scale)
uniform = torch.distributions.Uniform(lower, upper)

composite_prior_torch = Prior(
    cdf=compose(tensor_to_array, lambda x: torch.cat([normal.cdf(x[:, :2]), uniform.cdf(x[:, 2:4])], dim=1), array_to_tensor),
    icdf=compose(tensor_to_array, lambda x: torch.cat([normal.icdf(x[:, :2]), uniform.icdf(x[:, 2:4])], dim=1), array_to_tensor),
    log_prob=compose(tensor_to_array, lambda x: torch.cat([normal.log_prob(x[:, :2]), uniform.log_prob(x[:, 2:4])], dim=1), array_to_tensor),
    n_parameters=n_parameters,
)

In [None]:
samples = PriorTruncator(composite_prior_torch, bound=None).sample(10_000)
for i in range(n_parameters):
    _ = plt.hist(samples[:, i], bins=100, alpha=0.5)

In [None]:
normal

In [None]:
h = plt.hist(plot, bins = 100)
def cdf(x, y): return uniform.cdf(x)
def icdf(x, y): return uniform.icdf(x)
def log_prob(x, y): return uniform.log_prob(x)


prior = swyft.Prior(
        cdf = compose(tensor_to_array, cdf, array_to_tensor),
        icdf = compose(tensor_to_array, icdf, array_to_tensor),
        log_prob = compose(tensor_to_array, log_prob, array_to_tensor),
        n_parameters = 3*nsub,
    )

In [None]:
cdf(torch.tensor([[0., 0., 10., 1., 1., 9., -1., -1., 10.]]))

In [None]:
samples = swyft.PriorTruncator(prior, bound=None).sample(10_000)
for i in range(3*nsub):
    _ = plt.hist(samples[:, i], bins=100, alpha=0.5)

In [None]:
lows_u = array_to_tensor(np.tile(lows, nsub))
highs_u = array_to_tensor(np.tile(highs, nsub))

uniform = torch.distributions.Uniform(lows_u, highs_u)

prior = swyft.Prior(
        cdf = compose(tensor_to_array, lambda x: uniform.cdf(x), array_to_tensor),
        icdf = compose(tensor_to_array, lambda x: uniform.icdf(x), array_to_tensor),
        log_prob = compose(tensor_to_array, lambda x: uniform.log_prob(x), array_to_tensor),
        n_parameters = nsub*3,
    )

In [None]:
lows_2 = torch.full((6,), -2.5)
highs_2 = torch.full((6,), 2.5)


In [None]:
def cat(uniform, shmf, x, nsub):
    return torch.cat([torch.cat([uniform(x[:,3*i:3*i+2]).squeeze(), shmf(x[:, 3*i+2])]) for i in range(nsub)])


uniform = torch.distributions.Uniform(array_to_tensor(lows[:-1]), array_to_tensor(highs[:-1]))
shmf = get_default_shmf(z_lens = 0.5, log_range = (8.5, 10.5))

s = torch.cat([uniform.sample(), shmf.sample().unsqueeze(0), uniform.sample(), shmf.sample().unsqueeze(0), uniform.sample(), shmf.sample().unsqueeze(0)]).unsqueeze(0)
cat(uniform.cdf, shmf.cdf, s, 3)

In [None]:
[(uniform(x[:,3*i:3*i+2]), shmf(x[:, 3*i+2])) for i in range(nsub)]

In [None]:
lows

In [None]:
ones = torch.ones((4, 2))
zeros = torch.zeros((4, 1))
ones.shape, zeros.shape

In [None]:
torch.cat([ones, zeros], dim = 1).shape

In [None]:
.to(DEVICE)