In [None]:
import sys
sys.path.append("..")

import jax, os, corner
import jax.numpy as jnp
from jax import grad, config
import matplotlib.pyplot as plt
import numpy as np
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)
from functools import partial

# Load reparameterization methods
from src.reparameterization import sigma, logistic_CDF, reparameterized_gradient

# Load birth/death method
from src.birth_death import birth_death

print(jax.devices())

In [None]:
def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

# Mirrored Langevin Birth/Death (MLBD)

In [None]:
# Annealing schedule
gamma = lambda t: 1

def ula_kernel(key, param, log_post, grad_log_post, dt, iteration, lower, upper, stride, rate, bandwidth, p):
    """ 
    Remarks
    -------
    
    (1) grad_log_post is the gradient of the potential. Change name of this later
    (2) stride = 1 will run birth-death at every iteration!
    
    """
    key, subkey = jax.random.split(key)

    # Transform to unbounded domain 
    Y, gmlpt_Y = reparameterized_gradient(param, grad_log_post, lower, upper, gamma(iteration))

    # Perform jumps in unbounded domain
    jumps = jax.lax.cond(jnp.mod(iteration, stride) == 0, lambda: birth_death(key, param, log_post, bandwidth=bandwidth, p=p, rate=rate, a=lower, b=upper, gamma=gamma(iteration)), lambda: jnp.arange(param.shape[0]))
    # jumps = jnp.arange(param.shape[0])

    # Perform update in unbounded domain
    Y = Y[jumps] - gmlpt_Y[jumps] * dt + jnp.sqrt(2 * dt) * jax.random.normal(key=subkey, shape=(param.shape))

    # Convert samples back to bounded domain 
    param = sigma(logistic_CDF(Y), lower, upper)

    iteration = iteration + 1

    return key, param, iteration

@partial(jax.jit, static_argnums=(1,2,3))
def ula_sampler_full_jax_jit(key, log_post, grad_log_post, n_iter, dt, x_0, lower, upper, stride=1, rate=1, bandwidth=0.01, p=2):

    # @progress_bar_scan(n_iter)
    # @scan_tqdm(1000)
    # @scan_tqdm(n_iter, print_rate=1, desc='progress bar', position=0, leave=False)
    def ula_step(carry, x):
        key, param, iteration = carry
        key, param, iteration = ula_kernel(key, param, log_post, grad_log_post, dt, iteration, lower, upper, stride, rate, bandwidth, p)
        return (key, param, iteration), param

    carry = (key, x_0, 0)
    _, samples = jax.lax.scan(ula_step, carry, None, n_iter)
    # _, samples = scan(ula_step, carry, None, n_iter)
    return samples

# Unit tests

In [None]:
def rejection_sampling(iid_samples, lower_bound, upper_bound):
    truth_table = ((iid_samples > lower_bound) & (iid_samples < upper_bound))
    idx = np.where(np.all(truth_table, axis=1))[0]
    print('%i samples obtained from rejection sampling' % idx.shape[0])
    return np.array(iid_samples[idx])

In [None]:
""" 
Remarks
-------
(1) 2d MoG with increasing weights to the right
(2) Sampler teleports out all particles from smallest mode given enough time
(3) Can be mitigated using a small `rate`

"""

from models.mog_new import MoG 

k = 3
d = 2
weights = jnp.array([2, 4, 5])

mus = jnp.zeros((k, d))
mus = mus.at[0].set(jnp.array([-15, 0]))
# mus = mus.at[0].set(jnp.array([-10, 0]))
mus = mus.at[1].set(jnp.array([0, 0]))
# mus = mus.at[2].set(jnp.array([10, 0]))
mus = mus.at[2].set(jnp.array([15, 0]))

covs = jnp.zeros((k, d))
covs = covs.at[0].set(jnp.ones(d))
covs = covs.at[1].set(jnp.ones(d))
covs = covs.at[2].set(jnp.ones(d))

lower_bound = jnp.array([-15, -15])
upper_bound = jnp.array([15, 15])

model = MoG(weights, mus, covs, lower_bound, upper_bound)

iid_samples = model.newDrawFromPosterior(1000000)

bounded_iid_samples = rejection_sampling(iid_samples, model.lower_bound, model.upper_bound)

