In [None]:
import torch
from nd_scaffold import GridModule

mod = GridModule(shape=(3, 3, 5), device="cpu:0", T=0.1)

print("state", mod.state)
print("one hot", mod.onehot())

mod.denoise_self()

print("state", mod.state)
print("one hot", mod.onehot())

mod.shift(torch.tensor([1, 0, 0], device="cpu:0"))

print("state", mod.state)
print("one hot", mod.onehot())

mod.denoise_self()

print("state", mod.state)
print("one hot", mod.onehot())

In [None]:
import numpy as np
import torch
from nd_scaffold import GridScaffold, SparseMatrixBySparsityInitializer
from graph_utils import graph_scaffold, print_imgs_side_by_side
import os


def test_mnist(
    num_imgs=1,
    prefix="",
    relu_theta=0.5,
    sparsity=0.1,
    N_h=400,
    T=0.01,
    plot_figs=False,
):
    import torchvision
    from torchvision import transforms

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x.flatten())]
    )

    mnist = torchvision.datasets.MNIST(
        root="data", train=True, download=True, transform=transform
    )

    mnist_data = mnist.data.flatten(1).float().to("cpu")[:num_imgs]
    mnist_data = (mnist_data - mnist_data.mean()) / mnist_data.std()
    l = mnist_data.shape[0]

    shapes = [(3, 3, 5), (4, 4, 7)]
    velocities = torch.tile(torch.tensor([[1, 1, 1]]), (l, 1)).to("cpu")

    if False and os.path.exists("checkpoint.pt"):
        GS = GridScaffold.load("checkpoint.pt", device="cpu")
    else:
        GS = GridScaffold(
            shapes=shapes,
            N_h=N_h,
            input_size=784,
            device="cpu",
            sparse_matrix_initializer=SparseMatrixBySparsityInitializer(
                sparsity=sparsity,
                device="cpu",
            ),
            relu_theta=relu_theta,
            T=T,
        )
        # GS.checkpoint("checkpoint.pt")

    # graph_scaffold(GS)

    # random_noise = torch.zeros_like(mnist_data).uniform_(-128, 128)
    noisy_mnist = mnist_data  # + random_noise
    # recalled_imgs = GS.recall(noisy_mnist)

    # for i in range(1):
    #     original_img = mnist_data[i].reshape(28, 28).cpu().numpy()
    #     noisy_img = noisy_mnist[i].reshape(28, 28).cpu().numpy()
    #     recalled_img = recalled_imgs[i].reshape(28, 28).cpu().numpy()
    #     print_imgs_side_by_side(
    #         original_img,
    #         noisy_img,
    #         recalled_img,
    #         out=f"mnist_unlearned_{i}.png",
    #         captions=["original", "noisy", "recalled"],
    #         title="Unlearned",
    #     )

    GS.learn_path(observations=mnist_data, velocities=velocities)
    recalled_imgs = GS.recall(noisy_mnist)

    if plot_figs:
        for i in range(1):
            original_img = mnist_data[i].reshape(28, 28).cpu().numpy()
            noisy_img = noisy_mnist[i].reshape(28, 28).cpu().numpy()
            recalled_img = recalled_imgs[i].reshape(28, 28).cpu().numpy()
            print_imgs_side_by_side(
                original_img,
                noisy_img,
                recalled_img,
                out=f"{prefix}mnist_learned_{i}.png",
                captions=["original", "noisy", "recalled"],
                title="Learned",
            )

    similarity = torch.nn.functional.cosine_similarity(
        mnist_data, GS.recall(noisy_mnist)
    )
    return similarity


test_mnist(11, sparsity=0.99, N_h=1000, prefix="", relu_theta=0.9)

In [None]:
from matplotlib import pyplot as plt

temperatures = [0.001, 0.03, 0.01, 0.1, 0.3, 1]
N_h = [200, 400, 600, 800, 1000]

scores = np.zeros((len(temperatures), len(N_h)))

for i, T in enumerate(temperatures):
    for j, N in enumerate(N_h):
        scores[i, j] = (
            test_mnist(11, sparsity=0.99, N_h=N, prefix=f"T_{T}_N_{N}_").mean().item()
        )

print(scores)

fig, ax = plt.subplots()
cax = ax.matshow(scores, cmap="viridis")

In [4]:
import pickle

with open("scores.pkl", "wb") as f:
    pickle.dump(scores, f)

In [None]:
fig, ax = plt.subplots()
for i, T in enumerate(temperatures):
    ax.plot(N_h, scores[i], label=f"T={T}")

ax.set_xlabel("N_h")
ax.set_ylabel("Similarity")
ax.legend()
ax.legend()

