In [1]:
import einops

from functools import partial
from itertools import product
import numpy as np
from pathlib import Path
from plotnine import (
    ggplot,
    geom_point, 
    geom_histogram, 
    geom_line,
    geom_ribbon,
    qplot, 
    coord_fixed, 
    aes, 
    facet_wrap, 
    labs,
    scale_x_log10,
    scale_y_log10
)
import polars as pl
import torch

from tokengrams import MemmapIndex, InMemoryIndex
from tqdm.notebook import tqdm, trange
from transformers import AutoTokenizer
from transformer_lens import HookedTransformerConfig

from ngram_markov.hooked_transformer import HookedTransformer
from torch.nn.functional import softmax, log_softmax


import collections
from collections import defaultdict
from itertools import islice
import numpy as np
import rustworkx as rx


from ngram_markov.model import GPT, GPTConfig
from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
import einops
import torch
import plotly.express as px



In [None]:
import numpy as np
from tqdm import tqdm

def efron_stein_decomposition(tensor):
    """
    Perform Efron-Stein decomposition on a 3D tensor representing a single output logit.
    
    :param tensor: numpy array of shape (512, 512, 512)
    :return: tuple of (zeroth_order, first_order, second_order, third_order)
    """
    # Zeroth-order effect (mean)
    zeroth_order = np.mean(tensor)
    
    # First-order effects
    first_order = {
        0: np.mean(tensor, axis=(1, 2)) - zeroth_order,
        1: np.mean(tensor, axis=(0, 2)) - zeroth_order,
        2: np.mean(tensor, axis=(0, 1)) - zeroth_order
    }
    
    # Second-order effects
    second_order = {
        (0, 1): np.mean(tensor, axis=2) - first_order[0][:, np.newaxis] - first_order[1][np.newaxis, :] - zeroth_order,
        (0, 2): np.mean(tensor, axis=1) - first_order[0][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order,
        (1, 2): np.mean(tensor, axis=0) - first_order[1][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order
    }
    
    # Third-order effect
    third_order = (tensor - 
                   zeroth_order - 
                   first_order[0][:, np.newaxis, np.newaxis] -
                   first_order[1][np.newaxis, :, np.newaxis] -
                   first_order[2][np.newaxis, np.newaxis, :] -
                   second_order[(0, 1)][:, :, np.newaxis] -
                   second_order[(0, 2)][:, np.newaxis, :] -
                   second_order[(1, 2)][np.newaxis, :, :])
    
    return zeroth_order, first_order, second_order, third_order

def check_orthogonality(first, second, third, tolerance=1e-6):
    def expand_to_full(component, order, shape):
        full = np.zeros(shape)
        new_dims = tuple([i for i in range(3) if i not in order])
        if len(order) < len(shape):
            full += np.expand_dims(component, axis=new_dims)
        else:
            full = component
        return full

    shape = (512, 512, 512)
    components = [
        expand_to_full(first[0], (0,), shape),
        expand_to_full(first[1], (1,), shape),
        expand_to_full(first[2], (2,), shape),
        expand_to_full(second[(0,1)], (0,1), shape),
        expand_to_full(second[(0,2)], (0,2), shape),
        expand_to_full(second[(1,2)], (1,2), shape),
        third
    ]
    
    # Flatten each component
    flat_components = [comp.flatten() for comp in components]
    
    for i in tqdm(range(len(flat_components)), desc="Checking orthogonality"):
        for j in range(i+1, len(flat_components)):
            dot_product = np.dot(flat_components[i], flat_components[j])
            assert np.abs(dot_product) < tolerance, f"Components {i} and {j} are not orthogonal. Dot product: {dot_product}"
    
    print("All components are orthogonal in the function space.")

def check_reconstruction(tensor, zeroth_order, first_order, second_order, third_order):
    reconstructed = (zeroth_order + 
                     first_order[0][:, np.newaxis, np.newaxis] +
                     first_order[1][np.newaxis, :, np.newaxis] +
                     first_order[2][np.newaxis, np.newaxis, :] +
                     second_order[(0,1)][:, :, np.newaxis] +
                     second_order[(0,2)][:, np.newaxis, :] +
                     second_order[(1,2)][np.newaxis, :, :] +
                     third_order)
    max_error = np.max(np.abs(tensor - reconstructed))
    print(f"Maximum reconstruction error: {max_error}")
    assert np.allclose(tensor, reconstructed), "Reconstruction failed"

def check_variances(tensor, first_order, second_order, third_order):
    total_var = np.var(tensor)
    component_vars = (np.var(first_order[0]) + np.var(first_order[1]) + np.var(first_order[2]) +
                      np.var(second_order[(0,1)]) + np.var(second_order[(0,2)]) + np.var(second_order[(1,2)]) +
                      np.var(third_order))
    print(f"Total variance: {total_var}")
    print(f"Sum of component variances: {component_vars}")
    assert np.allclose(total_var, component_vars), "Variance decomposition failed"


ngram_n = 3
epoch = 53_000
num_tokens = 512

mm_array = np.memmap(
    f'ngram_{ngram_n}_outputs_epoch_{epoch}.npy',
    dtype='float32', 
    mode='r',
    shape=(num_tokens, num_tokens, num_tokens, num_tokens)
)

data = np.copy(mm_array[..., 2])


# Perform Efron-Stein decomposition
zeroth_order, first_order, second_order, third_order = efron_stein_decomposition(data)

# Run checks
check_orthogonality(first_order, second_order, third_order)
check_reconstruction(data, zeroth_order, first_order, second_order, third_order)
check_variances(data, first_order, second_order, third_order)

print("All checks passed successfully!")

In [38]:
zo_ground_truth =  -8. / 9

fo_ground_truth = {
    0: np.array([-1., -1., 0.]),
    1: np.array([0., -1., -1.]),
    2: np.array([-1., 0., -1.]),
}

so_ground_truth = {
    (0, 1): np.eye(3) * 4.,
    (0, 2): np.zeros((3, 3)),
    (1, 2): np.eye(3) * 4.
    
}

to_ground_truth = np.zeros((3, 3, 3))
to_ground_truth[0, 1, 2] = 2.
to_ground_truth[2, 0, 1] = 2.
to_ground_truth[1, 2, 0] = 2.


In [49]:
ground_truth = ( 
    fo_ground_truth[0][:, None, None] + fo_ground_truth[1][None, :, None] + fo_ground_truth[2][None, None, :] + 
    so_ground_truth[(0,1)][:, :, None] + so_ground_truth[(0,2)][:, None, :] + so_ground_truth[(1,2)][None, :, :] + 
    to_ground_truth
)

In [51]:
np.allclose(ground_truth, test_tensor)

True

In [89]:
import numpy as np
import itertools


zo_ground_truth = 8. / 9

fo_ground_truth = {
    0: np.array([-1., -1., 0.]),
    1: np.array([0., -1., -1.]),
    2: np.array([-1., 0., -1.])
}

so_ground_truth = {
    (0, 1): np.eye(3) * 4.,
    (0, 2): np.zeros((3, 3)),
    (1, 2): np.eye(3) * 4.
    
}

to_ground_truth = np.zeros((3, 3, 3))
to_ground_truth[0, 1, 2] = 2.
to_ground_truth[2, 0, 1] = 2.
to_ground_truth[1, 2, 0] = 2.

ground_truth = (
    fo_ground_truth[0][:, None, None] + fo_ground_truth[1][None, :, None] + fo_ground_truth[2][None, None, :] + 
    so_ground_truth[(0,1)][:, :, None] + so_ground_truth[(0,2)][:, None, :] + so_ground_truth[(1,2)][None, :, :] + 
    to_ground_truth
)


def simple_discrete_function(x0, x1, x2):
    """
    A simple function with discrete inputs (0 or 1 for each position).
    This function includes main effects, two-way interactions, and a three-way interaction.
    """
    first_order = fo_ground_truth[0][x0] + fo_ground_truth[1][x1] + fo_ground_truth[2][x2]
    second_order = so_ground_truth[(0, 1)][x0, x1] + so_ground_truth[(0, 2)][x0, x1] + so_ground_truth[(1, 2)][x1, x2]
    third_order = to_ground_truth[x0, x1, x2]
    return first_order + second_order + third_order

def create_test_tensor(size=3):
    """
    Create a test tensor using all possible combinations of 0 and 1 for each position.
    """
    tensor = np.zeros((size, size, size))
    for x, y, z in itertools.product(range(size), repeat=3):
        tensor[x, y, z] = simple_discrete_function(x, y, z)
    return tensor

def efron_stein_decomposition(tensor):
    # Implement your Efron-Stein decomposition here
    # This is where you'll put your actual implementation
    pass


def check_means(zeroth, first, second, third):
    assert np.allclose(np.mean(first[0]), 0)
    assert np.allclose(np.mean(first[1]), 0)
    assert np.allclose(np.mean(first[2]), 0)
    assert np.allclose(np.mean(second[(0,1)]), 0)
    assert np.allclose(np.mean(second[(0,2)]), 0)
    assert np.allclose(np.mean(second[(1,2)]), 0)
    assert np.allclose(np.mean(third), 0)
    print("All component means (except zeroth-order) are zero.")


def expand_to_full(component, order, shape):
    full = np.zeros(shape)
    new_dims = tuple([i for i in range(3) if i not in order])
    if len(order) < len(shape):
        full += np.expand_dims(component, axis=new_dims)
    else:
        full = component
    return full



def check_orthogonality(first, second, third, tolerance=1e-6):
  

    shape = (3, 3, 3)
    components = [
        expand_to_full(first[0], (0,), shape),
        expand_to_full(first[1], (1,), shape),
        expand_to_full(first[2], (2,), shape),
        expand_to_full(second[(0,1)], (0,1), shape),
        expand_to_full(second[(0,2)], (0,2), shape),
        expand_to_full(second[(1,2)], (1,2), shape),
        third
    ]
    
    # Flatten each component
    flat_components = [comp.flatten() for comp in components]
    
    for i in range(len(flat_components)):
        for j in range(i+1, len(flat_components)):
            dot_product = np.dot(flat_components[i], flat_components[j])
            assert np.abs(dot_product) < tolerance, f"Components {i} and {j} are not orthogonal. Dot product: {dot_product}"
    
    print("All components are orthogonal in the function space.")



def check_reconstruction(tensor, zeroth, first, second, third):
    reconstructed = (zeroth + 
                     first[0][:, None, None] + first[1][None, :, None] + first[2][None, None, :] + 
                     second[(0,1)][:, :, None] + second[(0,2)][:, None, :] + second[(1,2)][None, :, :] + 
                     third)
    assert np.allclose(tensor, reconstructed), "Reconstruction failed"
    print("Reconstruction successful.")


def check_variances(tensor, zeroth, first, second, third):
    total_var = np.var(tensor)
    component_vars = (np.var(first[0]) + np.var(first[1]) + np.var(first[2]) +
                      np.var(second[(0,1)]) + np.var(second[(0,2)]) + np.var(second[(1,2)]) +
                      np.var(third))
    assert np.allclose(total_var, component_vars), "Variance decomposition failed"
    print("Variance decomposition successful.")

def manual_calculation(tensor):
    """
    Manually calculate the Efron-Stein decomposition for verification.
    """
    mean = np.mean(tensor)
    
    first_order = {
        0: np.mean(tensor, axis=(1, 2)) - mean,
        1: np.mean(tensor, axis=(0, 2)) - mean,
        2: np.mean(tensor, axis=(0, 1)) - mean
    }
    
    second_order = {
        (0, 1): np.mean(tensor, axis=2) - first_order[0][:, np.newaxis] - first_order[1][np.newaxis, :] - mean,
        (0, 2): np.mean(tensor, axis=1) - first_order[0][:, np.newaxis] - first_order[2][np.newaxis, :] - mean,
        (1, 2): np.mean(tensor, axis=0) - first_order[1][:, np.newaxis] - first_order[2][np.newaxis, :] - mean
    }
    
    third_order = (tensor - 
                   mean - 
                   first_order[0][:, np.newaxis, np.newaxis] -
                   first_order[1][np.newaxis, :, np.newaxis] -
                   first_order[2][np.newaxis, np.newaxis, :] -
                   second_order[(0, 1)][:, :, np.newaxis] -
                   second_order[(0, 2)][:, np.newaxis, :] -
                   second_order[(1, 2)][np.newaxis, :, :])
    
    return mean, first_order, second_order, third_order

# Run tests
test_tensor = create_test_tensor()

# Your implementation
#zeroth, first, second, third = efron_stein_decomposition(test_tensor)

# Manual calculation for verification
zeroth, first, second, third = manual_calculation(test_tensor)



print("Implementation matches manual calculation.")

# These all pass except for the orthogonality one, but I actually am not sure that's implemented correctly. Don't worry about it.
check_means(zeroth, first, second, third)
check_orthogonality(first, second, third) # This fails
check_reconstruction(test_tensor, zeroth, first, second, third)
check_variances(test_tensor, zeroth, first, second, third)


# Compare implementation with "ground truth"
assert np.allclose(ground_truth, test_tensor) # we implemented "basic function" correctly

assert np.allclose(zeroth, zo_ground_truth), "Zeroth-order effect mismatch"

for k in first.keys():
    # This fails for an interesting normalization reason. The output of `first` is fo_ground_truth[i] / 3
    assert np.allclose(first[k], fo_ground_truth[k]), f"First-order effect mismatch for position {k}"

for k in second.keys():
    # this fails for a more complicated reason. I expect the output of second[(0, 2)] to be all zeros, and second[(0, 1)] to have
    # just 4s on the diagonal, but this is what they look like:
    # {(0,1): array([[ 2.44444444, -0.88888889, -1.55555556],
    #    [-1.55555556,  2.44444444, -0.88888889],
    #    [-0.88888889, -1.55555556,  2.44444444]]),
    # (0,2): array([[-0.22222222, -0.22222222,  0.44444444],
    #    [ 0.44444444, -0.22222222, -0.22222222],
    #    [-0.22222222,  0.44444444, -0.22222222]]),
    # (1,2): array([[ 2.44444444, -0.88888889, -1.55555556],
    #    [-1.55555556,  2.44444444, -0.88888889],
    #    [-0.88888889, -1.55555556,  2.44444444]])}
    assert np.allclose(second[k], manual_second[k]), f"Second-order effect mismatch for positions {k}"

# Haven't really looked at this yet
assert np.allclose(third, manual_third), "Third-order effect mismatch"

print("All tests passed successfully!")

Implementation matches manual calculation.
All component means (except zeroth-order) are zero.
All components are orthogonal in the function space.
Reconstruction successful.
Variance decomposition successful.


AssertionError: First-order effect mismatch for position 0

In [83]:
full = np.zeros((3, 3, 3))
test1 = np.ones((3,))
test2 = np.ones((3, 3))


In [87]:
np.expand_dims(test1, axis=(1, 2))

array([[[1.]],

       [[1.]],

       [[1.]]])

In [86]:
full[(slice(None), slice(None), 1]

SyntaxError: closing parenthesis ']' does not match opening parenthesis '(' (543767458.py, line 1)

In [71]:
shape = (3, 3, 3)
expand_to_full(first[0], (0,), shape)
        

array([-0.33333333, -0.33333333,  0.66666667])

In [72]:
expand_to_full(first[1], (1,), shape)


array([ 0.66666667, -0.33333333, -0.33333333])

In [48]:
so_ground_truth[(0, 1)]

array([[4., 0., 0.],
       [0., 4., 0.],
       [0., 0., 4.]])

In [2]:
ngram_n = 3
epoch = 53_000

mm_array = np.memmap(
    f'ngram_{ngram_n}_outputs_epoch_{epoch}.npy',
    dtype='float32', 
    mode='r',
    shape=(512, 512, 512, 512)
)

num_tokens = 512

In [4]:
#zeroth_order = mm_array.mean(axis=(0, 1, 2))
#first_order = {
#    0: np.mean(mm_array, axis=(1, 2)) - zeroth_order[np.newaxis, :],
#    1: np.mean(mm_array, axis=(0, 2)) - zeroth_order[np.newaxis, :],
#    2: np.mean(mm_array, axis=(1, 2)) - zeroth_order[np.newaxis, :]
#}
#second_order = {
#    (0, 1): np.mean(mm_array, axis=2) - first_order[0][:, np.newaxis] - first_order[1][np.newaxis, :] - zeroth_order[np.newaxis, np.newaxis, :],
#    (0, 2): np.mean(mm_array, axis=1) - first_order[0][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order[np.newaxis, np.newaxis, :],
#    (1, 2): np.mean(mm_array, axis=0) - first_order[1][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order[np.newaxis, np.newaxis, :]
#}

partial_es = torch.load('partial_efron_stein.pt')

zeroth_order = partial_es['zeroth_order'].numpy()
first_order = {k: v.numpy() for k, v in partial_es['first_order'].items()}
second_order = {k: v.numpy() for k, v in partial_es['second_order'].items()}

In [22]:
[np.var(v[..., 24], axis=(0, 1)) for v in second_order.values()]

[0.12062385, 3.783763, 4.334224]

In [14]:
third_order = np.memmap(
    'third_order_results.npy',
    dtype='float32', 
    mode='r',
    shape=(512, 512, 512, 512)
)

4.018879

In [17]:

np.var(mm_array[..., 24])

7.4358907

In [14]:
#torch.save({
#        'zeroth_order': torch.from_numpy(zeroth_order),
#        'first_order': {k: torch.from_numpy(v) for k, v in first_order.items()},
#        'second_order': {k: torch.from_numpy(v) for k, v in second_order.items()}
#}, 'partial_efron_stein.pt')

In [15]:
import numpy as np
from tqdm import tqdm

def calculate_third_order(mm_array, zeroth_order, first_order, second_order, output_file):
    # Create a memory-mapped array for the output
    third_order = np.memmap(output_file, dtype='float32', mode='w+', 
                            shape=(512, 512, 512, 512))

    # Process the data in chunks
    chunk_size = 32  # Adjust this based on your available memory
    for i in tqdm(range(0, 512, chunk_size)):
        for j in range(0, 512, chunk_size):
            for k in range(0, 512, chunk_size):
                # Load a chunk of the input array
                chunk = mm_array[i:i+chunk_size, j:j+chunk_size, k:k+chunk_size, :]
                
                # Perform the calculation on the chunk
                result = (
                    chunk
                    - zeroth_order[np.newaxis, np.newaxis, np.newaxis, :]
                    - first_order[0][i:i+chunk_size, np.newaxis, np.newaxis, :]
                    - first_order[1][np.newaxis, j:j+chunk_size, np.newaxis, :]
                    - first_order[2][np.newaxis, np.newaxis, k:k+chunk_size, :]
                    - second_order[(0, 1)][i:i+chunk_size, j:j+chunk_size, np.newaxis, :]
                    - second_order[(0, 2)][i:i+chunk_size, np.newaxis, k:k+chunk_size, :]
                    - second_order[(1, 2)][np.newaxis, j:j+chunk_size, k:k+chunk_size, :]
                )
                
                # Write the result to the memory-mapped array
                third_order[i:i+chunk_size, j:j+chunk_size, k:k+chunk_size, :] = result

    # Flush to ensure all data is written to disk
    third_order.flush()

    return third_order

# Usage
output_file = 'third_order_results.npy'
third_order = calculate_third_order(mm_array, zeroth_order, first_order, second_order, output_file)

# To load and use the results later:
# third_order = np.memmap('third_order_results.npy', dtype='float32', mode='r', shape=(512, 512, 512, 512))

100%|██████████████████████████████████████████████████████████████████████████████████████| 16/16 [16:24<00:00, 61.56s/it]


In [22]:

# Analysis
for i in range(3):
    print(f"First-order effect (dim {i}) mean magnitude: {np.abs(first_order[i]).mean()}")
for (i, j) in second_order:
    print(f"Second-order effect (dims {i},{j}) mean magnitude: {np.abs(second_order[(i,j)]).mean()}")
print(f"Third-order effect mean: {third_order.mean()}")

# Variance explained
total_var = np.var(mm_array)
first_var = sum(np.var(eff) for eff in first_order.values())
second_var = sum(np.var(eff) for eff in second_order.values())
third_var = np.var(third_order)

print("\nVariance explained:")
print(f"First order: {first_var / total_var * 100:.2f}%")
print(f"Second order: {second_var / total_var * 100:.2f}%")
print(f"Third order: {third_var / total_var * 100:.2f}%")

First-order effect (dim 0) mean magnitude: 3.24462890625
First-order effect (dim 1) mean magnitude: 3.2683374881744385
First-order effect (dim 2) mean magnitude: 3.24462890625
Second-order effect (dims 0,1) mean magnitude: 3.2460999488830566
Second-order effect (dims 0,2) mean magnitude: 3.7644736766815186
Second-order effect (dims 1,2) mean magnitude: 3.81758189201355
Third-order effect mean: -3.159119129180908


MemoryError: Unable to allocate 256. GiB for an array with shape (512, 512, 512, 512) and data type float32

In [None]:
token_vars = np.array([np.var(mm_array[..., i]) for i in range(512)])

In [11]:
third_order = (
    mm_array 
    - zeroth_order[np.newaxis, np.newaxis, np.newaxis, :]
    - first_order[0][:, np.newaxis, np.newaxis, :] # (cond_0, out) -> (cond_0, 1, 1, out)
    - first_order[1][np.newaxis, :, np.newaxis, :] # (cond_1, out) -> (1, cond_1, 1, out)
    - first_order[2][np.newaxis, np.newaxis, :, :] # (cond_2, out) -> (1, 1, cond_2, out)
    - second_order[(0, 1)][:, :, np.newaxis, :] # (cond_0, cond_1, out) -> (cond_0, cond_1, 1, out)
    - second_order[(0, 2)][:, np.newaxis, :, :]  # (cond_0, cond_2, out) -> (cond_0, 1, cond_2, out)
    - second_order[(1, 2)][np.newaxis, :, :, :]  # (cond_1, cond_2, out) -> (1, cond_1, cond_2, out)
)

In [16]:
#np.save('partial_efron_stein.npy', {"zeroth": zeroth_order, "first": first_order}, )


In [None]:
second_order = {
    (0, 1): np.mean(mm_array, axis=2) - first_order[0][:, np.new_axis] - first_order[1][np.newaxis, :] - zeroth_order,
    (0, 2): np.mean(mm_array, axis=1) - first_order[0][:, np.new_axis] - first_order[2][np.newaxis, :] - zeroth_order,
    (1, 2): np.mean(mm_array, axis=0) - first_order[1][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order
}

In [None]:
import numpy as np

def efron_stein_decomposition(mm_array):
    # Zeroth-order effect
    print(f'Calculating 0th Order')
    zeroth_order = np.mean(mm_array)

    print(f'Calculating 1st Order')
    # First-order effects
    first_order = {
        0: np.mean(mm_array, axis=(1, 2, 3)) - zeroth_order,
        1: np.mean(mm_array, axis=(0, 2, 3)) - zeroth_order,
        2: np.mean(mm_array, axis=(0, 1, 3)) - zeroth_order
    }

    print(f'Calculating 2nd Order')
    # Second-order effects
    second_order = {
        (0, 1): np.mean(mm_array, axis=(2, 3)) - first_order[0][:, np.newaxis] - first_order[1][np.newaxis, :] - zeroth_order,
        (0, 2): np.mean(mm_array, axis=(1, 3)) - first_order[0][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order,
        (1, 2): np.mean(mm_array, axis=(0, 3)) - first_order[1][:, np.newaxis] - first_order[2][np.newaxis, :] - zeroth_order
    }

    print(f'Calculating 3rd Order')
    # Third-order effects
    third_order = (
        mm_array 
        - zeroth_order
        - first_order[0][:, np.newaxis, np.newaxis, np.newaxis]
        - first_order[1][np.newaxis, :, np.newaxis, np.newaxis]
        - first_order[2][np.newaxis, np.newaxis, :, np.newaxis]
        - second_order[(0, 1)][:, :, np.newaxis, np.newaxis]
        - second_order[(0, 2)][:, np.newaxis, :, np.newaxis]
        - second_order[(1, 2)][np.newaxis, :, :, np.newaxis]
    )

    return zeroth_order, first_order, second_order, third_order

# Usage
ngram_n = 3
epoch = 53_000
mm_array = np.memmap(
    f'ngram_{ngram_n}_outputs_epoch_{epoch}.npy',
    dtype='float32', 
    mode='r',
    shape=(512, 512, 512, 512)
)
zeroth_order, first_order, second_order, third_order = efron_stein_decomposition(mm_array)

# Analysis
print(f"Zeroth-order effect: {zeroth_order}")
for i in range(3):
    print(f"First-order effect (dim {i}) mean magnitude: {np.abs(first_order[i]).mean()}")
for (i, j) in second_order:
    print(f"Second-order effect (dims {i},{j}) mean magnitude: {np.abs(second_order[(i,j)]).mean()}")
print(f"Third-order effect mean magnitude: {np.abs(third_order).mean()}")

# Variance explained
total_var = np.var(mm_array)
first_var = sum(np.var(eff) for eff in first_order.values())
second_var = sum(np.var(eff) for eff in second_order.values())
third_var = np.var(third_order)

print("\nVariance explained:")
print(f"First order: {first_var / total_var * 100:.2f}%")
print(f"Second order: {second_var / total_var * 100:.2f}%")
print(f"Third order: {third_var / total_var * 100:.2f}%")

Calculating 0th Order
Calculating 1st Order
Calculating 2nd Order


In [3]:
zeroth_order = data.mean(axis=0)
zeroth_order = torch.from_numpy(zeroth_order)

In [4]:
from tqdm import trange
#zeroth_order = torch.from_numpy(data.mean(axis=0))
ngrams = torch.cartesian_prod(torch.arange(512), torch.arange(512), torch.arange(512))
first_order = {}
for pos in [0, 1, 2]:
    effects = []
    for token in trange(512):
        indices = torch.argwhere(ngrams[:, pos] == token).squeeze().numpy()
        val = torch.from_numpy(data[indices]).mean(dim=0) - zeroth_order
        effects.append(val)
    first_order[pos] = torch.stack(effects, dim=0)

100%|████████████████████████████████████████████████████████████████████████████████████| 512/512 [06:47<00:00,  1.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 512/512 [08:39<00:00,  1.01s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 512/512 [17:22<00:00,  2.04s/it]


In [5]:
efron_stein = {'zeroth': zeroth_order, 'first': first_order}
torch.save(efron_stein, 'partial_efron_stein.pt')

In [55]:
import torch
import itertools

def generate_random_function_2d(n_tokens=5):
    ngrams = torch.cartesian_prod(torch.arange(n_tokens), torch.arange(n_tokens))
    fn = torch.randn(n_tokens ** 2)
    return ngrams, fn

def compute_efron_stein_2d(ngrams, fn, n_tokens=5):
    # 0th order effect (mean)
    zeroth_order = fn.mean()

    # 1st order effects
    first_order = {}
    for pos in range(2):
        first_order[pos] = torch.stack([
            fn[ngrams[:, pos] == token].mean() - zeroth_order
            for token in range(n_tokens)
        ])

    # 2nd order effect
    second_order = torch.zeros((n_tokens, n_tokens))
    for i in range(n_tokens):
        for j in range(n_tokens):
            mask = (ngrams[:, 0] == i) & (ngrams[:, 1] == j)
            second_order[i, j] = (
                fn[mask]  # This is a single value for 2D case
                - zeroth_order
                - first_order[0][i]
                - first_order[1][j]
            )

    return zeroth_order, first_order, second_order

def reconstruct_function_2d(zeroth_order, first_order, second_order, n_tokens=5):
    reconstructed = torch.zeros(n_tokens, n_tokens)
    for i in range(n_tokens):
        for j in range(n_tokens):
            reconstructed[i, j] = (
                zeroth_order
                + first_order[0][i]
                + first_order[1][j]
                + second_order[i, j]
            )
    return reconstructed.view(-1)

# Generate random function
n_tokens = 5
ngrams, fn = generate_random_function_2d(n_tokens)

# Compute Efron-Stein decomposition
zeroth_order, first_order, second_order = compute_efron_stein_2d(ngrams, fn, n_tokens)

# Reconstruct the function
reconstructed_fn = reconstruct_function_2d(zeroth_order, first_order, second_order, n_tokens)

# Verify reconstruction
error = torch.abs(fn - reconstructed_fn).max()
print(f"Max reconstruction error: {error.item()}")
assert torch.allclose(fn, reconstructed_fn, atol=1e-6), "Reconstruction failed!"

# Print components
print(f"Zeroth order effect: {zeroth_order.item()}")
print("First order effects:")
for pos in range(2):
    print(f"Position {pos}:\n{first_order[pos]}")
print("Second order effect:\n", second_order)

# Analyze relative importance
total_var = torch.var(fn)
first_var = sum(torch.var(eff) for eff in first_order.values())
second_var = torch.var(second_order)

print("\nVariance explained:")
print(f"First order: {first_var / total_var * 100:.2f}%")
print(f"Second order: {second_var / total_var * 100:.2f}%")

Max reconstruction error: 2.384185791015625e-07
Zeroth order effect: -0.06464104354381561
First order effects:
Position 0:
tensor([ 0.4549,  0.1557, -0.0581, -0.2491, -0.3034])
Position 1:
tensor([ 0.7062, -0.5175, -0.7466,  0.7363, -0.1783])
Second order effect:
 tensor([[ 0.3490,  1.0507, -0.4164, -0.0022, -0.9812],
        [-0.1368,  0.3021, -1.6845,  0.1970,  1.3221],
        [-0.1967, -0.2287,  1.2454, -0.9075,  0.0874],
        [ 0.5951, -0.8946,  0.8644,  1.2440, -1.8088],
        [-0.6106, -0.2295, -0.0089, -0.5314,  1.3804]])

Variance explained:
First order: 45.22%
Second order: 62.31%


In [64]:
import torch
import itertools

def generate_random_function_3d(n_tokens=5):
    ngrams = torch.cartesian_prod(torch.arange(n_tokens), torch.arange(n_tokens), torch.arange(n_tokens))
    fn = torch.randn(n_tokens ** 3)
    return ngrams, fn

def compute_efron_stein_3d(ngrams, fn, n_tokens=5):
    # 0th order effect (mean)
    zeroth_order = fn.mean()

    # 1st order effects
    first_order = {}
    for pos in range(3):
        first_order[pos] = torch.stack([
            fn[ngrams[:, pos] == token].mean() - zeroth_order
            for token in range(n_tokens)
        ])

    # 2nd order effects
    second_order = {}
    for pos1, pos2 in itertools.combinations(range(3), 2):
        effects = torch.zeros((n_tokens, n_tokens))
        for token1 in range(n_tokens):
            for token2 in range(n_tokens):
                mask = (ngrams[:, pos1] == token1) & (ngrams[:, pos2] == token2)
                effects[token1, token2] = (
                    fn[mask].mean() 
                    - zeroth_order
                    - first_order[pos1][token1] 
                    - first_order[pos2][token2]
                )
        second_order[(pos1, pos2)] = effects

    # 3rd order effect
    third_order = torch.zeros((n_tokens, n_tokens, n_tokens))
    for i in range(n_tokens):
        for j in range(n_tokens):
            for k in range(n_tokens):
                mask = (ngrams[:, 0] == i) & (ngrams[:, 1] == j) & (ngrams[:, 2] == k)
                third_order[i, j, k] = (
                    fn[mask].item()  # This is a single value for 3D case
                    - zeroth_order
                    - first_order[0][i] - first_order[1][j] - first_order[2][k]
                    - second_order[(0, 1)][i, j] - second_order[(0, 2)][i, k] - second_order[(1, 2)][j, k]
                )

    return zeroth_order, first_order, second_order, third_order

def reconstruct_function_3d(zeroth_order, first_order, second_order, third_order, n_tokens=5):
    reconstructed = torch.zeros(n_tokens, n_tokens, n_tokens)
    for i in range(n_tokens):
        for j in range(n_tokens):
            for k in range(n_tokens):
                reconstructed[i, j, k] = (
                    zeroth_order
                    + first_order[0][i] + first_order[1][j] + first_order[2][k]
                    + second_order[(0, 1)][i, j] + second_order[(0, 2)][i, k] + second_order[(1, 2)][j, k]
                    + third_order[i, j, k]
                )
    return reconstructed.view(-1)

# Generate random function
n_tokens = 5
ngrams, fn = generate_random_function_3d(n_tokens)

# Compute Efron-Stein decomposition
zeroth_order, first_order, second_order, third_order = compute_efron_stein_3d(ngrams, fn, n_tokens)

# Reconstruct the function
reconstructed_fn = reconstruct_function_3d(zeroth_order, first_order, second_order, third_order, n_tokens)

# Verify reconstruction
error = torch.abs(fn - reconstructed_fn).max()
print(f"Max reconstruction error: {error.item()}")
assert torch.allclose(fn, reconstructed_fn, atol=1e-6), "Reconstruction failed!"

# Print components
print(f"Zeroth order effect: {zeroth_order.item()}")
print("First order effects:")
for pos in range(3):
    print(f"Position {pos}:\n{first_order[pos]}")
print("Second order effects:")
for (pos1, pos2), effect in second_order.items():
    print(f"Positions {pos1}, {pos2}:\n{effect}")
print("Third order effect shape:", third_order.shape)

# Analyze relative importance
total_var = torch.var(fn)
first_var = sum(torch.var(eff) for eff in first_order.values())
second_var = sum(torch.var(eff.flatten()) for eff in second_order.values())
third_var = torch.var(third_order)

print("\nVariance explained:")
print(f"First order: {first_var / total_var * 100:.2f}%")
print(f"Second order: {second_var / total_var * 100:.2f}%")
print(f"Third order: {third_var / total_var * 100:.2f}%")

Max reconstruction error: 2.384185791015625e-07
Zeroth order effect: 0.16686588525772095
First order effects:
Position 0:
tensor([ 0.4146,  0.2782, -0.3489, -0.0337, -0.3103])
Position 1:
tensor([ 0.1551,  0.2297, -0.0873, -0.0148, -0.2827])
Position 2:
tensor([ 0.2148,  0.3035, -0.3975, -0.0563, -0.0646])
Second order effects:
Positions 0, 1:
tensor([[-0.2600,  0.2227,  0.2253,  0.3525, -0.5405],
        [-0.4240,  0.3028,  0.1430,  0.4904, -0.5122],
        [ 0.9624, -0.6138,  0.4313, -0.3964, -0.3835],
        [ 0.0895,  0.0603, -0.6111,  0.3487,  0.1126],
        [-0.3678,  0.0280, -0.1885, -0.7952,  1.3235]])
Positions 0, 2:
tensor([[-0.0020,  0.2337, -0.0996, -0.0786, -0.0536],
        [-0.1320, -0.6032, -0.2954,  0.6642,  0.3663],
        [-0.0080, -0.0310,  0.1330,  0.2046, -0.2986],
        [-0.2261,  0.1516,  0.0728, -0.2234,  0.2251],
        [ 0.3681,  0.2489,  0.1892, -0.5668, -0.2393]])
Positions 1, 2:
tensor([[ 0.4966, -0.4041, -0.2700,  0.1109,  0.0666],
        [ 0.245

In [11]:
n_tokens = 5
ngrams = torch.cartesian_prod(torch.arange(n_tokens), torch.arange(n_tokens), torch.arange(n_tokens))

In [16]:
f = torch.randn(n_tokens ** 3)

In [27]:
view1 = torch.stack([f[ngrams[:, 1] == i] for i in range(n_tokens)], dim=0)

In [34]:
[torch.argwhere(ngrams[:, 1] == i).squeeze() for i in range(n_tokens)]

[tensor([  0,   1,   2,   3,   4,  25,  26,  27,  28,  29,  50,  51,  52,  53,
          54,  75,  76,  77,  78,  79, 100, 101, 102, 103, 104]),
 tensor([  5,   6,   7,   8,   9,  30,  31,  32,  33,  34,  55,  56,  57,  58,
          59,  80,  81,  82,  83,  84, 105, 106, 107, 108, 109]),
 tensor([ 10,  11,  12,  13,  14,  35,  36,  37,  38,  39,  60,  61,  62,  63,
          64,  85,  86,  87,  88,  89, 110, 111, 112, 113, 114]),
 tensor([ 15,  16,  17,  18,  19,  40,  41,  42,  43,  44,  65,  66,  67,  68,
          69,  90,  91,  92,  93,  94, 115, 116, 117, 118, 119]),
 tensor([ 20,  21,  22,  23,  24,  45,  46,  47,  48,  49,  70,  71,  72,  73,
          74,  95,  96,  97,  98,  99, 120, 121, 122, 123, 124])]

In [26]:
f.reshape(5, -1)

tensor([[-0.8750,  1.4016,  0.0270, -0.4332,  0.0133,  0.6124, -2.0095,  0.6408,
         -1.7025, -1.5452, -0.8682,  1.1706, -1.4370, -1.2018, -1.9680, -1.5157,
          0.9573, -0.1997,  0.4207, -2.7122, -0.5849, -0.5995,  0.7393, -0.9333,
          0.1884],
        [ 1.2517, -0.9314,  0.5711,  0.0932,  1.0021, -2.1479,  0.2673, -0.7387,
          0.3344, -0.7081, -0.2440,  1.8040,  0.7881, -2.1420,  0.3635, -1.8689,
         -0.1884, -1.9096,  0.0948, -0.6922,  0.7987, -0.9315, -1.4617,  1.0115,
         -0.3192],
        [ 1.0524,  0.6389,  0.6834,  1.4823, -0.2306,  0.8106, -1.0035, -2.0879,
          0.6985, -1.1134,  0.4746, -0.8444, -0.7959, -0.1769, -0.0194, -0.5012,
          1.8367,  0.8992,  1.0352,  0.9864,  1.7449,  0.2280, -1.4442, -0.5730,
         -1.7163],
        [ 0.2613,  0.9135, -0.1667, -1.3014, -0.0527,  0.1181,  1.0481,  0.8787,
         -0.5568,  1.0012, -1.1522,  0.0993,  0.1824, -2.1522,  0.8663,  0.0649,
          0.4856,  0.9354, -1.2093, -1.0450, -1.7064