In [9]:

import jax
jax.config.update('jax_enable_x64',True)
import jax.random as random
from jax import jit

import pickle

In [10]:


import numpyro
from NSF import NeuralSpline1D
from flow import Normal,Flow,transform,Serial
from SkewNormalPlus import SkewNormalPlus as snp

In [11]:
from jax import random
import jax_cosmo as jc
from jax_cosmo import Cosmology, background
import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax.scipy.stats import norm

In [12]:

@jit
def make_training_set(key,z_s,M0=-19.5,sigma_int=0.1,h=0.7324,Om0=0.28,w=-1,mu_cut=21,sigma_cut=0.01):




    cosmo_jax = Cosmology(Omega_c=Om0, h=h, w0=w, Omega_b=0, n_s= 0.96, sigma8=200000, Omega_k=0, wa=0)

    n_sne=len(z_s)

    d_s=background.transverse_comoving_distance(cosmo_jax, 1/(1+z_s))

    mu_s=5*np.log10((1+z_s)/h*d_s*1e6/10)

    m0_s=M0+mu_s

    m_s=random.normal(key, (len(m0_s),))*sigma_int+m0_s

    key, _ = random.split(key)

    p_s=norm.cdf(-m_s, loc=-mu_cut, scale=sigma_cut)

    sel_s=random.bernoulli(key, p=p_s)

    return m0_s,m_s,sel_s


@jit
def minmax_fit_and_scale(X):
  max= np.max(X,axis=0)
  min = np.min(X,axis=0)
  X_std = (X - min) / (max-min)
  return X_std,min,max

@jit
def minmax_scale(X,min,max):
  return (X - min) / (max - min)

@jit
def minmax_unscale(X,min,max):
  return X * (max - min) + min




In [13]:
from jax.example_libraries import stax, optimizers

from jax.example_libraries.stax import (Dense, Tanh, Flatten, Relu, LogSoftmax, Softmax, Exp,Sigmoid,Softplus,LeakyRelu)

def network(rng,conditional_dim,out_dim, hidden_dim):
    init_fun,apply_fun=stax.serial(stax.Dense(hidden_dim), Relu, stax.Dense(1024),  Relu, stax.Dense(1024),Relu,stax.Dense(hidden_dim), Relu, stax.Dense(out_dim),)
    _, params = init_fun(rng, (conditional_dim,))

    return params,apply_fun
'''
def network(rng,conditional_dim,out_dim, hidden_dim):
    init_fun,apply_fun=stax.serial(stax.Dense(hidden_dim), Tanh, stax.Dense(1024),Tanh,stax.Dense(hidden_dim), Tanh, stax.Dense(out_dim),)
    _, params = init_fun(rng, (conditional_dim,))
    return params,apply_fun
'''
rng, flow_rng = random.split(random.PRNGKey(0))

init_fun = Flow(Serial(*(NeuralSpline1D(network,hidden_dim=256,K=20,B=3),)*5),prior=Normal())

params, log_pdf, sample = init_fun(flow_rng,2 )

In [14]:
opt_init, opt_update, get_params = optimizers.adam(step_size=5e-4)
opt_state = opt_init(params)

In [15]:
def loss_fn(params, inputs):
    return -log_pdf(params, inputs).mean()

@jit
def step(i, opt_state, inputs):
    params = get_params(opt_state)

    loss,gradients = value_and_grad(loss_fn)(params,inputs)
    return loss, opt_update(i, gradients, opt_state)



In [16]:
from tqdm.notebook import trange
import itertools
import numpy.random as npr
itercount = itertools.count()
sample_batch=1000
batch_size=1000
import logging
itercount = itertools.count()
from IPython.display import clear_output

from matplotlib import pyplot as plt

batch_size=500



