In [None]:
import sys

sys.path.append("../../")
import torch

from fourier_scaffold import (
    FourierScaffold,
    FourierScaffoldDebug,
    HadamardSharpening,
    ContractionSharpening,
    GaussianFourierSmoothing,
    GuassianFourierSmoothingMatrix,
    HadamardShiftMatrix,
    calculate_padding,
)
import math

device = "cuda"
shapes = torch.tensor([(5, 5), (7, 7)], device=device)
nruns = 5
dim_sizes = [int(shapes[:, dim].prod().item()) for dim in range(shapes.shape[1])]
rescaling = True
Ds = [10, 31, 100, 310, 1000]


def zero():
    return torch.zeros(*dim_sizes, device=device)


def uniform():
    t = torch.ones_like(zero())
    return t / t.sum()


def degenerate():
    t = zero()
    t[tuple([0] * shapes.shape[1])] = 1
    return t


def gaussian(sigma=1):
    t = degenerate()
    kernel_size = 2 * max(10, 3 * math.ceil(sigma)) + 1
    x = torch.arange(kernel_size, device=device) - kernel_size // 2
    y = torch.arange(kernel_size, device=device) - kernel_size // 2
    x, y = torch.meshgrid(x, y)
    kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
    kernel = kernel / kernel.sum()

    x_padding = calculate_padding(kernel_size, kernel.shape[0], 1)
    y_padding = calculate_padding(kernel_size, kernel.shape[1], 1)

    padded = torch.nn.functional.pad(
        t.unsqueeze(0).unsqueeze(0),
        y_padding + x_padding,
        mode="circular",
    )

    convoluted = torch.nn.functional.conv2d(
        input=padded, weight=kernel.unsqueeze(0).unsqueeze(0)
    )

    return convoluted.squeeze(0).squeeze(0)


def bimodal():
    t = zero()
    index = [0] * shapes.shape[1]
    for i, size in enumerate(dim_sizes):
        index[i] = size // 2

    t[tuple([0] * shapes.shape[1])] = 0.5
    t[tuple(index)] = 0.5
    return t

def bimodal2():
    t = zero()
    index = [0] * shapes.shape[1]
    for i, size in enumerate(dim_sizes):
        index[i] = 5

    t[tuple([1] * shapes.shape[1])] = 0.5
    t[tuple(index)] = 0.5
    return t

def gaussian_mixture():
    g1 = gaussian(1.5)
    g2 = gaussian(2)
    g3 = gaussian(2.5)

    t = (
        g1.roll(shifts=(7, 7), dims=(0, 1))
        + g2.roll(shifts=(30, 13), dims=(0, 1))
        + g3.roll(shifts=(17, 28), dims=(0, 1))
    )
    t = t / t.sum()
    return t

In [None]:
# import matplotlib.pyplot as plt
# x = gaussian_mixture().cpu()
# plt.imshow(x)

In [None]:
distributions = [
    ("degenerate", degenerate()),
    ("uniform", uniform()),
    ("gaussian σ=0.5", gaussian(0.5)),
    ("gaussian σ=1", gaussian(1)),
    ("gaussian σ=2", gaussian(2)),
    ("bimodal", bimodal()),
    ('bimodal2', bimodal2()),
    ("gaussian mixture", gaussian_mixture()),
]

scaffold_debug = FourierScaffoldDebug(shapes, device=device, rescale=rescaling)


def l2_err(v1: torch.Tensor, v2: torch.Tensor):
    return torch.linalg.vector_norm(v1 - v2)


def similarity(v1: torch.Tensor, v2: torch.Tensor):
    """ v1 and v2 must be flattened. if v1 and v2 are complex, it takes the absolute value of the inner product
    """
    return (v1 * v2.conj()).sum().abs() / (v1.norm() * v2.norm())