In [None]:
# Setup and run sampler
n_iter = 20000
n_particles = 500
eps = 1e-3
stride = 100
rate = 0.01
bandwidth = 0.001
p = 2
X0 = model._newDrawFromPrior(n_particles)
key = jax.random.PRNGKey(0)
sam = ula_sampler_full_jax_jit(key, jax.vmap(model.potential), jax.vmap(jax.jacfwd(model.potential)), n_iter, eps, X0, model.lower_bound, model.upper_bound, stride, rate)
# sam = ula_sampler_full_jax_jit(key, jax.vmap(model.potential), jax.vmap(jax.jacfwd(model.potential)), n_iter, eps, sam[-1], model.lower_bound, model.upper_bound, stride, rate)

In [None]:
# Plot
import matplotlib.lines as mlines
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], sam.shape[2])))
fig = corner.corner(bounded_iid_samples[-20000:], hist_kwargs={'density':True}, truths=jnp.mean(bounded_iid_samples, axis=0), color='k') 
labels = [r'$x_1$', r'$x_2$']
k_line = mlines.Line2D([], [], color='k', label='Truth')
r_line = mlines.Line2D([], [], color='r', label='MLBD')
corner.corner(reshaped_matrix[-20000:], color='r', fig=fig, hist_kwargs={'density':True}, labels=labels)
plt.legend(handles=[k_line,r_line], bbox_to_anchor=(0., 1.0, 1., .0), loc=4)

In [None]:
""" 
Remarks
-------
(1) 2d MoG with all modes of similar weight
"""

from models.mog_new import MoG 

k = 3
d = 2
weights = jnp.array([5, 4, 5])

mus = jnp.zeros((k, d))
mus = mus.at[0].set(jnp.array([-10, 0]))
mus = mus.at[1].set(jnp.array([0, 0]))
mus = mus.at[2].set(jnp.array([10, 0]))

covs = jnp.zeros((k, d))
covs = covs.at[0].set(jnp.ones(d))
covs = covs.at[1].set(jnp.ones(d))
covs = covs.at[2].set(jnp.ones(d))

lower_bound = jnp.array([-15, -15])
upper_bound = jnp.array([15, 15])

model = MoG(weights, mus, covs, lower_bound, upper_bound)


iid_samples = model.newDrawFromPosterior(1000000)

bounded_iid_samples = rejection_sampling(iid_samples, model.lower_bound, model.upper_bound)

In [None]:
# Setup and run sampler
n_iter = 20000
n_particles = 200
eps = 1e-2
stride = 100
rate = 0.01
bandwidth = 0.1
p = 0.5
X0 = model._newDrawFromPrior(n_particles)
key = jax.random.PRNGKey(0)
sam = ula_sampler_full_jax_jit(key, jax.vmap(model.potential), jax.vmap(jax.jacfwd(model.potential)), n_iter, eps, X0, model.lower_bound, model.upper_bound, stride, rate)

In [None]:
# Plot
import matplotlib.lines as mlines
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], sam.shape[2])))
fig = corner.corner(bounded_iid_samples[-20000:], hist_kwargs={'density':True}, truths=jnp.mean(bounded_iid_samples, axis=0), color='k') 
labels = [r'$x_1$', r'$x_2$']
k_line = mlines.Line2D([], [], color='k', label='Truth')
r_line = mlines.Line2D([], [], color='r', label='MLBD')
corner.corner(reshaped_matrix[-20000:], color='r', fig=fig, hist_kwargs={'density':True}, labels=labels)
plt.legend(handles=[k_line,r_line], bbox_to_anchor=(0., 1.0, 1., .0), loc=4)

In [None]:
""" 
Remarks
-------
(1) 15d MoG with all modes of similar weight
"""

from models.mog_new import MoG 

k = 3
d = 15
weights = jnp.array([5, 4, 5])

mus = jnp.zeros((k, d))

