In [None]:
import sys
sys.path.append("../..")
import torch
from hippocampal_sensory_layers import ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars
from fourier_scaffold import FourierScaffold
from preprocessing_cnn import RescalePreprocessing, SequentialPreprocessing, GrayscaleAndFlattenPreprocessing
from experiments.fourier_miniworld_gridsearch.room_env import RoomExperiment

In [None]:
D = 400
shapes_primes = [(3,), (5,)]
shapes_nonprimes = [((3,), (7,), (9,))]
device = "cuda"

scaffold_primes = FourierScaffold(
    shapes=torch.tensor(shapes_primes), D=D, device=device, _skip_K_calc=True
)
scaffold_nonprimes = FourierScaffold(
    shapes=torch.tensor(shapes_nonprimes), D=D, device=device, _skip_K_calc=True
)
env = RoomExperiment([3, 0, 3], 0, True, True)
preprocessing = SequentialPreprocessing(
    [RescalePreprocessing(0.5), GrayscaleAndFlattenPreprocessing(device=device)]
)

# output_shape = (60,80) # 1
output_shape = (10, 10)  # 0.5
# output_shape = (15, 20)  # 0.25

Npatts = 20
N_s = 100  # np.prod(output_shape)
sbook = torch.sign(torch.randn(Npatts, N_s, device=device))
print(f"N_s: {N_s}")
layer_primes = (
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
        input_size=N_s,
        N_h=D,
        hidden_layer_factor=0,
        epsilon_hs=0.1,
        epsilon_sh=0.1,
        device=device,
    )
)
layer_nonprimes = (
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
        input_size=N_s,
        N_h=D,
        hidden_layer_factor=0,
        epsilon_hs=0.1,
        epsilon_sh=0.1,
        device=device,
    )
)

In [None]:
import matplotlib.pyplot as plt
from graph_utils import plot_imgs_side_by_side

NONE = 3
FORWARD = 2
RIGHT = 1
LEFT = 0


def primes_P_from_s(img):
    g_avg = layer_primes.hippocampal_from_sensory(img)[0]  # (N_h)
    P = torch.outer(g_avg, g_avg.conj())
    return P


def primes_s_from_P(P):
    g_avg = P @ scaffold_primes.g_s
    s = layer_primes.sensory_from_hippocampal(g_avg)
    return s

def nonprimes_P_from_s(img):
    g_avg = layer_nonprimes.hippocampal_from_sensory(img)[0]  # (N_h)
    P = torch.outer(g_avg, g_avg.conj())
    return P


def nonprimes_s_from_P(P):
    g_avg = P @ scaffold_nonprimes.g_s
    s = layer_nonprimes.sensory_from_hippocampal(g_avg)
    return s


def similarity(P1, P2):
    return (P1 * P2).sum().abs() / (P1.norm() * P2.norm())


def similarity_batch(P1s, P2s):
    P1s = P1s.flatten(1)
    P2s = P1s.flatten(1)
    return (P1s * P2s).sum(dim=1).abs() / (P1s.norm(dim=1) * P2s.norm(dim=1))


In [None]:
# no smoothing test
first_img = sbook[0] #preprocessing.encode(obs)
first_P = scaffold_primes.P.clone()
layer_primes.learn(scaffold_primes.g_avg(), first_img)
layer_nonprimes.learn(scaffold_nonprimes.g_avg(), first_img)

N = 10
primes_results_first_P = torch.zeros(N)
primes_results_recent_P = torch.zeros(N)

primes_recent_imgs_true = torch.zeros(N, N_s)
primes_recent_imgs_recovered = torch.zeros(N, N_s)
primes_first_imgs_recovered = torch.zeros(N, N_s)

primes_recovered_recent_P_Hs = torch.zeros(N)
primes_recovered_first_P_Hs = torch.zeros(N)

nonprimes_results_first_P = torch.zeros(N)
nonprimes_results_recent_P = torch.zeros(N)

nonprimes_recent_imgs_true = torch.zeros(N, N_s)
nonprimes_recent_imgs_recovered = torch.zeros(N, N_s)
nonprimes_first_imgs_recovered = torch.zeros(N, N_s)

nonprimes_recovered_recent_P_Hs = torch.zeros(N)
nonprimes_recovered_first_P_Hs = torch.zeros(N)