def run_test(
    distribution: torch.Tensor,
    scaffold: FourierScaffold,
    scaffold_debug: FourierScaffoldDebug,
):
    scaffold.g = scaffold.encode_probability(distribution)
    scaffold_debug.ptensor = distribution

    estimated_probs, true_probs = (
        scaffold.get_all_probabilities().flatten().abs(),
        scaffold_debug.ptensor.flatten(),
    )
    prob_similarity, prob_l2 = (
        similarity(true_probs, estimated_probs),
        l2_err(true_probs, estimated_probs),
    )

    scaffold.sharpen()
    scaffold_debug.sharpen()

    true_encodings, generated_encodings = (
        scaffold.encode_probability(scaffold_debug.ptensor).flatten(),
        scaffold.g.flatten(),
    )

    sharpened_encoding_similarity, sharpened_encoding_l2 = (
        similarity(true_encodings, generated_encodings),
        l2_err(true_encodings, generated_encodings),
    )

    estimated_probs, true_probs = (
        scaffold.get_all_probabilities().flatten().abs(),
        scaffold_debug.ptensor.flatten(),
    )

    sharpened_prob_similarity, sharpened_prob_l2 = (
        similarity(true_probs, estimated_probs),
        l2_err(true_probs, estimated_probs),
    )

    return (
        sharpened_encoding_similarity,
        sharpened_encoding_l2,
        prob_similarity,
        sharpened_prob_similarity,
        prob_l2,
        sharpened_prob_l2,
    )

In [None]:
sharpened_encoding_similarities = torch.zeros(len(distributions), len(Ds), 2, nruns)
sharpened_encoding_l2s = torch.zeros(len(distributions), len(Ds), 2, nruns)

prob_similarities = torch.zeros(len(distributions), len(Ds), 2, nruns)
sharpened_prob_similarities = torch.zeros(len(distributions), len(Ds), 2, nruns)

prob_l2s = torch.zeros(len(distributions), len(Ds), 2, nruns)
sharpened_prob_l2s = torch.zeros(len(distributions), len(Ds), 2, nruns)


for i, [name, distribution] in enumerate(distributions):
    for j, D in enumerate(Ds):
        for run in range(nruns):
            print(
                f" ----------------------- running test: {name} ({i}/{len(distributions)}), D={D} ({j}/{len(Ds)}), run {run}/{nruns} --------------------"
            )
            scaffold_v = FourierScaffold(
                shapes,
                D=D**2,
                sharpening=HadamardSharpening(2),
                smoothing=GaussianFourierSmoothing(
                    kernel_radii=[10, 10], kernel_sigmas=[1, 1]
                ),  # doesnt matter just so it runs
                device=device,
                rescaling=rescaling,
                _skip_K_calc=True,
            )
            scaffold_m = FourierScaffold(
                shapes,
                D=D,
                sharpening=ContractionSharpening(2),
                shift=HadamardShiftMatrix(),
                smoothing=GuassianFourierSmoothingMatrix(
                    kernel_radii=[10, 10], kernel_sigmas=[1, 1]
                ),  # doesnt matter just so it runs
                device=device,
                representation="matrix",
                rescaling=rescaling,
                _skip_K_calc=True,
            )
            scaffold_debug = FourierScaffoldDebug(
                shapes, device=device, rescale=rescaling
            )

            (
                sharpened_encoding_similarity_v,
                sharpened_encoding_l2_v,
                prob_similarity_v,
                sharpened_prob_similarity_v,
                prob_l2_v,
                sharpened_prob_l2_v,
            ) = run_test(distribution.clone(), scaffold_v, scaffold_debug)
            (
                sharpened_encoding_similarity_m,
                sharpened_encoding_l2_m,
                prob_similarity_m,
                sharpened_prob_similarity_m,
                prob_l2_m,
                sharpened_prob_l2_m,
            ) = run_test(distribution.clone(), scaffold_m, scaffold_debug)

            sharpened_encoding_similarities[i, j, 0, run] = (
                sharpened_encoding_similarity_v
            )
            sharpened_encoding_similarities[i, j, 1, run] = (
                sharpened_encoding_similarity_m
            )

            sharpened_encoding_l2s[i, j, 0, run] = sharpened_encoding_l2_v
            sharpened_encoding_l2s[i, j, 1, run] = sharpened_encoding_l2_m

            prob_similarities[i, j, 0, run] = prob_similarity_v
            prob_similarities[i, j, 1, run] = prob_similarity_m

            sharpened_prob_similarities[i, j, 0, run] = sharpened_prob_similarity_v
            sharpened_prob_similarities[i, j, 1, run] = sharpened_prob_similarity_m

            prob_l2s[i, j, 0, run] = prob_l2_v
            prob_l2s[i, j, 1, run] = prob_l2_m

            sharpened_prob_l2s[i, j, 0, run] = sharpened_prob_l2_v
            sharpened_prob_l2s[i, j, 1, run] = sharpened_prob_l2_m