mus = mus.at[0].set(jnp.array([-10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
mus = mus.at[1].set(jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
mus = mus.at[2].set(jnp.array([10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

covs = jnp.zeros((k, d))
covs = covs.at[0].set(jnp.ones(d))
covs = covs.at[1].set(jnp.ones(d))
covs = covs.at[2].set(jnp.ones(d))

lower_bound = jnp.ones(d) * -15 
upper_bound = jnp.ones(d) * 15 

model = MoG(weights, mus, covs, lower_bound, upper_bound)


iid_samples = model.newDrawFromPosterior(1000000)

bounded_iid_samples = rejection_sampling(iid_samples, model.lower_bound, model.upper_bound)

In [None]:
# Setup and run sampler
n_iter = 20000
n_particles = 200
eps = 1e-3
stride = 100
rate = 0.01
bandwidth = 100
p = 2
X0 = model._newDrawFromPrior(n_particles)
key = jax.random.PRNGKey(1)
sam = ula_sampler_full_jax_jit(key, jax.vmap(model.potential), jax.vmap(jax.jacfwd(model.potential)), n_iter, eps, X0, model.lower_bound, model.upper_bound, stride, rate)

In [None]:
# Plot
import matplotlib.lines as mlines
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], sam.shape[2])))
fig = corner.corner(bounded_iid_samples[-20000:], hist_kwargs={'density':True}, truths=jnp.mean(bounded_iid_samples, axis=0), color='k') 
# labels = [r'$x_1$', r'$x_2$']
k_line = mlines.Line2D([], [], color='k', label='Truth')
r_line = mlines.Line2D([], [], color='r', label='MLBD')
corner.corner(reshaped_matrix[-20000:], color='r', fig=fig, hist_kwargs={'density':True})#, labels=labels)
plt.legend(handles=[k_line,r_line], bbox_to_anchor=(0., 1.0, 1., .0), loc=4)

In [None]:
""" 
Remarks
-------
(1) 15d MoG with some mass at the boundary
"""

from models.mog_new import MoG 

k = 3
d = 15
weights = jnp.array([0.5, 1, 2])

mus = jnp.zeros((k, d))

mus = mus.at[0].set(jnp.array([-15, 0, 0, 9, 0, 0, 0, 5, 0, 0, 0, 12, 0, 7, 0]))
mus = mus.at[1].set(jnp.array([0, 0, 0, -15, 0, 0, 4, 0, 10, 0, 0, 15, 0, 8, 0]))
mus = mus.at[2].set(jnp.array([10, 0, 0, 0, 3, 0, 2, 0, 0, 0, 0, 0, 0, 0, 15]))

covs = jnp.zeros((k, d))
covs = covs.at[0].set(jnp.ones(d))
covs = covs.at[1].set(jnp.ones(d))
covs = covs.at[2].set(jnp.ones(d))

lower_bound = jnp.ones(d) * -15 
upper_bound = jnp.ones(d) * 15 

model = MoG(weights, mus, covs, lower_bound, upper_bound)


iid_samples = model.newDrawFromPosterior(1000000)

bounded_iid_samples = rejection_sampling(iid_samples, model.lower_bound, model.upper_bound)

In [None]:
# Setup and run sampler
n_iter = 50000
n_particles = 500
eps = 1e-2
stride = 100
rate = 0.001
bandwidth = 1
p = 2
X0 = model._newDrawFromPrior(n_particles)
key = jax.random.PRNGKey(1)
sam = ula_sampler_full_jax_jit(key, jax.vmap(model.potential), jax.vmap(jax.jacfwd(model.potential)), n_iter, eps, X0, model.lower_bound, model.upper_bound, stride, rate)

In [None]:
# Plot
import matplotlib.lines as mlines
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], sam.shape[2])))
fig = corner.corner(bounded_iid_samples[-20000:], hist_kwargs={'density':True}, truths=jnp.mean(bounded_iid_samples, axis=0), color='k') 
# labels = [r'$x_1$', r'$x_2$']
k_line = mlines.Line2D([], [], color='k', label='Truth')
r_line = mlines.Line2D([], [], color='r', label='MLBD')
corner.corner(reshaped_matrix[-20000:], color='r', fig=fig, hist_kwargs={'density':True})#, labels=labels)
plt.legend(handles=[k_line,r_line], bbox_to_anchor=(0., 1.0, 1., .0), loc=4)

# Truncated Gaussian Kernel tests

In [None]:
# def multivariate_normal(X, mu, sigma):
#     separation_vectors = X[:, jnp.newaxis, :] - mu[jnp.newaxis, :, :]
#     arg_exp = 

#     -jnp.sum((jnp.abs(separation_vectors) ** p) / (p * bandwidth), axis=-1)
#     tmp = jnp.exp(-(separation_vectors ** 2) / (2 * sigma ** 2)) / (sigma * jnp.sqrt(2 * jnp.pi))

In [None]:
import jax
import jax.numpy as jnp

def indicator(x, a, b):
    return jnp.prod(jnp.heaviside(x - a, 1) * jnp.heaviside(b - x, 1), axis=-1)

# Univariate Gaussian CDF
F = lambda arg, mu, sigma: 0.5 * (1 + jax.scipy.special.erf((arg - mu) / (sigma * jnp.sqrt(2)))) # Trivially extends to d > 1

