# goom tests
Use this notebook to run pytest from the repo root with the `goom_jax` conda env active.

In [1]:
import sys
import goom as goom

print("goom module:", goom)
print("Python executable:", sys.executable)
print("Python version:", sys.version)

import jax
import jax.numpy as jnp
from jax import lax

import goom.lmme as lmme
import goom.operations as oprs
import goom.lle as lle

config = goom.config  # grab the config object from the goom module

config.keep_logs_finite = True          # log(0) will return a finite floor
config.cast_all_logs_to_complex = True  # GOOMs are complex-typed
config.float_dtype = jnp.float32        # real dtype

%load_ext autoreload
%autoreload 2




goom module: <module 'goom' from '/oscar/home/lkozachk/code/goom/src/goom/__init__.py'>
Python executable: /users/lkozachk/.conda/envs/jax_goom/bin/python
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0]


goom module: <module 'goom' from '/oscar/home/lkozachk/code/goom/src/goom/__init__.py'>
Python executable: /users/lkozachk/.conda/envs/jax_goom/bin/python
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0]


In [2]:
print(jax.devices())

[CudaDevice(id=0)]


W1128 19:36:35.343932 1583948 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W1128 19:36:35.346959 1534013 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.


In [6]:
# --- Make a chain of matrices: shape (T, N, N) ------------------------
T, N = 10_000, 20
key = jax.random.PRNGKey(0)
mats = 20*jax.random.normal(key, (T, N, N), dtype=config.float_dtype)/(jnp.sqrt(N))


# --- 1. Parallel product over *float* tensors (usually blows up) ------
def matmul_op(a, b):
    # a, b: (..., N, N)
    return a @ b

# lax.associative_scan does a parallel prefix of an associative op
float_prefix_products = lax.associative_scan(matmul_op, mats, axis=0)

# Final product of the whole chain:
float_prod = float_prefix_products[-1]

print("Computes over float tensors?",
      bool(jnp.isfinite(float_prod).all()))


# --- 2. Parallel product over GOOMs -----------------------------------
# Turn each matrix into its GOOM "log" representation
log_mats = goom.goom.to_goom(mats)   # shape (T, N, N), complex dtype


# Parallel scan with the GOOM matmul log kernel
log_prefix_products = lax.associative_scan(lmme.log_matmul_exp, log_mats, axis=0)

# Final log-product of the whole chain in GOOM space:
log_prod = log_prefix_products[-1]

# Map back from GOOMs to (approximate) real-space product
prod_goom = goom.goom.from_goom(log_prod)

print("Computes over complex GOOMs?",
      bool(jnp.isfinite(log_prod).all()))


Computes over float tensors? False
Computes over complex GOOMs? True


In [7]:
# compare LLE computed in paralle to sequential algorithm
lle_est_par = lle.jax_estimate_lle_parallel(mats, key, dt=1.0)
lle_est_seq = lle.jax_estimate_lle_sequential(mats, key, dt=1.0)

print("Parallel LLE vs Sequential LLE Close?:",
      bool(jnp.isclose(lle_est_par, lle_est_seq, atol=1e-3, rtol=1e-3)))

Parallel LLE vs Sequential LLE Close?: True