In [None]:
import pickle

data = {
    "sharpened_encoding_similarities": sharpened_encoding_similarities,
    "sharpened_encoding_l2s ": sharpened_encoding_l2s,
    "prob_similarities ": prob_similarities,
    "sharpened_prob_similarities ": sharpened_prob_similarities,
    "prob_l2s ": prob_l2s,
    "sharpened_prob_l2s ": sharpened_prob_l2s,
}

if rescaling:
    with open("results_rescaling.pkl", "wb") as f:
        pickle.dump(data, f)
else:
    with open("results.pkl", "wb") as f:
        pickle.dump(data, f)

In [None]:
import matplotlib.pyplot as plt

for i, [name, distribution] in enumerate(distributions):
    fig, ax = plt.subplots(ncols=2, figsize=(15, 5))

    ax[0].errorbar(
        (torch.tensor(Ds)**2).tolist(),
        prob_similarities[i, :, 0].mean(dim=1),
        yerr=prob_similarities[i, :, 0].std(dim=1),
        label="original",
        capsize=3,
        fmt="--o",
    )
    ax[0].errorbar(
        (torch.tensor(Ds)**2).tolist(),
        sharpened_prob_similarities[i, :, 0].mean(dim=1),
        yerr=sharpened_prob_similarities[i, :, 0].std(dim=1),
        label="sharpened (hadamard)",
        capsize=3,
        fmt="--o",
    )
    ax[0].set_xscale("log")
    ax[0].legend()
    ax[0].set_ylabel("cosine similarity")
    ax[0].set_xlabel("D")
    # ax[0].set_ylim(-0.2,1)
    ax[0].set_title(f"hadamard sharpening")

    ax[1].errorbar(
        Ds,
        prob_similarities[i, :, 1].mean(dim=1),
        yerr=prob_similarities[i, :, 1].std(dim=1),
        label="original",
        capsize=3,
        fmt="--o",
    )
    ax[1].errorbar(
        Ds,
        prob_similarities[i, :, 1].mean(dim=1),
        yerr=prob_similarities[i, :, 1].std(dim=1),
        label="sharpened (contraction)",
        capsize=3,
        fmt="--o",
    )
    ax[1].set_xscale("log")
    ax[1].legend()
    ax[1].set_ylabel("cosine similarity")
    ax[1].set_xlabel("D")
    # ax[1].set_ylim(-0.2,1)
    ax[1].set_title(f"contraction sharpening")

    fig.suptitle(
        f"Mean cosine similarity between true and recovered original distribution and sharpened distribution for hammard and contraction sharpening. distribution={name}.\nshapes={shapes.cpu().tolist()}"
    )
    if rescaling:
        fig.savefig(f'fourier_scaffold_contraction_testing_cosine_sim_{i}.png')
    else:
        fig.savefig(f'fourier_scaffold_contraction_testing_no_rescaling_cosine_sim_{i}.png')