def trunc_gaussian(x, mu, sigma, a, b):
    renormalization = jnp.prod(F(b, mu, sigma) - F(a, mu, sigma), axis=-1)
    return indicator(x, a, b) * jax.scipy.stats.multivariate_normal.pdf(x, mu, jnp.diag(sigma ** 2)) / renormalization

trunc_gaussian_batch = jax.vmap(trunc_gaussian, in_axes=(None, 0, None, None, None), out_axes=1)

In [None]:
# 1d Plots
a = jnp.array([-1])
b = jnp.array([2])
mu = jnp.array([0, 1, 2])[..., None]
sigma = jnp.ones(len(a))
xs = jnp.linspace(a-0.5, b+0.5, 200)
plt.plot(xs, trunc_gaussian_batch(xs, mu, sigma, a, b))

# Confirm normalization in 1D is performed properly
from scipy.integrate import trapz 
xs = jnp.linspace(a-0.5, b+0.5, 200) 
ys = trunc_gaussian_batch(xs, mu, sigma, a, b)
for i in range(3):
    print(trapz(ys[:,i], xs[:,i]))

In [None]:
print(mu.shape, xs.shape, sigma.shape, a.shape, b.shape)

In [None]:
mu.shape

In [None]:
test(xs, mu, sigma, a, b)

1.0060972

In [49]:
xs.shape

(200, 1)

In [None]:
test(xs, mu, sigma, a, b).shape

In [None]:
test(xs, mu[None,...], sigma, a, b).shape

In [None]:
print(mu[None,...].shape, xs.shape)

In [None]:
a = jnp.array([-1, -1])
b = jnp.array([2, 2])
mu = jnp.array([0, 0])
sigma = jnp.ones(2)
xs = jnp.linspace(a, b, 200)
trunc_gaussian(xs, mu, sigma, a, b)

In [None]:
import numpy as np
a = jnp.array(np.random.rand(3, 2))
b = jnp.array(np.random.rand(3, 2))

jax.scipy.stats.multivariate_normal.pdf(a, b, jnp.diag(jnp.ones(2)))

In [None]:
a

In [None]:
import matplotlib.pyplot as plt
a = 0.
b = 3.
mu = 2.
sigma = jnp.ones(1)
xs = jnp.linspace(a-1, b+1, 1000)[..., None]
ys = trunc_gaussian(xs, mu, sigma, a, b) # Already batched over xs!!! Can we batch over mus as well?

mus = jnp.array([-0.5, 1, 2])[..., None]

test_func_batch = jax.vmap(trunc_gaussian, in_axes=1)

batch_evals = test_func_batch(xs, mus, sigma, a, b) # Already batched over xs!!! Can we batch over mus as well?


# plt.plot(xs, ys.squeeze())

In [None]:
mus.shape

In [None]:
indicator(xs + 5, a, b).shape

In [None]:
import matplotlib.pyplot as plt

plt.plot(xs, ys.squeeze())

In [None]:
a = jnp.array([-2.])
b = jnp.array([3.])
xs = jnp.linspace(a, b, 100)
mu = jnp.array([3.])
sigma = jnp.array([1.])

output = test(xs, mu, sigma, a, b)
# plt.plot(xs, )



In [None]:
print(a, b)

In [None]:
xs.shape

In [None]:
# Batched multivariate normal for diagonal cov
def indicator(X, a, b):
    return jnp.heaviside(X - a, 1) * jnp.heaviside(b - X, 1)

