In [1]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 270.2 gigabytes of available RAM

You are using a high-RAM runtime!


In [2]:
import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.optimize as jsco
from jax.lib import xla_bridge
from jax._src import compilation_cache as cc

import tinygp
from tinygp import GaussianProcess, kernels
from tinygp.kernels import quasisep

import eztao
from eztao.carma import DRW_term
from eztao.ts import gpSimRand

import logging

jax.config.update("jax_enable_x64", True)
cc.initialize_cache("./cache_min_example")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.1)

print("eztao version: " + eztao.__version__)

print("jax version: " + jax.__version__)
print("tinygp version: " + tinygp.__version__)

print("cpu/gpu: " + str(xla_bridge.get_backend().platform))

Initialized persistent compilation cache at ./cache_min_example


eztao version: 0.4.1
jax version: 0.4.18
tinygp version: 0.2.4
cpu/gpu: gpu


In [3]:
t_batch = []
y_batch = []
yerr_batch = []
len_lc = []

# Create 10 lightcurves with different length
for i in range(10):  
  amp = 0.2
  tau = 100
  DRW_kernel = DRW_term(np.log(amp), np.log(tau))
  t, y_drw, yerr = gpSimRand(DRW_kernel, 10, 365*10, 203+i)

  # This adds periodic component to the drw process
  y = y_drw + 0.2* np.sin(t/100)

  t_batch.append(t)
  y_batch.append(y)
  yerr_batch.append(yerr)
  len_lc.append(len(t))

In [4]:
# create a single simple testing lightcurve
t, y, yerr = np.arange(0, 1000, 1), np.sin(np.arange(0, 1000, 1)), np.ones(1000)

theta_init_float = [np.log10(100), np.log10(0.25),np.log10(1.2), np.log10(4.3)]

def build_exp_gp_kernels(theta_float, t, y, yerr):
    """Build an Gaussian proccess that is only exp
    """

    log_drw_scale = theta_float[0]
    exp_kernel = kernels.Exp(scale=10**log_drw_scale)
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

def build_exp_gp_quasi(theta_float, t, y, yerr):
    """Build an Gaussian proccess that is only exp, but quasiseparable
    """

    log_drw_scale = theta_float[0]
    exp_kernel = quasisep.Exp(scale=10**log_drw_scale)
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

def build_gp_float(theta_float, t, y, yerr):
    """Build an Gaussian proccess that is a combination of exp and periodic
    """
    log_drw_scale = theta_float[0]
    log_drw_amp = theta_float[1]
    log_per_scale = theta_float[2]
    log_per_amp = theta_float[3]

    sigma_drw = 10**log_drw_amp
    sigma_per = 10**log_per_amp

    exp_kernel = kernels.Exp(scale=10**log_drw_scale)
    periodic_kernel = kernels.Cosine(scale=10**(log_per_scale))

    kernel = sigma_drw * exp_kernel + sigma_per * periodic_kernel

    return GaussianProcess(kernel, t, diag=yerr, mean=np.mean(y))

def build_gp_float_quasisep(theta_float, t, y, yerr):
    """Build an Gaussian proccess that is a combination of exp and periodic
    """
    log_drw_scale = theta_float[0]
    log_drw_amp = theta_float[1]
    log_per_scale = theta_float[2]
    log_per_amp = theta_float[3]

    exp_kernel = quasisep.Exp(
        scale=10**log_drw_scale, sigma=10**log_drw_amp
    )

    periodic_kernel = (
        quasisep.Cosine(
        scale=10**(log_per_scale),
        sigma=10**(log_per_amp),
        )
    )

    kernel = exp_kernel + periodic_kernel

    return GaussianProcess(kernel, t, diag=yerr, mean=np.mean(y))

"""
@jax.jit
def neg_log_likelihood_kernels(theta, t, y, yerr):
    gp = build_exp_gp_kernels(theta, t, y, yerr)
    return -gp.log_probability(y)

@jax.jit
def neg_log_likelihood_quasi(theta, t, y, yerr):
    gp = build_exp_gp_quasi(theta, t, y, yerr)
    return -gp.log_probability(y)
"""

def neg_log_likelihood_float(theta_float, t, y, yerr):
    gp = build_gp_float(theta_float, t, y, yerr)
    return -gp.log_probability(y)