In [None]:
for i, [name, distribution] in enumerate(distributions):
    fig, ax = plt.subplots(ncols=2, figsize=(15, 5))

    ax[0].errorbar(
        (torch.tensor(Ds) ** 2).tolist(),
        prob_l2s[i, :, 0].mean(dim=1),
        yerr=prob_l2s[i, :, 0].std(dim=1),
        label="original",
        capsize=3,
        fmt="--o",
    )
    ax[0].errorbar(
        (torch.tensor(Ds) ** 2).tolist(),
        sharpened_prob_l2s[i, :, 0].mean(dim=1),
        yerr=sharpened_prob_l2s[i, :, 0].std(dim=1),
        label="sharpened (hadamard)",
        capsize=3,
        fmt="--o",
    )
    ax[0].set_xscale("log")
    ax[0].legend()
    ax[0].set_ylabel("||true-estimated||₂")
    ax[0].set_xlabel("D")
    ax[0].set_ylim(0, 1)
    ax[0].set_title(f"hadamard sharpening")

    ax[1].errorbar(
        Ds,
        prob_l2s[i, :, 1].mean(dim=1),
        yerr=prob_l2s[i, :, 1].std(dim=1),
        label="original",
        capsize=3,
        fmt="--o",
    )
    ax[1].errorbar(
        Ds,
        sharpened_prob_l2s[i, :, 1].mean(dim=1),
        yerr=sharpened_prob_l2s[i, :, 1].std(dim=1),
        label="sharpened (contraction)",
        capsize=3,
        fmt="--o",
    )
    ax[1].set_xscale("log")
    ax[1].legend()
    ax[1].set_ylabel("||true-estimated||₂")
    ax[1].set_xlabel("D")
    ax[1].set_ylim(0, 1)
    ax[1].set_title(f"contraction sharpening")

    fig.suptitle(
        f"Mean l2 error between true and recovered original distribution and sharpened distribution for hammard and contraction sharpening. distribution={name}.\nshapes={shapes.cpu().tolist()}"
    )
    if rescaling:
        fig.savefig(f"fourier_scaffold_contraction_testing_l2err_{i}.png")
    else:
        fig.savefig(f"fourier_scaffold_contraction_testing_no_rescaling_l2err_{i}.png")

In [None]:
for i, [name, distribution] in enumerate(distributions):
    fig, ax = plt.subplots(ncols=2, figsize=(11, 5))

    ax[0].errorbar(
        (torch.tensor(Ds) ** 2).tolist(),
        prob_l2s[i, :, 0].mean(dim=1),
        yerr=prob_l2s[i, :, 0].std(dim=1),
        label="original",
        capsize=3,
        fmt="--o",
    )
    ax[0].errorbar(
        (torch.tensor(Ds) ** 2).tolist(),
        sharpened_prob_l2s[i, :, 0].mean(dim=1),
        yerr=sharpened_prob_l2s[i, :, 0].std(dim=1),
        label="sharpened (hadamard)",
        capsize=3,
        fmt="--o",
    )
    ax[0].set_xscale("log")
    ax[0].legend()
    ax[0].set_ylabel("||true-estimated||₂")
    ax[0].set_xlabel("D")
    ax[0].set_ylim(0, 1)
    ax[0].set_title(f"hadamard sharpening")

    ax[1].errorbar(
        Ds,
        prob_l2s[i, :, 1].mean(dim=1),
        yerr=prob_l2s[i, :, 1].std(dim=1),
        label="original",
        capsize=3,
        fmt="--o",
    )
    ax[1].errorbar(
        Ds,
        sharpened_prob_l2s[i, :, 1].mean(dim=1),
        yerr=sharpened_prob_l2s[i, :, 1].std(dim=1),
        label="sharpened (contraction)",
        capsize=3,
        fmt="--o",
    )
    ax[1].set_xscale("log")
    ax[1].legend()
    ax[1].set_ylabel("||true-estimated||₂")
    ax[1].set_xlabel("D")
    ax[1].set_ylim(0, 1)
    ax[1].set_title(f"contraction sharpening")

    fig.suptitle(
        f"Mean l2 error between true and recovered original distribution and sharpened distribution for hammard and contraction sharpening. distribution={name}.\nshapes={shapes.cpu().tolist()}"
    )
    if rescaling:
        fig.savefig(f"fourier_scaffold_contraction_testing_l2err_{i}.png")
    else:
        fig.savefig(f"fourier_scaffold_contraction_testing_no_rescaling_l2err_{i}.png")

