In [1]:
%load_ext nb_black
%matplotlib inline
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import numpy as np
import logging
import jax


logging.getLogger().setLevel(logging.INFO)
logging.getLogger("absl").setLevel(logging.DEBUG)

set_matplotlib_formats("svg")


<IPython.core.display.Javascript object>

In [2]:
from math import log, exp
from sim import sim_and_fit, sim_wf
from common import Observation
from plotting import plot_summary
rng = np.random.default_rng(2)


DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
DEBUG:absl:Initializing backend 'gpu'
DEBUG:absl:Backend 'gpu' initialized
DEBUG:absl:Initializing backend 'tpu'
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


<IPython.core.display.Javascript object>

In [10]:
from jax import vmap, jit, value_and_grad
from betamix import BetaMixture
from estimate import jittable_estimate

prior = BetaMixture.uniform(100)
batch_estimate = vmap(jittable_estimate, (0,) * 2 + (None,) * 4)

<IPython.core.display.Javascript object>

In [11]:
model = {"s": [0.01] * 100, "h": [0.5] * 100, "f0": 0.1}
res = sim_and_fit(model, seed=1, lam=1e2)
Ne = res["Ne"]
obs = res["obs"]

<IPython.core.display.Javascript object>

In [12]:
for batch_size in [1, 10, 100]:

    def b(ary):
        return np.repeat(ary[None], batch_size, axis=0)

    Neb = b(Ne)
    obsb = b(obs)

    for plat in ["cpu", "gpu"]:
        f = jit(batch_estimate, backend=plat, static_argnums=5)
        for i in range(2):
            print("b=%d, %s mode, run=%i" % (batch_size, plat, i))
            %time f(obsb, Neb, 0., prior, 0., 100)[0].block_until_ready()

b=1, cpu mode, run=0


DEBUG:absl:Compiling jittable_estimate (140327327096192) for args (ShapedArray(float64[1,101,2]), ShapedArray(float64[1,100]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[], weak_type=True)).


CPU times: user 11.3 s, sys: 27.2 ms, total: 11.3 s
Wall time: 11.4 s
b=1, cpu mode, run=1
CPU times: user 1.18 s, sys: 6.25 ms, total: 1.19 s
Wall time: 1.18 s
b=1, gpu mode, run=0


DEBUG:absl:Compiling jittable_estimate (140327724471680) for args (ShapedArray(float64[1,101,2]), ShapedArray(float64[1,100]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[], weak_type=True)).


CPU times: user 11.6 s, sys: 252 ms, total: 11.9 s
Wall time: 7.63 s
b=1, gpu mode, run=1
CPU times: user 3.88 s, sys: 176 ms, total: 4.05 s
Wall time: 4.05 s
b=10, cpu mode, run=0


DEBUG:absl:Compiling jittable_estimate (140326784064384) for args (ShapedArray(float64[10,101,2]), ShapedArray(float64[10,100]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[], weak_type=True)).


CPU times: user 55.3 s, sys: 339 ms, total: 55.6 s
Wall time: 55.1 s
b=10, cpu mode, run=1
CPU times: user 13.2 s, sys: 189 ms, total: 13.4 s
Wall time: 12.9 s
b=10, gpu mode, run=0


DEBUG:absl:Compiling jittable_estimate (140327327166656) for args (ShapedArray(float64[10,101,2]), ShapedArray(float64[10,100]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[], weak_type=True)).


CPU times: user 15.6 s, sys: 1.8 s, total: 17.4 s
Wall time: 13.3 s
b=10, gpu mode, run=1
CPU times: user 6.6 s, sys: 1.67 s, total: 8.27 s
Wall time: 8.26 s
b=100, cpu mode, run=0


DEBUG:absl:Compiling jittable_estimate (140326185929408) for args (ShapedArray(float64[100,101,2]), ShapedArray(float64[100,100]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[], weak_type=True)).


CPU times: user 2min 30s, sys: 4.1 s, total: 2min 34s
Wall time: 2min 14s
b=100, cpu mode, run=1
CPU times: user 2min 17s, sys: 4.35 s, total: 2min 21s
Wall time: 2min 1s
b=100, gpu mode, run=0


DEBUG:absl:Compiling jittable_estimate (140326048627584) for args (ShapedArray(float64[100,101,2]), ShapedArray(float64[100,100]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[100]), ShapedArray(float64[], weak_type=True)).


CPU times: user 17.1 s, sys: 2.27 s, total: 19.3 s
Wall time: 14.9 s
b=100, gpu mode, run=1
CPU times: user 7.48 s, sys: 2.3 s, total: 9.77 s
Wall time: 9.77 s


<IPython.core.display.Javascript object>