def neg_log_likelihood_float_quasisep(theta_float, t, y, yerr):
    gp = build_gp_float_quasisep(theta_float, t, y, yerr)
    return -gp.log_probability(y)

@jax.jit
def neg_log_likelihood_float_jit(theta_float, t, y, yerr):
    gp = build_gp_float(theta_float, t, y, yerr)
    return -gp.log_probability(y)

@jax.jit
def neg_log_likelihood_float_jit_quasisep(theta_float, t, y, yerr):
    gp = build_gp_float_quasisep(theta_float, t, y, yerr)
    return -gp.log_probability(y)


logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

# both neg_log_likelihood are written
neg_log_likelihood_float_jit(theta_init_float, t, y, yerr)
neg_log_likelihood_float_jit_quasisep(theta_init_float, t, y, yerr)

2023-10-20 11:12:17.934670: W external/xla/xla/service/gpu/nvptx_compiler.cc:703] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.2.140). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
INFO:jax._src.compilation_cache:Writing jit_neg_log_likelihood_float_jit to persistent compilation cache with key jit_neg_log_likelihood_float_jit-dee2dc0c8308575790c32afdc7b585cbafee9b791cf75b69db235db5e45815ae.
INFO:jax._src.compiler:Not writing persistent cache entry for 'jit_neg_log_likelihood_float_jit_quasisep' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)


Array(1190.59536823, dtype=float64)

In [5]:
neg_log_likelihood_float_jit(theta_init_float, t, y, yerr)
neg_log_likelihood_float_jit_quasisep(theta_init_float, t, y, yerr)

Array(1190.59536823, dtype=float64)

In [6]:
def jsoln_jax_ty(t,y,yerr):
    # via jax implementation of scipy minimize
    # remove jnp.array, i.e., demand that inputs are jax arrays?
    jsoln = jsco.minimize(neg_log_likelihood_float, x0=jnp.array(theta_init_float), method="bfgs", args=(jnp.array(t), jnp.array(y), jnp.array(yerr)))
    return jsoln.fun

def jsoln_jax_ty_quasisep(t,y,yerr):
    # via jax implementation of scipy minimize
    # remove jnp.array, i.e., demand that inputs are jax arrays?
    jsoln = jsco.minimize(neg_log_likelihood_float_quasisep, x0=jnp.array(theta_init_float), method="bfgs", args=(jnp.array(t), jnp.array(y), jnp.array(yerr)))
    return jsoln.fun

# what if I use already jitted likelihoods?
def jsoln_jax_ty_jit(t,y,yerr):
    # via jax implementation of scipy minimize
    # remove jnp.array, i.e., demand that inputs are jax arrays?
    jsoln = jsco.minimize(neg_log_likelihood_float_jit, x0=jnp.array(theta_init_float), method="bfgs", args=(jnp.array(t), jnp.array(y), jnp.array(yerr)))
    return jsoln.fun

def jsoln_jax_ty_jit_quasisep(t,y,yerr):
    # via jax implementation of scipy minimize
    # remove jnp.array, i.e., demand that inputs are jax arrays?
    jsoln = jsco.minimize(neg_log_likelihood_float_jit_quasisep, x0=jnp.array(theta_init_float), method="bfgs", args=(jnp.array(t), jnp.array(y), jnp.array(yerr)))
    return jsoln.fun

jsoln_jax_ty_gpu = jax.jit(jsoln_jax_ty, backend="gpu")
jsoln_jax_ty_quasisep_gpu = jax.jit(jsoln_jax_ty_quasisep, backend="gpu")
jsoln_jax_ty_jit_gpu = jax.jit(jsoln_jax_ty_jit, backend="gpu")
jsoln_jax_ty_jit_quasisep_gpu = jax.jit(jsoln_jax_ty_jit_quasisep, backend="gpu")

In [7]:
# doesnt work, maybe cuda problem 

%time res_jsoln_jax_ty_gpu = jsoln_jax_ty_gpu(t, y, yerr)
%time res_jsoln_jax_ty_quasisep_gpu = jsoln_jax_ty_quasisep_gpu(t, y, yerr)
%time res_jsoln_jax_ty_gpu = jsoln_jax_ty_jit_gpu(t, y, yerr)
%time res_jsoln_jax_ty_quasisep_gpu = jsoln_jax_ty_jit_quasisep_gpu(t, y, yerr)

