In [1]:
import os
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [None]:
import jax.numpy as jnp
import jax.random as jr
from jax import scipy as jsp

import blackjax as bj
import tensorflow_probability.substrates.jax as tfp

In [None]:
import jax
from jax.lib import xla_bridge

In [None]:
print(jax.devices())
print(jax.default_backend())
print(jax.device_count())
print(xla_bridge.get_backend().platform)

[CpuDevice(id=0)]
cpu
1
cpu


In [None]:
from functools import partial

import distrax
import haiku as hk
import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from jax import numpy as jnp
from jax import random
from jax import scipy as jsp
from jax import vmap
from functools import partial

In [None]:
def likelihood_fn(theta):
    mu = jnp.tile(theta[:2], 4)
    s1, s2 = theta[2] ** 2, theta[3] ** 2
    corr = s1 * s2 * jnp.tanh(theta[4])
    cov = jnp.array([[s1**2, corr], [corr, s2**2]])
    cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
    p = distrax.MultivariateNormalFullCovariance(mu, cov)
    return p

In [None]:
lik = likelihood_fn(jr.normal(jr.PRNGKey(123), (5,)))
y = lik.sample(seed=jr.PRNGKey(1), sample_shape=(10,))

In [None]:
def likelihood_fn(theta, y):
    mu = jnp.tile(theta[:2], 4)
    s1, s2 = theta[2] ** 2, theta[3] ** 2
    corr = s1 * s2 * jnp.tanh(theta[4])
    cov = jnp.array([[s1**2, corr], [corr, s2**2]])
    cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
    p = distrax.MultivariateNormalFullCovariance(mu, cov)
    return p.log_prob(y)

In [None]:
prior = distrax.Independent(
    distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1
)

In [None]:
def log_density_fn(theta, y):
    prior_lp = prior.log_prob(theta)
    likelihood_lp = likelihood_fn(theta, y)

    lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)
    return lp

target_log_prob_fn_partial = partial(log_density_fn, y=y)
target_log_prob_fn = lambda theta: jax.vmap(target_log_prob_fn_partial)(theta)

In [None]:
n_samples = 10000
n_warmup = 5000
n_chains = 4

In [None]:
nuts = tfp.mcmc.NoUTurnSampler(
    target_log_prob_fn,
    step_size=0.1,
    max_tree_depth=10,
    max_energy_diff=1000.0,
    unrolled_leapfrog_steps=1,
)


nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(
    inner_kernel=nuts,
    num_adaptation_steps=int(0.8 * n_warmup),
    target_accept_prob=jnp.asarray(0.75, jnp.float32)
)

tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
    nuts,
    initial_running_variance,
    num_estimation_steps=None,
    momentum_distribution_setter_fn=hmc_like_momentum_distribution_setter_fn,
    momentum_distribution_getter_fn=hmc_like_momentum_distribution_getter_fn,
    validate_args=False,
    experimental_shard_axis_names=None,
    name=None
)


AttributeError: 'functools.partial' object has no attribute 'experimental_default_event_space_bijector'

In [None]:
jax.device_count()

4

In [None]:
initial_states = jr.normal(jr.PRNGKey(4), shape=(n_chains, 5))
samples = tfp.mcmc.sample_chain(
    num_results=n_samples - n_warmup,
    current_state=initial_states,
    num_steps_between_results=1,
    kernel=adaptive_sampler,
    num_burnin_steps=n_warmup,
    trace_fn=None,
    seed=jr.PRNGKey(2),
)
samples = samples[n_warmup:, ...]

In [19]:
jax.devices()

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

In [24]:
def sfunc(x): 
    while True: pass

jax.pmap(sfunc)(jnp.arange(4))

KeyboardInterrupt: 

In [9]:
import time, os, jax, numpy as np, jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu') # insures we use the CPU

def timer(name, f, x, shouldBlock=True):
   # warmup
   y = f(x).block_until_ready() if shouldBlock else f(x)
   # running the code
   start_wall = time.perf_counter()
   start_cpu = time.process_time()
   y = f(x).block_until_ready() if shouldBlock else f(x)
   end_wall = time.perf_counter()
   end_cpu = time.process_time()
   # computing the metric and displaying it
   wall_time = end_wall - start_wall
   cpu_time = end_cpu - start_cpu
   cpu_count = os.cpu_count()
   print(f"{name}: cpu usage {cpu_time/wall_time:.1f}/{cpu_count} wall_time:{wall_time:.1f}s")

# test functions
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, shape=(500000000,), dtype=jnp.float64)
x_mat = jax.random.normal(key, shape=(10000,10000), dtype=jnp.float64)
f_numpy = np.cos
f_vmap = jax.jit(jax.vmap(jnp.cos))
f_dot = jax.jit(lambda x: jnp.dot(x,x.T)) # to show that JAX can indeed use all cores

timer('numpy', f_numpy, x, shouldBlock=False)
timer('vmap', f_vmap, x)
timer('dot', f_dot, x_mat)

numpy: cpu usage 0.6/8 wall_time:1.7s
vmap: cpu usage 1.8/8 wall_time:1.9s
dot: cpu usage 7.0/8 wall_time:3.7s


In [10]:
os.cpu_count()

8

In [8]:
from jax import local_device_count
print(jax.local_devices()) #1 instead of 2

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
