<a href="https://colab.research.google.com/github/jcandane/cxfel_work/blob/main/JAX_N2_dist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp

key  = jax.random.key(10)

R_ix = jax.random.uniform(key, (12,8))

## the big note is that parallelization cannot all be done in tandem else, memory accolation is too much (5/26/24)

### therefore instead have an iterative-loop over a parallel-loop. e.g. scan + vmap

In [None]:
try:
    import gpjax as gpx
except:
    !pip install gpjax==0.8.2
    import gpjax as gpx

In [None]:
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.gps.Prior(mean_function = mean, kernel = kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints = 123)

posterior = prior * likelihood

### row-wise algorithm

In [None]:
import jax
import jax.numpy as jnp

## kernel computation
@jax.jit
def kernel(v_x, w_x):
    """
    this computes the each of the covariance/kernel's matrix-elements
    given two vectors over the feature/pixel-space
    """
    #return jnp.linalg.norm(v_x - w_x)
    ### 1.2 is a hyperparameter named length-scale
    return jnp.exp( - (jnp.linalg.norm(v_x - w_x)/1.2)**2 )
kernel_matrixelement = jax.vmap(kernel, in_axes=(None, 0)) ## over all Ds

@jax.jit
def covariance_vector(v_x, R_ix):
    """
    covariance(v_x, R_ix) = Σ_i
    """
    return kernel_matrixelement(v_x, R_ix) ## over all is
covariance_matrix = jax.vmap(covariance_vector, in_axes=(0, None)) ## over all js

### Example R_ix : 2D-jax.Array
R_ix = jax.random.uniform(key, (120,80))

### distance matrix
R_ij = covariance_matrix(R_ix, R_ix)

In [None]:
R_ix = jax.random.uniform(key, (12000,80))
%timeit R_ij = covariance_matrix(R_ix, R_ix)

22.4 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
### PYP DATASET:: 152677 = snapshots & 21556 = pixels
R_ix = jax.random.uniform(key, (6000,21556))
covariance_matrix(R_ix, R_ix)

Array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)

In [None]:
### PYP DATASET:: 152677 = snapshots & 21556 = pixels
R_ix = jax.random.uniform(key, (2000,21556))
%timeit covariance_matrix(R_ix, R_ix)

256 ms ± 11.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
100000*100000*4/1024/1024/1024, "GB"

(37.25290298461914, 'GB')

complaint from https://stackoverflow.com/questions/76109349/high-memory-consumption-in-jax-with-nested-vmap

In [None]:
import jax
import jax.numpy as jnp
from jax import lax

# Kernel computation
@jax.jit
def kernel(v_x, w_x):
    """
    This computes each of the covariance/kernel's matrix-elements
    given two vectors over the feature/pixel-space.
    """
    # Hyperparameter named length-scale
    length_scale = 1.2
    return jnp.exp(- (jnp.linalg.norm(v_x - w_x) / length_scale) ** 2)

kernel_matrixelement = jax.vmap(kernel, in_axes=(None, 0))  # over all Ds

@jax.jit
def covariance_vector(v_x, R_ix):
    """
    covariance(v_x, R_ix) = Σ_i
    """
    return kernel_matrixelement(v_x, R_ix)  # over all is

# Define the function to be used in the fori_loop
def covariance_matrix_element(i, R_ij, R_ix):
    R_ij = R_ij.at[i].set(covariance_vector(R_ix[i], R_ix))
    return R_ij

@jax.jit
def covariance_matrix(R_ix):
    N = R_ix.shape[0]
    R_ij = jnp.zeros((N, N))
    R_ij = lax.fori_loop(0, N, lambda i, R_ij: covariance_matrix_element(i, R_ij, R_ix), R_ij)
    return R_ij

# Example R_ix : 2D-jax.Array
key = jax.random.PRNGKey(0)
R_ix = jax.random.uniform(key, (120, 80))

# Compute the distance matrix
R_ij = covariance_matrix(R_ix)

print("R_ix:\n", R_ix)
print(R_ij.shape)