XlaRuntimeError: INTERNAL: cuSolver internal error

INFO:jax._src.compiler:Not writing persistent cache entry for 'jit_jsoln_jax_ty_quasisep' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)


CPU times: user 7.69 s, sys: 276 ms, total: 7.96 s
Wall time: 9.42 s


XlaRuntimeError: INTERNAL: cuSolver internal error

INFO:jax._src.compiler:Not writing persistent cache entry for 'jit_jsoln_jax_ty_jit_quasisep' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)


CPU times: user 7.96 s, sys: 179 ms, total: 8.14 s
Wall time: 9.51 s


In [17]:

def build_gp(theta, X):
    # We want most of our parameters to be positive so we take the `exp` here
    # Note that we're using `jnp` instead of `np`
    amps = jnp.exp(theta["log_amps"])
    scales = jnp.exp(theta["log_scales"])

    # Construct the kernel by multiplying and adding `Kernel` objects
    k1 = amps[0] * kernels.ExpSquared(scales[0])
    k2 = (
        amps[1]
        * kernels.ExpSquared(scales[1])
        * kernels.ExpSineSquared(
            scale=jnp.exp(theta["log_period"]),
            gamma=jnp.exp(theta["log_gamma"]),
        )
    )
    k3 = amps[2] * kernels.RationalQuadratic(
        alpha=jnp.exp(theta["log_alpha"]), scale=scales[2]
    )
    k4 = amps[3] * kernels.ExpSquared(scales[3])
    kernel = k1 + k2 + k3 + k4

    return GaussianProcess(
        kernel, X, diag=jnp.exp(theta["log_diag"]), mean=theta["mean"]
    )


def neg_log_likelihood(theta, X, y):
    gp = build_gp(theta, X)
    return -gp.log_probability(y)


theta_init = {
    "mean": np.float64(340.0),
    "log_diag": np.log(0.19),
    "log_amps": np.log([66.0, 2.4, 0.66, 0.18]),
    "log_scales": np.log([67.0, 90.0, 0.78, 1.6]),
    "log_period": np.float64(0.0),
    "log_gamma": np.log(4.3),
    "log_alpha": np.log(1.2),
}

# `jax` can be used to differentiate functions, and also note that we're calling
# `jax.jit` for the best performance.
obj = jax.jit(jax.value_and_grad(neg_log_likelihood))

print(f"Initial negative log likelihood: {obj(theta_init, t, y)[0]}")
print(
    f"Gradient of the negative log likelihood, wrt the parameters:\n{obj(theta_init, t, y)[1]}"
)

DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.0010919570922851562 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00042510032653808594 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming square for pjit in 0.0004372596740722656 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0007770061492919922 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming true_divide for pjit in 0.0005528926849365234 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming absolute for pjit in 0.00037169456481933594 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0004432201385498047 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _power for pjit in 0.00043702125549316406 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0004482269287109375 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _broadcast_arrays for pj

Initial negative log likelihood: 19692.68549568087
Gradient of the negative log likelihood, wrt the parameters:
{'log_alpha': Array(108.16638225, dtype=float64), 'log_amps': Array([-1.71192863e+04, -1.24699148e+02,  2.25648065e+02, -7.62560177e+00],      dtype=float64), 'log_diag': Array(402.97888208, dtype=float64), 'log_gamma': Array(266.50043559, dtype=float64), 'log_period': Array(-213.51481254, dtype=float64), 'log_scales': Array([-11937.8256461 ,   -356.5112549 ,   -552.36351147,   -115.72900226],      dtype=float64), 'mean': Array(104.33296606, dtype=float64)}


In [11]:
amp = 0.2
tau = 100
DRW_kernel = DRW_term(np.log(amp), np.log(tau))
t, y_drw, yerr = gpSimRand(DRW_kernel, 10, 365*10, 2222)

# This adds periodic component to the drw process
y = y_drw + 0.2* np.sin(t/100)

In [12]:
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.1)
jax.config.update("jax_compilation_cache_include_metadata_in_key", False)