In [None]:
for i, [name, distribution] in enumerate(distributions):
    fig, ax = plt.subplots(ncols=2, figsize=(15, 5))

    ax[0].errorbar(
        (torch.tensor(Ds)).tolist(),
        sharpened_encoding_similarities[i, :, 0].mean(dim=1),
        yerr=sharpened_encoding_similarities[i, :, 0].std(dim=1),
        label="hadamard (uses D² parameters)",
        capsize=3,
        fmt="--o",
    )
    ax[0].errorbar(
        (torch.tensor(Ds)).tolist(),
        sharpened_encoding_similarities[i, :, 1].mean(dim=1),
        yerr=sharpened_encoding_similarities[i, :, 1].std(dim=1),
        label="contraction",
        capsize=3,
        fmt="--o",
    )
    ax[0].set_xscale("log")
    ax[0].legend()
    ax[0].set_ylabel("similarity(true, computed)")
    ax[0].set_xlabel("D")
    ax[0].set_ylim(0, 1)
    ax[1].errorbar(
        (torch.tensor(Ds)).tolist(),
        sharpened_encoding_l2s[i, :, 0].mean(dim=1),
        yerr=sharpened_encoding_l2s[i, :, 0].std(dim=1),
        label="hadamard (uses D² parameters)",
        capsize=3,
        fmt="--o",
    )
    ax[1].errorbar(
        (torch.tensor(Ds)).tolist(),
        sharpened_encoding_l2s[i, :, 1].mean(dim=1),
        yerr=sharpened_encoding_l2s[i, :, 1].std(dim=1),
        label="contraction",
        capsize=3,
        fmt="--o",
    )
    ax[1].set_xscale("log")
    ax[1].legend()
    ax[1].set_ylabel("||true, computed||₂")
    ax[1].set_xlabel("D")
    ax[1].set_ylim(0, 1)

    ax[0].set_title(f'mean cosine sim between true and computed sharpened encodings')
    ax[1].set_title(f'mean l2 error between true and computed sharpened encodings')
    fig.suptitle(f"distribution={name}; shapes={shapes.tolist()}")
    if rescaling:
        fig.savefig(
            f"fourier_scaffold_contraction_testing_encodings_{i}.png"
        )
    else:
        fig.savefig(
            f"fourier_scaffold_contraction_testing_encodings_no_rescaling_{i}.png"
        )

Qualitative results

In [None]:
import torch
import matplotlib.pyplot as plt

from fourier_scaffold import (
    FourierScaffold,
    FourierScaffoldDebug,
    HadamardSharpening,
    ContractionSharpening,
    GaussianFourierSmoothing,
    GuassianFourierSmoothingMatrix,
    HadamardShiftMatrix,
    calculate_padding,
)

D=1000
scaffold_debug = FourierScaffoldDebug(shapes, device=device, rescale=rescaling)
distributions = [
    ("degenerate", degenerate()),
    ("uniform", uniform()),
    ("gaussian σ=0.5", gaussian(0.5)),
    ("gaussian σ=1", gaussian(1)),
    ("gaussian σ=2", gaussian(2)),
    ("bimodal", bimodal()),
    ('bimodal2', bimodal2()),
    ("gaussian mixture", gaussian_mixture()),
]

In [None]:
from graph_utils import plot_imgs_side_by_side

dummy_scaffold = FourierScaffold(
    shapes,
    D=D,
    smoothing=GaussianFourierSmoothing(
        kernel_radii=[10, 10], kernel_sigmas=[1, 1]
    ),  # doesnt matter just so it runs
    device=device,
    representation="vector",
    rescaling=rescaling,
)

scaffolds = [
    FourierScaffold(
        shapes,
        D=D,
        smoothing=GaussianFourierSmoothing(
            kernel_radii=[10, 10], kernel_sigmas=[1, 1]
        ),  # doesnt matter just so it runs
        device=device,
        representation="vector",
        rescaling=rescaling,
        features=dummy_scaffold.features,
    ),
    FourierScaffold(
        shapes,
        D=D,
        sharpening=ContractionSharpening(2),
        shift=HadamardShiftMatrix(),
        smoothing=GuassianFourierSmoothingMatrix(
            kernel_radii=[10, 10], kernel_sigmas=[1, 1]
        ),  # doesnt matter just so it runs
        device=device,
        representation="matrix",
        rescaling=rescaling,
        features=dummy_scaffold.features,
    ),
]
names = ["hadamard", "contraction"]