for epoch in range(100000):
    
    
  X = np.empty((0,3))
 
  for s_epoch in range(5):
    
      permute_rng, rng = random.split(rng)
      sigma_int =   np.absolute(random.normal(rng,(1,)))*0.2  
      #sigma_int =   random.uniform(rng,(1,))*0.5

  
      sigma_cut=0.01
      mu_cut= 21

      permute_rng, rng = random.split(rng)


      mu=random.uniform(rng,(10,))*(21+2*sigma_int[0]-20.5) + 20.5


      s = snp(m_int=mu,sigma_int=sigma_int,m_cut=mu_cut,sigma_cut=sigma_cut)

      permute_rng, rng = random.split(rng)
      

      samps=s.sample(rng,(20,10))
        
      permute_rng, rng = random.split(rng)
    
    
      

      theta = np.column_stack((np.repeat(mu,20).T.reshape(200,1),np.repeat(sigma_int,200).reshape(200,1)))
    
      
        
      X_ = np.column_stack((samps.T.reshape(200,1),theta))
    
      X = np.append(X,X_,axis=0)

  m0 = X[:,1]

  X = X.at[:, 0].set(X[:, 0] - m0)
  if epoch == 0:
      X,min,max=minmax_fit_and_scale(X)

    
  else:
      X  = minmax_scale(X,min,max)



  loss,opt_state = step(next(itercount), opt_state, X)
  params = get_params(opt_state)

  print('epoch: ',epoch,' loss: ', loss)

  
  if epoch % 1000 == 0:
    
    c=['red','green','blue','yellow','pink']
    mm = np.array([20.7,20.8,20.9,21,21.1])
    bins=np.linspace(20.3,21.1,100)
    no_samps = 50000


    plt.figure(figsize=(12,6))

    for i,m in enumerate(mm):
    


            samp=sample(rng, params,minmax_scale(np.column_stack((np.repeat(np.array([m]),no_samps).reshape(no_samps,1),np.array([0.1]*no_samps).reshape(no_samps,1))),min[1:],max[1:]),no_samps)
            samp = minmax_unscale(samp[:,0],min[0],max[0]) + m

            plt.hist(samp,density=True,bins=bins,color=c[i],histtype='step',label='$m_0=$'+str(m)+',$\sigma_{int}=0.1$,$m_{cut}=$'+str(mu_cut)+',$\sigma_{cut}=$'+str(sigma_cut))
            s=snp(np.array([m]),np.array([0.1]),21,0.01)
            plt.plot(bins,np.exp(s.log_prob(bins)),color=c[i])


            

    plt.legend(loc='center left',frameon=False,fontsize=12)
    plt.xlabel('$\hat{m}_s$',fontsize=11)
    plt.ylabel('$p(\hat{m}_s|m_0,\sigma_{int},m_{cut},\sigma_{cut})$',fontsize=12)
    plt.title('Normalising Flow',fontsize=12)

    trained_params = optimizers.unpack_optimizer_state(opt_state)
    pickle.dump(trained_params, open('flow2d_opt.pkl', "wb"))


    plt.show()

trained_params = optimizers.unpack_optimizer_state(opt_state)
pickle.dump(trained_params, open('flow2d_opt.pkl', "wb"))


TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [None]:
from matplotlib import pyplot as plt
c=['red','green','blue','yellow','pink']
mm = np.array([20.7,20.8,20.9,21,21.1])
bins=np.linspace(20.3,21.1,100)
no_samps = 100000


plt.figure(figsize=(12,6))



for i,m in enumerate(mm):
    

    samp=sample(rng, params,minmax_scale(np.column_stack((np.repeat(np.array([m]),no_samps).reshape(no_samps,1),np.array([0.1]*no_samps).reshape(no_samps,1))),min[1:],max[1:]),no_samps)
            
    samp = minmax_unscale(samp[:,0],min[0],max[0])


    plt.hist(samp,density=True,bins=bins,color=c[i],histtype='step',label='$m_0=$'+str(m)+',$\sigma_{int}=$'+str('0.1')+',$m_{cut}=$'+str(mu_cut)+',$\sigma_{cut}=$'+str(sigma_cut))
    s=snp(np.array([m]),np.array([0.1]),21,0.01)
    plt.plot(bins,np.exp(s.log_prob(bins)),color=c[i])
    
    print(np.min(samp),np.max(samp))

    #plt.hist(s.sample(rng,(1000,1)).reshape(1000,),density=True,histtype='step')
    

plt.legend(loc='upper right',frameon=False,fontsize=12)
plt.xlabel('$\hat{m}_s$',fontsize=11)
plt.ylabel('$p(\hat{m}_s|m_0,\sigma_{int},m_{cut},\sigma_{cut})$',fontsize=12)
plt.title('Normalising Flow',fontsize=12)
plt.xlim(21.1,20.3)
plt.show()

In [None]:
print(min,max)

In [None]:
from matplotlib import pyplot as plt
c=['red','green','blue','yellow','pink']
mm = np.array([20.7,20.8,20.9,21,21.1])
bins=np.linspace(20.3,21.1,100)
no_samps = 100000


plt.figure(figsize=(12,6))