In [13]:
# what if I have a simpler function

@jax.jit
def neg_log_likelihood_float_test(theta_float, t, y, yerr):
    test = t + y + yerr
    return -jnp.sum(test)

print(f"Initial negative log likelihood: {neg_log_likelihood_float_test(theta_init_float, t, y, yerr)}")

DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0006060600280761719 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0006413459777832031 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming neg_log_likelihood_float_test for pjit in 0.005285024642944336 sec
DEBUG:jax._src.interpreters.pxla:Compiling neg_log_likelihood_float_test for with global shapes and types [ShapedArray(float64[2222]), ShapedArray(float64[2222]), ShapedArray(float64[2222])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(neg_log_likelihood_float_test) in 0.0034036636352539062 sec
DEBUG:jax._src.compiler:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[cuda(id=0)]]
DEBUG:jax._src.compiler:get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:jax._src.cache_key:get_cache

Initial negative log likelihood: -4047113.132524418


In [16]:
@jax.jit
def neg_log_likelihood_drw(theta, t, y, yerr):
    gp = build_exp_gp(theta, t, y, yerr)
    return -gp.log_probability(y)


print(f"Initial negative log likelihood: {neg_log_likelihood_drw(theta_init, t, y, yerr)}")

DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0006961822509765625 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _mean for pjit in 0.0024716854095458984 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming ravel for pjit in 0.0001938343048095703 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming append for pjit in 0.001699686050415039 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming matmul for pjit in 0.0008721351623535156 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.00038242340087890625 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0004951953887939453 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming dot for pjit in 0.0006043910980224609 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming __add__ for pjit in 0.0005941390991210938 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0003800392150878

Initial negative log likelihood: -1949.2229968120407


In [1]:

# generate fake data
# how many lightcurves to create
n_lc = 10
print(f"\nN = {n_lc}:")

# Define the range for random variations|
variation_range = 0.01  # Adjust this value to control the extent of variation

t_batch = []
y_batch = []
yerr_batch = []

# Create variations
variation_range = 0.0000001
for _ in range(n_lc):  # Change the number as per your requirement
    t_var = t
    y_var =y
    yerr_var = yerr
    t_batch.append(t_var)
    y_batch.append(y_var)
    yerr_batch.append(yerr_var)

t_batch = np.array(t_batch)
y_batch = np.array(y_batch)
yerr_batch = np.array(yerr_batch)

t_batch_jax = jnp.array(t_batch)
y_batch_jax = jnp.array(y_batch)
yerr_batch_jax = jnp.array(yerr_batch)



N = 10:


NameError: name 't' is not defined

In [12]:
%time jsoln_jax_ty_gpu_vmap(t_batch_jax, y_batch_jax, yerr_batch_jax).block_until_ready()

DEBUG:jax._src.interpreters.pxla:Compiling jsoln_jax_ty for with global shapes and types [ShapedArray(float64[10,203]), ShapedArray(float64[10,203]), ShapedArray(float64[10,203])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(jsoln_jax_ty) in 0.697577714920044 sec
DEBUG:jax._src.compiler:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[cuda(id=0)]]
DEBUG:jax._src.compiler:get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:jax._src.cache_key:get_cache_key hash of serialized computation: 19294f348f7ebeea596452d60a258b965c2071d8558fbc11c76dd5d1921d21f7
DEBUG:jax._src.cache_key:get_cache_key hash after serializing computation: 19294f348f7ebeea596452d60a258b965c2071d8558fbc11c76dd5d1921d21f7
DEBUG:jax._src.cache_key:get_cache_key hash of serialized jax_lib version: 99b916eb3ced1f033fba99b2686e992e18db31b37

INFO:jax._src.compiler:Not writing persistent cache entry for 'jit_jsoln_jax_ty' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:jax._src.dispatch:Finished XLA compilation of jit(jsoln_jax_ty) in 4.813591718673706 sec


CPU times: user 5.57 s, sys: 403 ms, total: 5.97 s
Wall time: 6.95 s


Array([-119.68832781, -119.68832781, -119.68832781, -119.68832781,
       -119.68832781, -119.68832781, -119.68832781, -119.68832781,
       -119.68832781, -119.68832781], dtype=float64)