In [None]:
from vectorhash_functions import spacefillingcurve

modules = [(2, 3), (3, 4)]
v = spacefillingcurve(modules)

# graph walk

import matplotlib.pyplot as plt
import numpy as np


def plot_walk(v, modules):
    l2 = modules[0][0] * modules[1][0]
    l1 = modules[0][1] * modules[1][1]
    x = np.cumsum([0] + [i[0] for i in v]) % l1
    y = np.cumsum([0] + [i[1] for i in v]) % l2

    # plot arrows

    c = plt.scatter(x, y, c=range(len(x)), cmap="viridis", s=20)
    for i in range(len(x) - 1):
        plt.arrow(
            x[i],
            y[i],
            x[i + 1] - x[i],
            y[i + 1] - y[i],
            head_width=0.2,
            head_length=0.2,
            length_includes_head=True,
        )
    plt.colorbar(c)


print(len(v))
plot_walk(v, modules)

In [None]:
T = [0.01, 0.03, 0.1, 0.3, 1]
N_h = [500, 600, 700, 800, 900]

scores = np.zeros((len(T), len(N_h)))

for i, t in enumerate(T):
    for j, dim in enumerate(N_h):
        scores[i, j] = (
            test_mnist(
                100, sparsity=0.99, N_h=dim, prefix=f"T_{t}_N_{dim}_", plot_figs=False
            )
            .mean()
            .item()
        )

In [None]:
T = [0.01, 0.03, 0.1, 0.3, 1]
N_h = [500, 600, 700, 800, 900]
fig, ax = plt.subplots()
for i, t in enumerate(T):
    ax.plot(N_h, scores[i], label=f"T={t}")

ax.set_xlabel("N_h")
ax.set_ylabel("Similarity")
ax.legend()
ax.legend()

In [None]:
from nd_scaffold import GridScaffold, SparseMatrixBySparsityInitializer
from matrix_initializers import SparseMatrixByScalingInitializer
import torchvision
from vectorhash_functions import calculate_big_theta, calculate_relu_theta, solve_mean
from torchvision import transforms
from graph_utils import print_imgs_side_by_side, graph_scaffold
from vectorhash_functions import spacefillingcurve
import matplotlib.pyplot as plt
from scipy.stats import norm
import torch
import math
import os


