In [1]:
import numpy as onp
import pandas as pd
import torch
import sklearn

# visualization
import seaborn
import matplotlib as mpl
import matplotlib.pyplot as plt
import corner

%load_ext autoreload
%autoreload 2

In [2]:
from tqdm.auto import tqdm
import itertools
import numpy.random as npr
import jax
import optax 


In [3]:
import matplotlib.pyplot as plt
from sklearn import datasets, preprocessing
from IPython.display import clear_output
import jax.numpy as np

import flows

from jax import grad, jit, random
from jax.experimental import stax, optimizers
import numpyro

In [4]:
rng, LR_rng = random.split(random.PRNGKey(0))


def bce_w_logits(params, data, labels, average=True):
    """
    Binary Cross Entropy Loss
    Should be numerically stable, built based on: https://github.com/pytorch/pytorch/issues/751
    :param data: Input tensor
    :param labels: Target tensor
    :param average: Boolean to average resulting loss vector
    :return: Scalar value
    """
    max_val = np.clip(x, 0, None)
    loss = x - x * y + max_val + np.log(np.exp(-max_val) + np.exp((-x - max_val)))

    if weight is not None:
        loss = loss * weight

    if average:
        return loss.mean()
    else:
        return loss.sum()
    
def loss(params, batch):
    data, labels = batch
    return -bce_w_logits(params, data, labels)

@jit
def step(params, opt_state, batch):
    nll, grads = jax.value_and_grad(loss)(params.fast, batch)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    return nll, optax.apply_updates(params, updates), opt_state

In [5]:

def hmc(flow_params, obs, init_theta,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=True,
        step_size=1e0,
        max_tree_depth=12,
        num_warmup=100,
        num_samples=50,
        num_chains=1,
        seed=1234):
    
    def log_prior(theta):
        """unit gaussian about 0
        """
        dim = theta.shape[-1]
        return -0.5*(dim*np.log(2*np.pi) + np.log(1) + (theta)**2).sum()
    
    def posterior_wrapper(theta):
        if len(theta.shape) == 1:
            theta = theta[None, :]
        inputs = np.hstack([obs, theta])
        log_post = log_pdf(flow_params, inputs) # + log_prior(theta)
        return -log_post[0]

    hmc_key = random.PRNGKey(seed)
    nuts_kernel = numpyro.infer.NUTS(potential_fn=posterior_wrapper, 
                                     adapt_step_size=adapt_step_size, 
                                     adapt_mass_matrix=adapt_mass_matrix,
                                     dense_mass=dense_mass,
                                     step_size=step_size, 
                                     max_tree_depth=max_tree_depth)
    mcmc = numpyro.infer.MCMC(nuts_kernel, 
                              num_samples=num_samples, 
                              num_warmup=num_warmup, 
                              num_chains=num_chains)

    #TODO: make sure x0 is being used in posterior (because it's not being used in run)
    mcmc.run(hmc_key, init_params=init_theta)
    return mcmc

In [6]:
import dataset
import sklearn.preprocessing

N = 1500000

obs_mean, obs_std = dataset.obs.mean(dim=0), dataset.obs.std(dim=0)
params_mean, params_std = dataset.params.mean(dim=0), dataset.params.std(dim=0)

param_dim = dataset.params.shape[-1]
obs_dim = dataset.obs.shape[-1]

params = (dataset.params - params_mean)/params_std
obs = (dataset.obs - obs_mean)/obs_std


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  f['chg2'][np.where(np.isnan(f['chg2']))[0]]=-1


In [7]:
data = pd.DataFrame(np.hstack([obs[:N].numpy(), params[:N].numpy()]))
data.head()

Unnamed: 0,0,1,2,3,4,5
0,1.785297,0.075091,0.939147,-1.884812,0.691027,0.759124
1,-0.459451,0.131942,-0.748742,1.705266,0.770638,-0.543799
2,-1.230811,0.418187,1.139939,1.218724,0.747513,1.859668
3,1.213942,0.423448,1.633857,-0.320276,1.76339,1.04044
4,1.185564,-0.487451,1.317851,-0.897858,1.889171,-1.316012


In [8]:
training_dataloader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(data.values)), 
                                                  batch_size=128, 
                                                  shuffle=True)

In [37]:
init_random_params, predict = stax.serial(
                                          Dense(1024), Relu,
                                          Dense(1024), Relu,
                                          Dense(10), 
                                          LogSoftmax)

In [38]:
max_norm = 0.1
learning_rate = 1e-4
sync_period = 5
slow_step_size = 0.5