for scaffold, contraction_name in zip(scaffolds, names):
    for name, distribution in distributions:
        scaffold_debug.ptensor = distribution
        scaffold.g = scaffold.encode_probability(distribution)
        original = scaffold.get_all_probabilities().cpu().clone()

        l2_o = l2_err(original, distribution.cpu())
        sim_o = similarity(original.flatten(), distribution.cpu().flatten())

        fig, ax = plt.subplots(2, 2, figsize=(10, 8))
        plot_imgs_side_by_side(
            [distribution.cpu(), original.cpu()],
            ax[0],
            [
                "distribution",
                f"encoded distribution\nsimilarity={sim_o:.3f}, ||true-encoding||₂ = {l2_o:.3f}",
            ],
            fig,
            False,
        )

        scaffold.sharpen()
        scaffold_debug.sharpen()

        sharpened = scaffold.get_all_probabilities().cpu().clone()

        encoding_similarity = similarity(
            scaffold.g.flatten(),
            scaffold.encode_probability(scaffold_debug.ptensor).flatten(),
        )
        encoding_l2err = l2_err(
            scaffold.g.flatten(),
            scaffold.encode_probability(scaffold_debug.ptensor).flatten(),
        )
        l2_s = l2_err(sharpened, scaffold_debug.ptensor.cpu())
        sim_s = similarity(sharpened.flatten(), scaffold_debug.ptensor.cpu().flatten())
        print(l2_o, l2_s)
        print(sharpened)
        print(sharpened - scaffold_debug.ptensor.cpu())
        plot_imgs_side_by_side(
            [scaffold_debug.ptensor.cpu(), sharpened.cpu()],
            ax[1],
            [
                "sharpened",
                f"encoded sharpened\nsimilarity={sim_s:.3f}, ||true-encoding||₂ = {l2_s:.3f}, encoding_similarity={encoding_similarity:.3f}, encoding_l2={encoding_l2err:.3f}",
            ],
            fig,
            False,
        )

        fig.suptitle(
            f"original vs. sharpened {name}, D={D if contraction_name == 'contraction' else D**2}, sharpening={contraction_name}, shapes={shapes.tolist()}"
        )
        if rescaling:
            fig.savefig(
                f"org_vs_sharp_D={D}, dist={name} sharpening={contraction_name}.png"
            )
        else:
            fig.savefig(
                f"org_vs_sharp_no_rescaling_D={D}, dist={name} sharpening={contraction_name}.png"
            )

Cross term error analysis

In [None]:
x = FourierScaffold(
    shapes,
    1000,
    shift=HadamardShiftMatrix(),
    smoothing=GuassianFourierSmoothingMatrix([10,10] * 1, [0.4,0.4] * 1),
    sharpening=ContractionSharpening(2),
    representation='matrix'
)

In [None]:
omega = torch.cartesian_prod(*[torch.arange(0, dim_sizes[dim]) for dim in range(shapes.shape[1])])
if (omega.ndim == 1):
  omega = omega.unsqueeze(1)
print(omega.shape)
encodings = x.encode_batch(omega.T)

In [None]:
print(omega.shape)
print(encodings.shape)

In [None]:
encodings = encodings.flatten(0,1)
print(encodings.shape)

In [None]:
results = (encodings.T @ encodings.conj()) ** 0.5
print(results)

In [None]:
fig, ax =plt.subplots(figsize=(15,15))
im = ax.imshow(results.abs())
ax.set_xlabel('i')
ax.set_xlabel('j')
ax.set_title(f'|<g(k_i), g(k_j)>| for each i,j in Ω. shapes={shapes.tolist()}, D=200, sharpening=contraction')
fig.colorbar(im, ax=ax)

In [None]:
fig, ax =plt.subplots(figsize=(15,15))
ax.set_xlabel('i')
ax.set_xlabel('j')
ax.set_title(f'log(<g(k_i), g(k_j)>) for each i,j in Ω[:200]. shapes={shapes.tolist()}, D=200, sharpening=contraction')
im = ax.imshow(results[:200, :200].abs().log())
fig.colorbar(im, ax=ax)

In [None]:
error = results - torch.diag(torch.ones(len(omega)))

In [None]:
print(error)

In [None]:
print(error.abs().max())
print(error.abs().min())
print(error.abs().mean())
print(error.abs().std())
print(error.abs().sum())
print(len(error.flatten()))

In [None]:
plt.hist(error.flatten())

In [None]:
bins = torch.arange(-15, 0, 0.5)
plt.ylabel('count')
plt.xlabel('error')
plt.title('counts of logs of error')
plt.hist((error.abs()+1e-8).log().flatten(), bins=bins)