R_ix:
 [[0.0176878  0.5916569  0.89398885 ... 0.6814302  0.29505873 0.8417723 ]
 [0.48681045 0.28730118 0.67160344 ... 0.37929893 0.09927273 0.3082807 ]
 [0.81797326 0.22183383 0.21658444 ... 0.4808247  0.70558167 0.2004193 ]
 ...
 [0.5076928  0.6204833  0.5446193  ... 0.3674395  0.54845095 0.643455  ]
 [0.9651996  0.5978706  0.03186083 ... 0.84701216 0.69112897 0.52002454]
 [0.18267298 0.7189157  0.01269352 ... 0.7409458  0.36291182 0.47115946]]
(120, 120)


In [None]:
R_ix = jax.random.uniform(key, (40000,21556))
%timeit covariance_matrix(R_ix)

Array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)



1.   $N^2$ with covaraince-matrix pure-dense
2.   $N^2$ with covaraince-matrix dense-sparse
3.   $N^2$ with covaraince-matrix pure-sparse

If $\sim N^2 D + N^2$ Operations (1.5E5)^2 2E4 = 4.5(E+14) /



### using jax.scan

In [None]:
### dense N^2 distance/covariance-matrix calculation in jax by j. candanedo

import jax
import jax.numpy as jnp

@jax.jit
def kernel(v_x, w_x, length_scale = 1.2):
    """ Kernel Computation
    This computes each of the covariance/kernel's matrix-elements
    given two vectors over the feature/pixel-space.
    """
    return jnp.exp(- (jnp.linalg.norm(v_x - w_x) / length_scale) ** 2)
kernel_matrixelement = jax.vmap(kernel, in_axes=(None, 0))  # over all "D"s

@jax.jit
def covariance_vector(v_x, R_ix):
    """ compute 1 row of the covariance-matrix
    covariance_vector(v_x, R_ix) = Σ_i
    """
    return kernel_matrixelement(v_x, R_ix)  # over all "i"s

@jax.jit
def update_covariance_matrix(carry, i):
    """
    Given the entire covariance-matrix, update the "i"th row
    with the covariance_vector function
    """
    R_ij, R_ix = carry
    R_ij = R_ij.at[i].set(covariance_vector(R_ix[i], R_ix))
    return (R_ij, R_ix), None

@jax.jit
def covariance_matrix(R_ix):
    """ compute the entire covariance-matrix
    GIVEN > R_ix
    GET > Σ_ij
    """
    N    = R_ix.shape[0]
    Σ_ij = jnp.zeros((N, N)) ### define empty covaraince-matrix
    (Σ_ij, _), _ = jax.lax.scan(update_covariance_matrix, (Σ_ij, R_ix), jnp.arange(N))
    return Σ_ij

###################
@jax.jit
def covariance_matrixX(R_ix, R_jx):
    """ compute the entire covariance-matrix
    GIVEN > R_ix, R_jx : jnp.ndarray
    GET > Σ_ij
    """
    N    = R_ix.shape[0]
    M    = R_jx.shape[0]
    R_ij = jnp.zeros((N, M))
    (R_ij, _), _ = jax.lax.scan(update_covariance_matrix, (R_ij, R_ix), jnp.arange(N))
    return R_ij

## !! perhaps benchmark just initializing the array, and then mutating it. time this to see if this matches with theoritical FLOPs

### example

In [None]:
key  = jax.random.PRNGKey(7)
R_ix = jax.random.uniform(key, (120, 80)) ## : R_ix : 2D-jax.Array
R_ij = covariance_matrix(R_ix) # Compute the distance matrix

In [None]:
key  = jax.random.PRNGKey(7)
R_ix = jax.random.uniform(key, (40000,21556)) ### 101s on A100 for R_ix.shape = (40000,21556)
R_ij = covariance_matrix(R_ix)

In [None]:
R_ij.shape

(40000, 40000)

### another code

In [None]:
import jax
import jax.numpy as jnp
from jax import lax

