In [1]:
import numpy as np
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import jax
from tensorflow_probability.substrates.jax import distributions as tfd
numpyro.set_host_device_count(4)

In [2]:
import pickle

# Unpickle the variables
with open("/Users/christopher/git/ComputableAstronomicalDiaries/PythonExperiments/vars.pkl", "rb") as f:
    loaded_variables = pickle.load(f)

# Access the variables
n_objects = loaded_variables["n_objects"]
n_references = loaded_variables["n_references"]
n_times = loaded_variables["n_times"]
time_range = loaded_variables["time_range"]
cubits = loaded_variables["cubits"]
cubits_mask = loaded_variables["cubits_mask"]
objects = loaded_variables["objects"]
objects_mask = loaded_variables["objects_mask"]
references = loaded_variables["references"]
references_mask = loaded_variables["references_mask"]
axes = loaded_variables["axes"]
axes_mask = loaded_variables["axes_mask"]
signs = loaded_variables["signs"]
signs_mask = loaded_variables["signs_mask"]
years = loaded_variables["years"]
years_mask = loaded_variables["years_mask"]
months = loaded_variables["months"]
months_mask = loaded_variables["months_mask"]
earliest_days = loaded_variables["earliest_days"]
latest_days = loaded_variables["latest_days"]
year_month_julian_dates = loaded_variables["year_month_julian_dates"]
julian_date_index = loaded_variables["julian_date_index"]
object_reference_axis_date_positions = loaded_variables["object_reference_axis_date_positions"]
times = loaded_variables["times"]
times_mask = loaded_variables["times_mask"]

In [3]:
from numpyro.ops.indexing import Vindex
def model():
    # number of observations
    n = len(objects)

    # length of a cubit
    length_cubit = numpyro.sample('length_cubit', dist.TruncatedNormal(2.0, 1.0, low=0.0))

    # observation variance
    distance_variance = numpyro.sample('distance_variance', dist.Gamma(0.5,0.5))

    # outlier distribution parameters
    mu_outlier = numpyro.sample('mu_outlier', dist.Normal(0,1))
    sigma_outlier = numpyro.sample('sigma_outlier', dist.Gamma(2.0,0.5))

    # categorical priors
    object_dist = numpyro.sample('object_dist', dist.Dirichlet(jnp.ones(n_objects)))
    reference_dist = numpyro.sample('reference_dist', dist.Dirichlet(jnp.ones(n_references)))
    axis_dist = numpyro.sample('axis_dist', dist.Dirichlet(jnp.ones(2)))
    sign_dist = numpyro.sample('sign_dist', dist.Dirichlet(jnp.ones(2)))

    # outlier probability prior
    q = numpyro.sample('q', dist.Beta(1/2,1))

    with numpyro.plate('observations', n):

        # observed objects
        object_probs = jax.nn.one_hot(objects,n_objects)
        object_probs = object_probs.at[objects_mask,:].set(object_dist)
        latent_object = numpyro.sample('latent_object', dist.Categorical(probs=object_probs))
        # latent_object = jnp.abs(objects)
    
        # observed references
        reference_probs = jax.nn.one_hot(references,n_references)
        reference_probs = reference_probs.at[references_mask,:].set(reference_dist)
        latent_reference = numpyro.sample('latent_reference', dist.Categorical(probs=reference_probs))
        # latent_reference = jnp.abs(references)
    
        # observed relation axes
        axis_probs = jax.nn.one_hot(axes,2)
        axis_probs = axis_probs.at[axes_mask,:].set(axis_dist)
        latent_axis = numpyro.sample('latent_axis', dist.Categorical(probs=axis_probs))
        # latent_axis = jnp.abs(axes)
    
        # observed relation signs
        sign_probs = jax.nn.one_hot(signs,2)
        sign_probs = sign_probs.at[signs_mask,:].set(sign_dist)
        latent_sign_id = numpyro.sample('latent_sign', dist.Categorical(probs=sign_probs))
        # latent_sign_id = jnp.abs(signs)
        latent_sign = latent_sign_id*2-1
    
        # observation dates
        date = julian_date_index[year_month_julian_dates[years, months] + earliest_days - 1]
    
        # the observation times of different observations
        time = numpyro.sample('time', dist.Normal(0.0,6))
        
        # true distances at the given observation times
        distance_range = Vindex(object_reference_axis_date_positions)[latent_object, latent_reference, latent_axis, date]
        # distance_range = object_reference_axis_date_positions[latent_object, latent_reference, latent_axis, date]
    
        true_distance = ((distance_range[...,1]-distance_range[...,0])/(time_range[1] - time_range[0])*(time-time_range[1]))+distance_range[...,1]

        # outlier mixture model
        cat = dist.Categorical(probs=jnp.array([1-q,q]))
        inlier_dist = dist.Normal(latent_sign * true_distance / length_cubit, distance_variance)
        outlier_dist = dist.Normal(mu_outlier, sigma_outlier)
        mix = dist.Mixture(cat, [inlier_dist, outlier_dist])
        # mix = inlier_dist
    
        with numpyro.handlers.mask(mask=jnp.logical_not(cubits_mask)):
            c = numpyro.sample('c', mix, obs=cubits)

        # record m for outlier identification
        # print(c.shape)
        # log_probs = mix.component_log_probs(c)
        # numpyro.deterministic('m', log_probs - jax.nn.logsumexp(log_probs, axis=-1, keepdims=True))