def test_mnist2(
    percent=1,
    prefix="",
    relu_theta=0.5,
    sparsity=0.1,
    N_h=400,
    T=0.01,
    plot_figs=False,
    plot_scaffold=False,
    continualupdate_=False,
    ratshift_=True,
    sparsitymethod=0,
    ratio=0.1,
):
    os.makedirs("mnist_test_2", exist_ok=True)

    shapes = [(3, 3), (4, 4)]
    velocities = spacefillingcurve(shapes)
    velocities = velocities[:133]
    print("imgs:    ", int(len(velocities) * percent))

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x.flatten())]
    )

    mnist = torchvision.datasets.MNIST(
        root="data", train=True, download=True, transform=transform
    )

    mnist_data = mnist.data.flatten(1).float()[: int(len(velocities) * percent)]
    mnist_data = (mnist_data - mnist_data.mean()) / mnist_data.std()
    l = mnist_data.shape[0]

    relu_theta_ = calculate_relu_theta(
        modules=len(shapes), sparsity=1 - sparsity, target_prob=0.1
    )
    print("relu_theta", relu_theta_)
    relu_big_theta = calculate_big_theta(len(shapes), sparsity=sparsity, targetp=0.2)
    print("relu_big_theta", relu_big_theta)

    # plot normal distribution with mean relu_theta and std 1 and plot relu big theta
    x = torch.linspace(-5, 5, 100)
    plt.plot(x, norm.pdf(x, relu_theta_, math.sqrt(len(shapes) * (1 - sparsity))))
    plt.plot(x, norm.pdf(x, -relu_big_theta, math.sqrt(len(shapes) * (sparsity))))
    print("what are 10 percent of numbers going to be below? 1: ezra, 2: johnny")
    print(norm(-relu_theta_, math.sqrt(len(shapes) * (sparsity))).cdf(0))
    print(norm(-relu_big_theta, math.sqrt(len(shapes) * (1 - sparsity))).cdf(0))
    relu_new = solve_mean(p=ratio, var=len(shapes) * (1 - sparsity))
    print(norm(-relu_new, math.sqrt(len(shapes) * (1 - sparsity))).cdf(0))
    ####
    ### TESTING OTHER GROUP MATRIX INITIALIZERS
    ####
    k = 2 * len(shapes)
    c = 0.8
    var = 1.0
    print("SPARSITY VARSITY BANANA LAME ")
    a = SparseMatrixBySparsityInitializer(sparsity=sparsity)
    print((a((N_h, sum([module[0] * module[1] for module in shapes])))))
    if False and os.path.exists("checkpoint.pt"):
        GS = GridScaffold.load("checkpoint.pt", device="cpu")
    else:
        GS = GridScaffold(
            shapes=shapes,
            N_h=N_h,
            input_size=784,
            device=None,
            sparse_matrix_initializer=(
                SparseMatrixByScalingInitializer(
                    mean=relu_theta_, scale=var / math.sqrt(k)
                )
                if sparsitymethod == 1
                else SparseMatrixBySparsityInitializer(sparsity=sparsity)
            ),
            relu_theta=(-relu_new) if sparsitymethod == 0 else relu_theta,
            T=T,
            continualupdate=continualupdate_,
            ratshift=ratshift_,
        )

    if plot_scaffold:
        graph_scaffold(GS, dir="mnist_test_2")
        # GS.checkpoint("checkpoint.pt")

    # graph_scaffold(GS)

    # random_noise = torch.zeros_like(mnist_data).uniform_(-128, 128)
    noisy_mnist = mnist_data  # + random_noise
    # recalled_imgs = GS.recall(noisy_mnist)

    # for i in range(1):
    #     original_img = mnist_data[i].reshape(28, 28).cpu().numpy()
    #     noisy_img = noisy_mnist[i].reshape(28, 28).cpu().numpy()
    #     recalled_img = recalled_imgs[i].reshape(28, 28).cpu().numpy()
    #     print_imgs_side_by_side(
    #         original_img,
    #         noisy_img,
    #         recalled_img,
    #         out=f"mnist_unlearned_{i}.png",
    #         captions=["original", "noisy", "recalled"],
    #         title="Unlearned",
    #     )

    GS.learn_path(observations=mnist_data, velocities=velocities)
    recalled_imgs = GS.recall(noisy_mnist)

    if plot_figs:
        for i in range(3):
            original_img = mnist_data[i].reshape(28, 28).cpu().numpy()
            noisy_img = noisy_mnist[i].reshape(28, 28).cpu().numpy()
            recalled_img = recalled_imgs[i].reshape(28, 28).cpu().numpy()
            print_imgs_side_by_side(
                original_img,
                noisy_img,
                recalled_img,
                out=f"mnist_test_2/{prefix}mnist2_LRND_{i}.png",
                captions=["original", "noisy", "recalled"],
                title="Learned",
            )

    similarity = torch.nn.functional.cosine_similarity(
        mnist_data, GS.recall(noisy_mnist)
    )
    return similarity, GS


scores, GS = test_mnist2(
    percent=1,
    sparsity=0,
    N_h=1000,
    prefix="",
    relu_theta=0,
    T=0.0000001,
    plot_figs=True,
    plot_scaffold=True,
    continualupdate_=True,
    ratshift_=False,
    sparsitymethod=0,
    ratio=0.005,
)
print(scores.mean().item())

In [None]:
from vectorhash_functions import solve_mean, calculate_big_theta, calculate_relu_theta
import torch
import math
import matplotlib.pyplot as plt
from scipy.stats import norm

shapes = [(3, 3), (4, 4), (5, 5)]
sparsity = 0.9
relu_theta_ = calculate_relu_theta(
    modules=len(shapes), sparsity=1 - sparsity, target_prob=0.1
)
print("relu_theta", relu_theta_)
relu_big_theta = calculate_big_theta(len(shapes), sparsity=sparsity, targetp=0.1)
print("relu_big_theta", relu_big_theta)
relu_new = solve_mean(p=0.2, var=len(shapes) * (1 - sparsity))
print("relu_new", relu_new)
# plot normal distribution with mean relu_theta and std 1 and plot relu big theta
x = torch.linspace(-5, 5, 100)
plt.plot(x, norm.pdf(x, relu_theta_, math.sqrt(len(shapes) * (1 - sparsity))))
plt.plot(x, norm.pdf(x, -relu_big_theta, math.sqrt(len(shapes) * (sparsity))))
print("what are 10 percent of numbers going to be below? 1: ezra, 2: johnny")
print(norm(-relu_theta_, math.sqrt(len(shapes) * (sparsity))).cdf(0))
print(norm(-relu_big_theta, math.sqrt(len(shapes) * (1 - sparsity))).cdf(0))
print(norm(relu_new, math.sqrt(len(shapes) * (1 - sparsity))).cdf(0))