# goom tests
Use this notebook to run pytest from the repo root with the `goom_jax` conda env active.

In [1]:
!pwd
!git clone https://github.com/hkozachkov/goom.git

/content
fatal: destination path 'goom' already exists and is not an empty directory.


In [4]:
%cd goom
%pip install -e .
!pytest


[Errno 2] No such file or directory: 'goom'
/content/goom
Obtaining file:///content/goom
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: goom
  Building editable for goom (pyproject.toml) ... [?25l[?25hdone
  Created wheel for goom: filename=goom-0.1.0-py3-none-any.whl size=2158 sha256=153af367e8615cbbb659647d3fcb66ad8476de13e6bbd569b241d77cd738b67c
  Stored in directory: /tmp/pip-ephem-wheel-cache-nwtgs328/wheels/a3/25/cc/3cf8fcc9e6f3c09d274a4bd56f0a22b3af69533fc9ac1cb798
Successfully built goom
Installing collected packages: goom
  Attempting uninstall: goom
    Found existing installation: goom 0.1.0
    Uninstalling goom-0.1.0:
      Successfully uninstalled goom-0.1.0
Successfully 

In [1]:
import jax
import jax.numpy as jnp
from jax import lax

import goom.goom as goom
import goom.lmme as lmme
import goom.operations as oprs
import goom.utils as utils
import goom.lle as lle

config = goom.config  # grab the config object from the goom module

config.keep_logs_finite = True          # log(0) will return a finite floor
config.cast_all_logs_to_complex = True  # GOOMs are complex-typed
config.float_dtype = jnp.float32        # real dtype



%load_ext autoreload
%autoreload 2




ModuleNotFoundError: No module named 'goom.utils'

In [None]:
# --- Make a chain of matrices: shape (T, N, N) ------------------------
T, N = 10_000, 20
key = jax.random.PRNGKey(0)
mats = 20*jax.random.normal(key, (T, N, N), dtype=config.float_dtype)/(jnp.sqrt(N))


# --- 1. Parallel product over *float* tensors (usually blows up) ------
def matmul_op(a, b):
    # a, b: (..., N, N)
    return a @ b

# lax.associative_scan does a parallel prefix of an associative op
float_prefix_products = lax.associative_scan(matmul_op, mats, axis=0)

# Final product of the whole chain:
float_prod = float_prefix_products[-1]

print("Computes over float tensors?",
      bool(jnp.isfinite(float_prod).all()))


# --- 2. Parallel product over GOOMs -----------------------------------
# Turn each matrix into its GOOM "log" representation
log_mats = goom.to_goom(mats)   # shape (T, N, N), complex dtype


# Parallel scan with the GOOM matmul log kernel
log_prefix_products = lax.associative_scan(lmme.log_matmul_exp, log_mats, axis=0)

# Final log-product of the whole chain in GOOM space:
log_prod = log_prefix_products[-1]

# Map back from GOOMs to (approximate) real-space product
prod_goom = goom.from_goom(log_prod)

print("Computes over complex GOOMs?",
      bool(jnp.isfinite(log_prod).all()))


Computes over float tensors? False
Computes over complex GOOMs? True


In [None]:
# compare LLE computed in paralle to sequential algorithm
lle_est_par = lle.jax_estimate_lle_parallel(mats, key, dt=1.0)
lle_est_seq = lle.jax_estimate_lle_sequential(mats, key, dt=1.0)
lle_ratio = lle_est_par/lle_est_seq



print("Parallel LLE vs Sequential LLE Close?:",
      bool(jnp.isclose(lle_ratio, 1.0, atol=1e-3, rtol=0.0)))

Parallel LLE vs Sequential LLE Close?: True
