In [None]:
!nvidia-smi --query-gpu=gpu_name --format=csv

In [None]:
import os

os.environ["JAX_ENABLE_X64"] = "True"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["XLA_FLAGS"] = (
    os.environ.get("XLA_FLAGS", "")
    + " --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
)

In [None]:
from functools import partial

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import jax

import tinygp

jax.config.update("jax_enable_x64", True)

sigma = 1.5
rho = 2.5
jitter = 0.1

random = np.random.default_rng(49382)
x = np.sort(random.uniform(0, 10, 100_000))
y = np.sin(x) + jitter * random.normal(0, 1, len(x))


def tinygp_loglike(x, y):
    kernel = sigma**2 * tinygp.kernels.Matern32(rho)
    gp = tinygp.GaussianProcess(kernel, x, diag=jitter**2)
    return gp.log_probability(y)


tinygp_loglike_cpu = jax.jit(tinygp_loglike, backend="cpu")
tinygp_loglike_gpu = jax.jit(tinygp_loglike, backend="gpu")


@partial(jax.jit, backend="cpu", static_argnames=("num_terms",))
def quasisep_loglike(x, y, num_terms=1):
    kernel = tinygp.kernels.quasisep.Matern32(sigma=sigma, scale=rho)
    for _ in range(1, num_terms):
        kernel += tinygp.kernels.quasisep.Matern32(sigma=sigma, scale=rho)
    gp = tinygp.GaussianProcess(kernel, x, diag=jitter**2)
    return gp.log_probability(y)

In [None]:
ns = [10, 20, 100, 200, 1_000, 2_000, 10_000, 20_000, len(x)]
num_terms = [1, 3, 5]
data = []
for n in ns:
    print(f"\nN = {n}:")
    row = [n]

    args = x[:n], y[:n]
    gpu_args = jax.device_put(x[:n]), jax.device_put(y[:n])

    if n < 10_000:
        tinygp_loglike_cpu(*args).block_until_ready()
        results = %timeit -o tinygp_loglike_cpu(*args).block_until_ready()
        row.append(results.average)
    else:
        row.append(np.nan)

    if n <= 20_000:
        tinygp_loglike_gpu(*gpu_args).block_until_ready()
        results = %timeit -o tinygp_loglike_gpu(*gpu_args).block_until_ready()
        row.append(results.average)
    else:
        row.append(np.nan)

    for j in num_terms:
        quasisep_loglike(*args, num_terms=j).block_until_ready()
        results = %timeit -o quasisep_loglike(*args, num_terms=j).block_until_ready()
        row.append(results.average)

    data.append(tuple(row))

data = np.array(
    data,
    dtype=[
        ("n", int),
        ("cpu", float),
        ("gpu", float),
    ]
    + [(f"qs{j}", float) for j in num_terms],
)

In [None]:
df = pd.DataFrame.from_records(data)
df = df.set_index("n")
df.to_csv("scaling.csv")

In [None]:
fig, ax = plt.subplots()
for k, s in zip(df.columns, "s^ooo"):
    label = (
        f"celerite ({k[2:]} term{'' if k[2] == '1' else 's'})"
        if k.startswith("qs")
        else f"direct ({k.upper()})"
    )
    ax.loglog(df.index[~df[k].isna()], df[k][~df[k].isna()], f"{s}-", label=label)
ax.legend()
ax.set_xlabel("number of data points")
ax.set_ylabel("cost of one likelihood evaluation [sec]");