def gaussian_batch(X, Y, sigma):
    """ 
    X - Nd matrix
    Y - Md matrix
    sigma - diagonal covariance. d sized matrix
    Returns: NM matrix
    TODO: confirm that this works for the case d=1 as well, and when theres no axis.
    
    """
    separation_vectors = X[:, jnp.newaxis, :] - Y[jnp.newaxis, :, :]
    arg_exp = 

    -jnp.sum((jnp.abs(separation_vectors) ** p) / (p * bandwidth), axis=-1)
    tmp = jnp.exp(-(separation_vectors ** 2) / (2 * sigma ** 2)) / (sigma * jnp.sqrt(2 * jnp.pi))
    return jnp.prod(tmp, ax

F = lambda arg, Y, sigma: 0.5 * (1 + jnp.scipy.special.erf((arg - Y) / (sigma * jnp.sqrt(2))))

def truncated_gaussian_kernel(X, Y, sigma, a, b):
   return indicator(X, a, b) * multivariate_gaussian_batch(X, Y, sigma) / (F(b, Y, sigma) - F(a, Y, sigma))
# In principle, this should agree with the scipy implementation

In [None]:
import numpy as np
a = jnp.array(np.random.rand(3, 3, 2))
b = jnp.array(np.random.rand(2))
c = jnp.array(np.random.rand(3,2))
# (a / b).shape

In [None]:
print(a.shape, b.shape)

In [None]:
(c-b) / b

In [None]:
len(sigma)

In [None]:
import jax
import jax.numpy as jnp
mu = jnp.array([2, 3])
sigma = jnp.array([0.5, 0.1])
x = jnp.array([[1, 1], [2,2], [3,3]])
y = jnp.array([[0.1, 0.2], [0.1,0.2], [0.3,0.4]])
# jax.scipy.stats.norm.pdf(x, mu, sigma)
jax.scipy.stats.norm.pdf(x, y, sigma[None, ...])

In [None]:
print(x.shape, y.shape)

# GW150914

In [None]:
from models.gw150914 import gwfast_LVGW150914

# Initialize model
model = gwfast_LVGW150914(wf_model='IMRPhenomD', nbins=100, verbose=True)
# model = gwfast_LVGW150914(wf_model='TaylorF2', nbins=100, verbose=True)

# Center periodic coordinates 
# for param in ['Phicoal', 'psi', 'phi']:
#     x = model.injParams[param][0]
#     delta = x - (model.priorDict[param][1] + model.priorDict[param][0]) / 2
#     model.priorDict[param][0] += delta
#     model.priorDict[param][1] += delta

# model.lower_bound = model.lower_bound.at[4].set(model.priorDict['phi'][0])
# model.upper_bound = model.upper_bound.at[4].set(model.priorDict['phi'][1])

# model.lower_bound = model.lower_bound.at[8].set(model.priorDict['Phicoal'][0])
# model.upper_bound = model.upper_bound.at[8].set(model.priorDict['Phicoal'][1])

# model.lower_bound = model.lower_bound.at[6].set(model.priorDict['psi'][0])
# model.upper_bound = model.upper_bound.at[6].set(model.priorDict['psi'][1])

In [None]:
# Setup and run sampler
n_iter = 20000
n_particles = 200

eps = 1e-6 * jnp.ones(model.DoF)
eps = eps.at[1].set(1e-5)
eps = eps.at[7].set(1e-5)
eps = eps.at[9].set(1e-7)
eps = eps.at[10].set(1e-7)

stride = n_iter + 1
rate = 1e-6
bandwidth = 100
p = 2
X0 = model._newDrawFromPrior(n_particles)
key = jax.random.PRNGKey(0)

# bandwidth = jnp.ones(model.DoF) * 100
# bandwidth = bandwidth.at[0].set()
# bandwidth = bandwidth.at[1].set()
# bandwidth = bandwidth.at[2].set()
# bandwidth = bandwidth.at[3].set()
# bandwidth = bandwidth.at[4].set()
# bandwidth = bandwidth.at[5].set()
# bandwidth = bandwidth.at[6].set()
# bandwidth = bandwidth.at[7].set()
# bandwidth = bandwidth.at[8].set()
# bandwidth = bandwidth.at[9].set(1)
# bandwidth = bandwidth.at[10].set(1)

sam = ula_sampler_full_jax_jit(key, model.minusLogLikelihood, model.gradient_minusLogLikelihood, n_iter, eps, X0, model.lower_bound, model.upper_bound, stride, rate)

In [None]:
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], 11)))
fig = corner.corner(reshaped_matrix[-5000:], hist_kwargs={'density':True}, labels=model.gwfast_param_order, truths=model.true_params)

In [None]:

n_iter = 200000
n_particles = 200
# eps = 1e-6
eps = jnp.ones(9) * 1e-7
eps = eps.at[1].set(1e-5)


subset_params = jnp.array([0, 1, 2, 3, 5, 6, 7, 9, 10])
lower = model.lower_bound[subset_params] 
upper = model.upper_bound[subset_params]

# injection wrapper
DoF = len(model.gwfast_param_order)
injection = np.zeros(DoF)
for d in range(DoF):
    injection[d] = model.injParams[model.gwfast_param_order[d]]

X_injection = np.tile(injection, n_particles).reshape(n_particles, DoF)
X_injection_gpu = jnp.array(X_injection)

