# Example usage of the u_tensor_algebra.py

Suppose we want to compute the expression of the unbiased estimator version of

$$\sum_{i_1, i_2} \sum_{a_1,a_2} X_{i_1,a_1}X_{i_2,a_1}X_{i_1,a_2}X_{i_2,a_2} $$

which, in this example, is the second spectral moment.
The unbiased version can be computed by

$$\sum_{i_1 \neq i_2} \sum_{a_1 \neq a_2} X_{i_1,a_1}X_{i_2,a_1}X_{i_1,a_2}X_{i_2,a_2} $$

Calling get_unbiased_einsums will derive the expression of the above sum in terms of "regular sums" (i.e. no disjoint-index constraint).
The expression will be written in the form of einsum notation

The expression can then be provided to compute_estimate, along with an actual dataset, to compute the unbiased estimate.


By Chanwoo Chun, Mar. 2025

In [9]:
import u_tensor_algebra as uta

# Define matrix names and their indices in summation.
all_indices   = {
    'X1': ('i1','a1'),
    'X2': ('i2','a1'),
    'X3': ('i2','a2'),
    'X4': ('i1','a2')}

# The instruction below means that we want to center all columns individually.
centerings = {
    'X1': ('','c'),
    'X2': ('','c'),
    'X3': ('','c'),
    'X4': ('','c')}

dist_groups = [('i1','i2'),('a1','a2')]

# Compute the expression of the unbiased estimator
estimator_formula = uta.get_unbiased_einsums(all_indices,centerings,dist_groups)


Distinct indices:  [('0', '2', '3', '5'), ('1', '4')]
Add or Subtract:  1
Formula:  {'X1': ('0', '1'), 'X2': ('2', '1'), 'X3': ('3', '4'), 'X4': ('5', '4')}

Distinct indices:  [('0', '2', '4'), ('1', '3')]
Add or Subtract:  -1
Formula:  {'X1': ('0', '1'), 'X2': ('2', '1'), 'X3': ('2', '3'), 'X4': ('4', '3')}

Distinct indices:  [('0', '2', '3'), ('1', '4')]
Add or Subtract:  -1
Formula:  {'X1': ('0', '1'), 'X2': ('2', '1'), 'X3': ('3', '4'), 'X4': ('0', '4')}

Distinct indices:  [('0', '2'), ('1', '3')]
Add or Subtract:  1
Formula:  {'X1': ('0', '1'), 'X2': ('2', '1'), 'X3': ('2', '3'), 'X4': ('0', '3')}



In [10]:
import numpy as np
import jax.numpy as jnp

np.random.seed(10)

P=150
Q=200
d=4

Xin = jnp.array(np.random.randn(P,d))
W = jnp.array(np.random.randn(d,Q))
W2 = jnp.array(np.random.randn(d,Q))
Xa = jnp.square(jnp.matmul(Xin,W)/d)
Xb = jnp.square(jnp.matmul(Xin,W2)/d)

# Now assign the data matrix (or matrices) to the matrix names. In this case, all matrix names get the same matrix.
factor_data = { 'X1': Xa, 'X2': Xa, 'X3': Xa, 'X4': Xa} 

# Compute the estimate
estimate = uta.compute_estimate(factor_data, estimator_formula)


{'X1': (150, 200), 'X2': (150, 200), 'X3': (150, 200), 'X4': (150, 200)}

0.0037856258757100747
-0.005984057701624579
-0.005984057701624579
0.010373302546213381

Total number of terms:  164
Estimate:  0.002190813018674298