## Skellam experiments

In [10]:
class Skellam(dist.Distribution):
    def __init__(self, u1, u2):
        batch_shape = jax.lax.broadcast_shapes(jnp.shape(u1), jnp.shape(u2))
        self._tf_skellam = tfd.Skellam(u1,u2)
        self._poisson1 = dist.Poisson(u1)
        self._poisson2 = dist.Poisson(u2)
        super(Skellam, self).__init__(batch_shape)

    def log_prob(self, value):
        return self._tf_skellam.log_prob(value)
    
    def sample(self, key, sample_shape=()):
        k1, k2 = jax.random.split(key)
        return self._poisson1.sample(k1, sample_shape=sample_shape) - self._poisson2.sample(k2, sample_shape=sample_shape)


In [None]:
class TruncatedSkellam(dist.Distributio, low=-2, high=2):
    def __init__(self, u1, u2):
        batch_shape = jax.lax.broadcast_shapes(jnp.shape(u1), jnp.shape(u2))
        self._tf_skellam = tfd.Skellam(u1,u2)
        self._poisson1 = dist.Poisson(u1)
        self._poisson2 = dist.Poisson(u2)
        super(Skellam, self).__init__(batch_shape)

    def log_prob(self, value):
        return self._tf_skellam.log_prob(value)
    
    def sample(self, key, sample_shape=()):
        k1, k2 = jax.random.split(key)
        return self._poisson1.sample(k1, sample_shape=sample_shape) - self._poisson2.sample(k2, sample_shape=sample_shape)


In [44]:
dist.Categorical(tfd.Skellam(0.5,1.2).prob(jnp.arange(-2,3,dtype=float))).sample(jax.random.PRNGKey(1234), sample_shape=(1000,)) - 2

Array([ 0, -1, -1, -1, -1,  0, -1, -1, -1,  0,  0, -1,  0, -2, -2, -1,  1,
        0,  0, -2, -2,  0, -1,  0,  2, -2,  1,  1, -1, -2, -2, -2, -1,  1,
        1,  0, -2, -1,  0, -2,  2,  0,  0,  2,  0,  0,  1,  0,  0, -1,  0,
       -2, -1, -1, -1,  0, -2,  0, -1, -1,  0, -1,  1, -2, -1,  2, -2,  1,
       -1,  0, -1, -1,  0, -2, -2, -1, -1, -1,  0,  1, -1, -2,  2,  0, -1,
       -1,  2,  0,  0,  0, -1,  0,  2, -2, -2,  0,  2,  0, -1, -2,  0, -1,
        0,  0, -2,  0, -1,  1, -1,  1,  0,  0,  0,  0, -1,  1,  1, -1,  0,
        2, -2, -2, -1,  0, -2, -2, -1, -1,  1,  1,  0,  0,  0,  0,  0,  0,
        1, -2,  0,  0,  0,  1,  0, -1,  0,  0,  0,  0,  0, -2, -1,  1,  0,
        0, -2, -1,  0,  0,  0, -2,  0,  0,  0, -1,  0,  2, -1, -2,  1,  1,
       -2,  0,  0,  1, -1,  0, -1, -2, -1,  1, -1,  0, -1, -2,  0,  0,  0,
       -1,  0, -1,  0, -1,  1,  0,  0, -1,  1, -1, -1,  0,  2, -1, -1,  1,
       -2,  0,  1, -2,  0, -1, -1, -1, -2, -2, -1, -1,  0, -1,  0, -1,  1,
       -2,  1, -1, -2,  1

In [23]:
dist.Categorical(jnp.array([0.33,0.66])).sample(jax.random.PRNGKey(1234), sample_shape=(1000,)).sum()

Array(677, dtype=int32)

In [3]:
data_sk = dist.Poisson(1.2).sample(jax.random.PRNGKey(1234), sample_shape=(1000,)) - dist.Poisson(3.4).sample(jax.random.PRNGKey(4321), sample_shape=(1000,))
data = dist.Normal(data_sk, 1).sample(jax.random.PRNGKey(5255))

In [5]:
def model():
    u1 = numpyro.sample('u1', dist.Gamma(0.5,0.5))
    u2 = numpyro.sample('u2', dist.Gamma(0.5,0.5))
    sk = numpyro.sample('sk', Skellam(u1, u2))
    numpyro.sample('data', dist.Normal(sk, 1), obs=data)

In [70]:
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    u1 = hmc_sites['u1']
    u2 = hmc_sites['u2']
    return

kernel = numpyro.infer.HMCGibbs(numpyro.infer.NUTS(model), gibbs_fn=gibbs_fn, gibbs_sites=['sk'])
mcmc = numpyro.infer.MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1, progress_bar=True)
mcmc.run(jax.random.PRNGKey(1234))

SyntaxError: expected argument value expression (2938217441.py, line 1)

In [18]:
Skellam(0.2,0.3).log_prob(jnp.array([3.0,0.,3.]))

Array([-7.1050954, -0.4408767, -7.1050954], dtype=float32)

In [30]:
dist.TruncatedDistribution(tfd.Normal(0,1), low=-1,high=1)

AssertionError: 

In [23]:
dist.TruncatedDistribution(dist.Normal(0,1), low=-1,high=1)

<numpyro.distributions.truncated.TwoSidedTruncatedDistribution at 0x38426c980>

In [4]:
kernel = numpyro.infer.DiscreteHMCGibbs(numpyro.infer.NUTS(model), modified=False)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1, progress_bar=True)
mcmc.run(jax.random.PRNGKey(1234))

