# testing ANPE deployment on NSA observations

In [17]:
import os, sys
import numpy as np
from scipy import stats
from sedflow import obs as Obs
from sedflow import train as Train

# torch
import torch
from sbi import utils as Ut
from sbi import inference as Inference

In [14]:
y_nsa = Obs.load_nsa_data(test_set=False)

  return 22.5 - 2.5 * np.log10(flux)


In [15]:
####################################################################
# set prior (this is fixed)
####################################################################
prior_low = [7, 0., 0., 0., 0., 1e-2, np.log10(4.5e-5), np.log10(4.5e-5), 0, 0., -2.]
prior_high = [12.5, 1., 1., 1., 1., 13.27, np.log10(1.5e-2), np.log10(1.5e-2), 3., 3., 1.]
lower_bounds = torch.tensor(prior_low)
upper_bounds = torch.tensor(prior_high)
prior = Ut.BoxUniform(low=lower_bounds, high=upper_bounds, device='cpu')

In [19]:
####################################################################
# load trained ANPE
####################################################################
sample = 'toy'
nhidden = 500
nblocks = 15
itrain = 2
####################################################################

x_test, y_test = Train.load_data('test', version=1, sample=sample, params='thetas_unt')
x_test[:,6] = np.log10(x_test[:,6])
x_test[:,7] = np.log10(x_test[:,7])

fanpe = os.path.join(Train.data_dir(), 'anpe_thetaunt_magsigz.%s.%ix%i.%i.pt' % (sample, nhidden, nblocks, itrain))

anpe = Inference.SNPE(prior=prior,
                      density_estimator=Ut.posterior_nn('maf', hidden_features=nhidden, num_transforms=nblocks),
                      device='cpu')
anpe.append_simulations(
    torch.as_tensor(x_test.astype(np.float32)),
    torch.as_tensor(y_test.astype(np.float32)))

p_x_y_estimator = anpe._build_neural_net(torch.as_tensor(x_test.astype(np.float32)), torch.as_tensor(y_test.astype(np.float32)))
p_x_y_estimator.load_state_dict(torch.load(fanpe))

anpe._x_shape = Ut.x_shape_from_simulation(torch.as_tensor(y_test.astype(np.float32)))

hatp_x_y = anpe.build_posterior(p_x_y_estimator)

In [29]:
print('%i hyperparameters' % np.sum(p.numel() for p in p_x_y_estimator.parameters() if p.requires_grad))

7890330 hyperparameters


  """Entry point for launching an IPython kernel.


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [22]:
def get_posterior(y_nsa_i, nmcmc=10000):
    ''' given [mag, uncertainty, redshift] of a galaxy, draw nmcmc samples of
    the posterior.
    '''
    mcmc_anpe = hatp_x_y.sample((nmcmc,), x=torch.as_tensor(y_nsa_i), show_progress_bars=True)
    return np.array(mcmc_anpe)

In [23]:
mcmcs = []
for igal in range(10): 
    print(y_nsa[igal])
    _mcmc_i = get_posterior(y_nsa[igal])
    mcmcs.append(_mcmc_i)

[16.887062   15.21636    14.365496   14.013487   13.719948    0.05103989
  0.02009773  0.02006188  0.0200624   0.03011609  0.02122228]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[17.542395   16.210804   15.504496   15.134066   14.883438    0.05418632
  0.02033832  0.02022357  0.0202281   0.03081442  0.04808582]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[18.724672   17.836632   17.48639    17.34852    17.25803     0.08334129
  0.02342823  0.02330136  0.02555846  0.0504447   0.02196459]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[17.03926    16.0428     15.576517   15.317082   15.157151    0.05287183
  0.02037018  0.02023692  0.02025547  0.03092435  0.03580831]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[18.777647   17.285528   16.502043   16.070406   15.781486    0.07161
  0.02104094  0.02056519  0.02051609  0.03152744  0.05248393]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[17.30959    16.590717   16.423431   16.44567    16.356396    0.06063322
  0.02149479  0.02150889  0.02329385  0.04339407  0.02170876]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[18.011139   16.898869   16.385351   16.133677   16.082958    0.06666344
  0.02111826  0.02089052  0.0210252   0.03584207  0.03645184]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[17.999893   16.201597   15.335937   14.95873    14.678669    0.05635887
  0.02043227  0.02025319  0.02026885  0.03052281  0.03673035]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[16.337324   14.524672   13.599525   13.20084    12.906909    0.05323015
  0.02013912  0.02007489  0.02006928  0.03015652  0.02160817]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

[17.072449   16.039417   15.612441   15.384242   15.329765    0.05636014
  0.02039787  0.02040346  0.02051542  0.03212593  0.02165231]


Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

In [26]:
np.array(mcmcs)[-2]

array([[11.117858  ,  0.6711677 ,  0.40776756, ...,  0.8489683 ,
         0.5990936 , -1.0390607 ],
       [10.769588  ,  0.23237726,  0.27514288, ...,  0.05772961,
         1.0135201 , -0.9459942 ],
       [11.037788  ,  0.15098369,  0.5215622 , ...,  1.0278299 ,
         0.3157818 , -1.7855717 ],
       ...,
       [11.054487  ,  0.88227826,  0.43112966, ...,  1.6429694 ,
         0.6993135 , -0.6503178 ],
       [10.905409  ,  0.94002116,  0.3469195 , ...,  1.8281566 ,
         0.66839224, -0.7238141 ],
       [10.663375  ,  0.3489581 ,  0.1223333 , ...,  0.63686687,
         0.756173  , -1.1358496 ]], dtype=float32)