for i,m in enumerate(mm):
    

    samp=sample(rng, params,minmax_scale(np.column_stack((np.repeat(np.array([m]),no_samps).reshape(no_samps,1),np.array([0.2]*no_samps).reshape(no_samps,1))),min[1:],max[1:]),no_samps)
            
    samp = minmax_unscale(samp[:,0],min[0],max[0])


    plt.hist(samp,density=True,bins=bins,color=c[i],histtype='step',label='$m_0=$'+str(m)+',$\sigma_{int}=$'+str('0.2')+',$m_{cut}=$'+str(mu_cut)+',$\sigma_{cut}=$'+str(sigma_cut))
    s=snp(np.array([m]),np.array([0.2]),21,0.01)
    plt.plot(bins,np.exp(s.log_prob(bins)),color=c[i])
    
    print(np.min(samp),np.max(samp))

    #plt.hist(s.sample(rng,(1000,1)).reshape(1000,),density=True,histtype='step')
    

plt.legend(loc='upper right',frameon=False,fontsize=12)
plt.xlabel('$\hat{m}_s$',fontsize=11)
plt.ylabel('$p(\hat{m}_s|m_0,\sigma_{int},m_{cut},\sigma_{cut})$',fontsize=12)
plt.title('Normalising Flow',fontsize=12)
plt.xlim(21.1,20.3)
plt.show()

In [None]:
from matplotlib import pyplot as plt
c=['red','green','blue','yellow','pink']
mm = np.array([20.7,20.8,20.9,21,21.1])
bins=np.linspace(20.3,21.1,100)
no_samps = 100000


plt.figure(figsize=(12,6))



for i,m in enumerate(mm):
    

    samp=sample(rng, params,minmax_scale(np.column_stack((np.repeat(np.array([m]),no_samps).reshape(no_samps,1),np.array([0.05]*no_samps).reshape(no_samps,1))),min[1:],max[1:]),no_samps)
            
    samp = minmax_unscale(samp[:,0],min[0],max[0])


    plt.hist(samp,density=True,bins=bins,color=c[i],histtype='step',label='$m_0=$'+str(m)+',$\sigma_{int}=$'+str('0.05')+',$m_{cut}=$'+str(mu_cut)+',$\sigma_{cut}=$'+str(sigma_cut))
    s=snp(np.array([m]),np.array([0.05]),21,0.01)
    plt.plot(bins,np.exp(s.log_prob(bins)),color=c[i])
    
    print(np.min(samp),np.max(samp))

    #plt.hist(s.sample(rng,(1000,1)).reshape(1000,),density=True,histtype='step')
    

plt.legend(loc='upper right',frameon=False,fontsize=12)
plt.xlabel('$\hat{m}_s$',fontsize=11)
plt.ylabel('$p(\hat{m}_s|m_0,\sigma_{int},m_{cut},\sigma_{cut})$',fontsize=12)
plt.title('Normalising Flow',fontsize=12)
plt.xlim(21.1,20.3)
plt.show()

In [None]:
from matplotlib import pyplot as plt
c=['red','green','blue','yellow','pink']
mm = np.array([20.7,20.8,20.9,21,21.1])
bins=np.linspace(20.3,21.1,100)
no_samps = 100000


plt.figure(figsize=(12,6))



for i,m in enumerate(mm):
    

    samp=sample(rng, params,minmax_scale(np.column_stack((np.repeat(np.array([m]),no_samps).reshape(no_samps,1),np.array([0.01]*no_samps).reshape(no_samps,1))),min[1:],max[1:]),no_samps)
            
    samp = minmax_unscale(samp[:,0],min[0],max[0])


    plt.hist(samp,density=True,bins=bins,color=c[i],histtype='step',label='$m_0=$'+str(m)+',$\sigma_{int}=$'+str('0.01')+',$m_{cut}=$'+str(mu_cut)+',$\sigma_{cut}=$'+str(sigma_cut))
    s=snp(np.array([m]),np.array([0.01]),21,0.01)
    plt.plot(bins,np.exp(s.log_prob(bins)),color=c[i])
    
    print(np.min(samp),np.max(samp))

    #plt.hist(s.sample(rng,(1000,1)).reshape(1000,),density=True,histtype='step')
    

plt.legend(loc='upper right',frameon=False,fontsize=12)
plt.xlabel('$\hat{m}_s$',fontsize=11)
plt.ylabel('$p(\hat{m}_s|m_0,\sigma_{int},m_{cut},\sigma_{cut})$',fontsize=12)
plt.title('Normalising Flow',fontsize=12)
plt.xlim(21.1,20.3)
plt.show()