fast_optimizer = optax.chain(
    # Set the parameters of AdamW. 
    optax.adamw(learning_rate=learning_rate, b1=0.9, b2=0.999, eps=1e-8),
    optax.adaptive_grad_clip(max_norm),
)
optimizer = optax.lookahead(fast_optimizer, sync_period=sync_period, slow_step_size=slow_step_size)

# The lookahead optimizer wrapper keeps a pair of slow and fast parameters. To
# initialize them, we create a pair of synchronized parameters from the
# initial model parameters. The first line below is only necessary for the
# lookahead wrapper; without it the initial parameters could be used in the
# initialization function of the optimizer directly.
params = optax.LookaheadParams.init_synced(initial_params)
opt_state = optimizer.init(params)
losses = []
results = []
best_results = np.inf


In [39]:
# create observation
observation = torch.tensor([[0.12, 125.]])
observation = (observation - obs_mean)/obs_std
observation = observation.numpy()

# find the most likely point in the training set to set initial params
temp = data.copy()
temp[[0, 1]] = observation

In [None]:
num_steps = 5000
batch_size = 128
itercount = itertools.count()
try:
    for step_num in tqdm(range(num_steps)):
        batch, = next(iter(training_dataloader))

        nll, params, opt_state = step(params, opt_state, batch.numpy())
        losses.append(nll)

        # see how well the observations are recovered
        if step_num % 500 == 0:
            likelihoods = log_pdf(params.slow, temp.values)
            res = data.iloc[np.argmax(likelihoods).item()].values[:obs_dim]
            if sum(res) < best_results:
                best_results = sum(res)
                trained_params = params.slow.copy()
            results.append((observation - res)/(observation + obs_mean.numpy()/obs_std.numpy()))
except KeyboardInterrupt:
    pass

  0%|          | 0/5000 [00:00<?, ?it/s]

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()
plt.plot(np.abs(np.vstack(results)))
plt.yscale('log')

In [None]:
likelihoods = log_pdf(trained_params, temp.values)

plt.hist(likelihoods, bins='fd', range=(-20, likelihoods.max()));

In [None]:
# find the most likely point in the training set to set initial params
temp = data.copy()
temp[[0, 1]] = observation
likelihoods = log_pdf(trained_params, temp.values)

# put together init_params
num_chains = 1
init_theta = data.iloc[np.argmax(likelihoods).item()].values[obs_dim:]
init_theta = np.array([init_theta for _ in range(num_chains)])

mcmc = hmc(flow_params=trained_params, 
           obs=observation*1.0, 
           init_theta=init_theta,
           adapt_step_size=True,
           adapt_mass_matrix=True,
           dense_mass=True,
           step_size=4e-2,
           max_tree_depth=10,
           num_warmup=1000,
           num_samples=1000,
           num_chains=num_chains,
           seed=1234,
          )

In [None]:
mcmc.print_summary()

In [None]:
samples = mcmc.get_samples()
samples = samples*params_std.numpy() + params_mean.numpy()
# samples = (samples - dataset.params.min(axis=0).values.numpy())/(dataset.params.max(axis=0).values.numpy() - dataset.params.min(axis=0).values.numpy())

In [None]:
Omega, delta_Omega = 0.12, 0.0012
mh, delta_mh = 125., 2.

mean = np.hstack([obs_mean, params_mean])
std = np.hstack([obs_std, params_std])

temp = data*std + mean

correct_Omega_idx = (temp[0] <= Omega + delta_Omega*2) & (temp[0]  >= Omega - delta_Omega*2)
correct_mh_idx = (temp[1] <= mh + delta_mh*2) & (temp[1] >= mh - delta_mh*2)


correct_omega_cMSSM = temp[correct_Omega_idx]
correct_mh_cMSSM = temp[correct_mh_idx]
correct_combo_cMSSM = temp[correct_mh_idx & correct_Omega_idx]

dset_samples = correct_combo_cMSSM.values[:, obs_dim:]
# dset_samples = (dset_samples - dataset.params.min(axis=0).values.numpy())/(dataset.params.max(axis=0).values.numpy() - dataset.params.min(axis=0).values.numpy())


param_cols = ['m0', 'm12', 'a0', 'tanb']
fig = corner.corner(onp.array(samples), 
             labels=param_cols, 
#              range=[(0, 1) for i in range(4)],
             title='valid cMSSM', 
             color='C1',
             hist_kwargs={'color':'C1',"density":True});

fig = corner.corner(dset_samples, 
             title='Correct Omega and mh', 
#              range=[[0, 1] for i in param_cols], 
                     
             color='C0',
                    hist_kwargs={'color':'C0',"density":True}, 
                    fig=fig,
                   );


In [19]:
dset_samples.shape, dataset.params.shape

((107, 4), torch.Size([866710, 4]))

In [20]:
164/866746

0.00018921344892275247