# Initial draw from reduced prior
X0_all = model._newDrawFromPrior(n_particles)
X0_subset = X0_all[:, subset_params]

# Fix eta coordinate for testing!
# X0_subset = X0_subset.at[:,1].set(jnp.ones(n_particles) * injection[1])

def potential_subset(X_red):
    X_ = X_injection_gpu.at[:, subset_params].set(X_red)
    return model.minusLogLikelihood(X_)

def gradient_subset(X_red):
    X_ = X_injection_gpu.at[:, subset_params].set(X_red)
    return model.gradient_minusLogLikelihood(X_)[:, subset_params]
    


In [None]:
import matplotlib.pyplot as plt
n_iter = 100000
plt.plot(np.arange(n_iter), gamma(np.arange(n_iter), T=n_iter, c=3, p=5))

In [None]:
from jax_tqdm import scan_tqdm

In [None]:
# Settings for run
n_iter = 100000

# n_iter = 100
n_particles = 200
eps = 1e-6

X0 = model._newDrawFromPrior(n_particles)
key = jax.random.PRNGKey(0)
# eps = jnp.ones(11) * 5e-7
# eps = eps.at[1].set(5e-5)
# eps = eps.at[4].set(1e-4)
# eps = eps.at[8].set(1e-4)

# For mixture of Gaussian model
# sam = ula_sampler_full_jax_jit(key, jax.vmap(model.potential), jax.vmap(jax.jacfwd(model.potential)), n_iter, eps, X0)

# For subset
# sam = ula_sampler_full_jax_jit(key, potential_subset, gradient_subset, n_iter, eps, X0_subset)

# For whole damn thing


sam = ula_sampler_full_jax_jit(key, model.minusLogLikelihood, model.gradient_minusLogLikelihood, n_iter, eps, X0, model.lower_bound, model.upper_bound)

In [None]:
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], 11)))
fig = corner.corner(reshaped_matrix[-50000:], hist_kwargs={'density':True}, labels=model.gwfast_param_order, truths=model.true_params)

In [None]:
sam = ula_sampler_full_jax_jit(key, model.minusLogLikelihood, model.gradient_minusLogLikelihood, n_iter, eps, sam[-1], model.lower_bound, model.upper_bound)

In [None]:
# Draw several particles from prior
X = model._newDrawFromPrior(3)

# Francesco derivative
test1 = model.gradient_minusLogLikelihood(X)

# Take gradient of likelihood directly
# f = jax.jit(jax.jacfwd(model.minusLogLikelihood))
# gl = jax.jacobian(model.minusLogLikelihood)

# Check to see if this works
# gl(X[0].squeeze())

# f = jax.vmap(jax.jacfwd(model.minusLogLikelihood))
# f(X)
# model.gradient_minusLogLikelihood(X)

In [None]:
test1

In [None]:
np.allclose(test2[np.arange(3), np.arange(3), :], test1)

In [None]:
test2 = f(X)

In [None]:
a.shape

In [None]:
print(model.gwfast_param_order)

In [None]:
sam.shape

In [None]:
reshaped_matrix = np.array(sam.reshape((sam.shape[0] * sam.shape[1], 11)))
fig = corner.corner(reshaped_matrix[-50000:], hist_kwargs={'density':True}, labels=model.gwfast_param_order, truths=model.true_params)

In [None]:
sam = ula_sampler_full_jax_jit(key, potential_subset, gradient_subset, n_iter, eps, sam[-1])

In [None]:
reshaped_matrix = np.array(sam.reshape((n_iter * n_particles, len(subset_params))))
labels = np.array(model.gwfast_param_order)[subset_params]
fig = corner.corner(reshaped_matrix[-50000:], hist_kwargs={'density':True}, labels=labels, truths=injection[subset_params])

In [None]:
import corner
reshaped_matrix = sam.reshape((n_iter * n_particles, model.DoF))
reshaped_matrix = np.array(reshaped_matrix)
fig = corner.corner(reshaped_matrix[-20000:], hist_kwargs={'density':True}, truths=jnp.mean(bounded_iid_samples, axis=0)) # For rosenbrock
# fig = corner.corner(np.array(reshaped_matrix[-60000:]), hist_kwargs={'density':True}, truths=jnp.mean(bounded_iid_samples, axis=0)) # For rosenbrock
corner.corner(bounded_iid_samples[-20000:], color='r', fig=fig, hist_kwargs={'density':True})


# GW PROBLEM