# Kernel computation
@jax.jit
def kernel(v_x, w_x):
    """
    This computes each of the covariance/kernel's matrix-elements
    given two vectors over the feature/pixel-space.
    """
    # Hyperparameter named length-scale
    length_scale = 1.2
    return jnp.exp(- (jnp.linalg.norm(v_x - w_x) / length_scale) ** 2)

kernel_matrixelement = jax.vmap(kernel, in_axes=(None, 0))  # over all Ds

@jax.jit
def covariance_vector(v_x, R_ix):
    """
    covariance(v_x, R_ix) = Σ_i
    """
    return kernel_matrixelement(v_x, R_ix)  # over all is

# Function to update each row of the covariance matrix
def update_covariance_matrix(i, R_ij, R_ix):
    R_ij = R_ij.at[i].set(covariance_vector(R_ix[i], R_ix))
    return R_ij

@jax.jit
def covariance_matrix(R_ix):
    N = R_ix.shape[0]
    R_ij = jnp.zeros((N, N))
    R_ij = lax.fori_loop(0, N, update_covariance_matrix, (i, R_ij, R_ix))
    return R_ij[0]

# Example R_ix : 2D-jax.Array
key = jax.random.PRNGKey(0)
R_ix = jax.random.uniform(key, (120, 80))

# Compute the distance matrix
R_ij = covariance_matrix(R_ix)

print("R_ix:\n", R_ix)
print("Distance matrix R_ij:\n", R_ij)


NameError: name 'i' is not defined

### all element-wise algorithm

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, vmap

# Function to compute the pairwise distance between two vectors
@jit
def pairwise_distance(vec1, vec2):
    return jnp.linalg.norm(vec1 - vec2)

# JIT-compile and vectorize the pairwise distance function for efficiency
pairwise_distance_vmap = vmap(vmap(pairwise_distance, in_axes=(None, 0)), in_axes=(0, None))

def compute_distance_matrix(R):
    # Compute the full distance matrix
    distance_matrix = pairwise_distance_vmap(R, R)
    return distance_matrix

# Example 2D JAX array R_ix
R_ix = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Compute the distance matrix using vmap over all elements
R_ij = compute_distance_matrix(R_ix)

print("R_ix:\n", R_ix)
print("Distance matrix R_ij:\n", R_ij)


R_ix:
 [[1. 2.]
 [3. 4.]
 [5. 6.]]
Distance matrix R_ij:
 [[0.       2.828427 5.656854]
 [2.828427 0.       2.828427]
 [5.656854 2.828427 0.      ]]


In [None]:
R_ix = jax.random.uniform(key, (8000,21556))
l = compute_distance_matrix(R_ix)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 39424000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.28GiB
              constant allocation:         0B
        maybe_live_out allocation:  244.14MiB
     preallocated temp allocation:   36.72GiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:   38.24GiB
              total fragmentation:        12B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 36.72GiB
		Operator: op_name="jit(pairwise_distance)/jit(main)/jit(norm)/reduce_sum[axes=(2,)]" source_file="<ipython-input-2-95acb946888c>" source_line=13
		XLA Label: fusion
		Shape: f32[8000,8000,154]
		==========================

	Buffer 2:
		Size: 657.84MiB
		Entry Parameter Subshape: f32[8000,21556]
		==========================

	Buffer 3:
		Size: 657.84MiB
		Entry Parameter Subshape: f32[8000,21556]
		==========================

	Buffer 4:
		Size: 244.14MiB
		Operator: op_name="jit(pairwise_distance)/jit(main)/jit(norm)/sqrt" source_file="<ipython-input-2-95acb946888c>" source_line=13
		XLA Label: fusion
		Shape: f32[8000,8000]
		==========================

	Buffer 5:
		Size: 4B
		XLA Label: parameter
		Shape: f32[]
		==========================

	Buffer 6:
		Size: 4B
		XLA Label: parameter
		Shape: f32[]
		==========================

	Buffer 7:
		Size: 4B
		Operator: op_name="jit(pairwise_distance)/jit(main)/jit(norm)/reduce_sum[axes=(2,)]" source_file="<ipython-input-2-95acb946888c>" source_line=13
		XLA Label: add
		Shape: f32[]
		==========================