warmup:   0%|          | 14/5000 [23:54<141:57:11, 102.49s/it, 1023 steps of size 3.88e-03. acc. prob=0.62]


KeyboardInterrupt: 

In [35]:
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1, progress_bar=True)
mcmc.run(jax.random.PRNGKey(1234))

  mcmc.run(jax.random.PRNGKey(1234))
  mcmc.run(jax.random.PRNGKey(1234))
  mcmc.run(jax.random.PRNGKey(1234))
  mcmc.run(jax.random.PRNGKey(1234))
warmup:   0%|          | 4/5000 [00:03<44:06,  1.89it/s, 1 steps of size 1.07e-03. acc. prob=0.00]  

: 

: 

In [None]:
import datetime
now = datetime.datetime.now()

## Complete observations only

In [5]:
complete_dated_timed_mask = jnp.logical_not(
	jnp.logical_or(
	jnp.logical_or(
	jnp.logical_or(
	jnp.logical_or(
	jnp.logical_or(
	jnp.logical_or(
	objects_mask,
	references_mask),
	axes_mask),
	signs_mask),
	cubits_mask),
	earliest_days != latest_days),
	times_mask
	)
)

In [10]:
complete_dated_timed_mask.mean()

Array(0.66820824, dtype=float32)

In [29]:
# select masked observations
mask = complete_dated_timed_mask
s_objects = objects[mask]
s_references = references[mask]
s_axes = axes[mask]
s_signs = signs[mask]
s_cubits = cubits[mask]
s_years = years[mask]
s_months = months[mask]
s_days = earliest_days[mask]
s_times = times[mask]

def model():
    # number of observations
    n = len(s_objects)

    # length of a cubit
    length_cubit = numpyro.sample('length_cubit', dist.TruncatedNormal(2.0, 1.0, low=0.0))

    # observation variance
    distance_variance = numpyro.sample('distance_variance', dist.Gamma(0.5,0.5))

    # outlier distribution parameters
    mu_outlier = numpyro.sample('mu_outlier', dist.Normal(0,1))
    sigma_outlier = numpyro.sample('sigma_outlier', dist.Gamma(2.0,0.5))

    # outlier probability prior
    q = numpyro.sample('q', dist.Beta(1/2,1))

    # time offsets for different observation times
    # TODO: Unclear if the .to_event(1) is correct. Maybe can be written with a plate?
    o_means = numpyro.sample('o_means', dist.Normal(jnp.zeros(n_times), jnp.ones(n_times) * 6.0).to_event(1))
    o_vars = numpyro.sample('o_vars', dist.Gamma(jnp.ones(n_times) * 0.5, jnp.ones(n_times) * 0.5).to_event(1))

    # the observation times of different observations
    tau = numpyro.sample('tau', dist.Normal(o_means[s_times], o_vars[s_times]))

    observations_plate = numpyro.plate('observations', n)

    # observation dates
    date = julian_date_index[year_month_julian_dates[s_years, s_months] + s_days - 1]
    
    # with observations_plate:
    #     # the observation times of different observations
    #     time = numpyro.sample('time', dist.Normal(0.0,6.0))
        
    # true distances at the given observation times
    distance_range = object_reference_axis_date_positions[s_objects, s_references, s_axes, date]

    true_distance = ((distance_range[...,1]-distance_range[...,0])/(time_range[1] - time_range[0])*(tau-time_range[1]))+distance_range[...,1]

    # outlier mixture model
    cat = dist.Categorical(probs=jnp.array([1-q,q]))
    inlier_dist = dist.Normal(true_distance / length_cubit, distance_variance)
    outlier_dist = dist.Normal(mu_outlier, sigma_outlier)
    mix = dist.Mixture(cat, [inlier_dist, outlier_dist])
    # mix = inlier_dist
    
    with observations_plate:
        c = numpyro.sample('c', mix, obs=s_signs*s_cubits)

        # record m for outlier identification
        log_probs = mix.component_log_probs(c)
        numpyro.deterministic('m', log_probs - jax.nn.logsumexp(log_probs, axis=-1, keepdims=True))

In [30]:
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1, progress_bar=True)
mcmc.run(jax.random.PRNGKey(1234))

warmup:   2%|▏         | 75/5000 [00:19<21:22,  3.84it/s, 1023 steps of size 6.55e-06. acc. prob=0.71]


KeyboardInterrupt: 

In [13]:
s_objects

Array([0, 0, 0, ..., 0, 0, 0], dtype=int32)