# 1. store image
# 2. calculate s->P->s for first image
# 2. calculate s->P->s for new image

for j in range(N):
    img = sbook[j]  # preprocessing.encode(obs)

    layer_primes.learn(scaffold_primes.g_avg(), img)
    layer_nonprimes.learn(scaffold_nonprimes.g_avg(), img)

    primes_results_first_P[j] = similarity(first_P, primes_P_from_s(first_img))
    primes_results_recent_P[j] = similarity(scaffold_primes.P, primes_P_from_s(img))
    primes_first_imgs_recovered[j] = primes_s_from_P(primes_P_from_s(first_img))
    primes_recent_imgs_true[j] = img
    primes_recent_imgs_recovered[j] = primes_s_from_P(primes_P_from_s(img))
    nonprimes_results_first_P[j] = similarity(first_P, nonprimes_P_from_s(first_img))
    nonprimes_results_recent_P[j] = similarity(
        scaffold_nonprimes.P, nonprimes_P_from_s(img)
    )
    nonprimes_first_imgs_recovered[j] = nonprimes_s_from_P(
        nonprimes_P_from_s(first_img)
    )
    nonprimes_recent_imgs_true[j] = img
    nonprimes_recent_imgs_recovered[j] = nonprimes_s_from_P(nonprimes_P_from_s(img))
    # print(env.agent.pos)

    primes_recovered_recent_P_Hs[j] = primes_P_from_s(img).norm() ** 2
    primes_recovered_first_P_Hs[j] = primes_P_from_s(first_img).norm() ** 2
    scaffold_primes.velocity_shift(torch.tensor([1]))
    nonprimes_recovered_recent_P_Hs[j] = nonprimes_P_from_s(img).norm() ** 2
    nonprimes_recovered_first_P_Hs[j] = nonprimes_P_from_s(first_img).norm() ** 2
    scaffold_nonprimes.velocity_shift(torch.tensor([1]))

print(primes_results_first_P)
print(primes_results_recent_P)
print(primes_recovered_recent_P_Hs)
print(primes_recovered_first_P_Hs)
print(nonprimes_results_first_P)
print(nonprimes_results_recent_P)
print(nonprimes_recovered_recent_P_Hs)
print(nonprimes_recovered_first_P_Hs)

In [None]:
fig, axes = plt.subplots(nrows=N, ncols=3, figsize=(10, 30))

for j in range(N):
    plot_imgs_side_by_side(
        axs=axes[j],
        imgs=[
            primes_recent_imgs_true[j].cpu().resize(*output_shape),
            nonprimes_recent_imgs_recovered[j].cpu().resize(*output_shape),
            primes_recent_imgs_recovered[j].cpu().resize(*output_shape),
        ],
        titles=["original", "recovered nonprimes", "recovered primes"],
        use_first_img_scale=False,
        fig=fig,
    )

plt.show()

In [None]:
fig, axes = plt.subplots(nrows=N, ncols=3, figsize=(10, 30))

for j in range(N):
    plot_imgs_side_by_side(
        axs=axes[j],
        imgs=[
            first_img.cpu().resize(*output_shape),
            nonprimes_first_imgs_recovered[j].cpu().resize(*output_shape),
            primes_first_imgs_recovered[j].cpu().resize(*output_shape),
        ],
        titles=["original", "nonprime recovered", "prime recovered"],
        use_first_img_scale=False,
        fig=fig,
    )

plt.show()

# Entropy calculations

In [None]:
import math
import itertools
import numpy as np

N = 1000
D = 200
N_s = 100
shapes_primes = [(3,), (5,), (7,), (11,)]
dim_sizes = np.prod(shapes_primes, axis=0)
combinations = [np.arange(dim_sizes[i]) for i in range(len(shapes_primes[0]))]

H = 0
j = 0
for k_tuple in itertools.product(*combinations):
    for d, k_d in enumerate(k_tuple):
        for i, m_i in enumerate(shapes_primes):
            # print(f"H += lg({m_i[d]} / gcd({k_d}, {m_i[d]}))")
            H += math.log2(m_i[d] / math.gcd(k_d, m_i[d]))
            j += 1
            if j >= N:
                continue

    if j >= N:
        break

print(D*H)
print(D * H / N)
print(D * H / N_s)