In [None]:
import sys

sys.path.append("../..")
import torch
from hippocampal_sensory_layers import (
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars,
    ComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars,
)
from fourier_scaffold import FourierScaffold

device = "cuda"
runs = 10
Ns = 1000
Npatts_total = 200 
sbook_all = torch.sign(torch.randn(runs, Npatts_total, Ns, device=device))
Npatts_list = torch.arange(5, Npatts_total + 1)  # H = 100
###### CONFIG 1: test MI/D dependence #####
# Note: iterative vs analytic pseudoinverse makes no difference
D_list = [100,200,300]
ps_types = ['iterative']
shape_configs = [
    # [(3,), (5,), (7,),(11,)],
    # [(3, 3), (5, 5), (7, 7)],
    [(3, 3, 3), (5,5,5)],
    # [(3, 3, 3), (7,7,7)],
    # [(11, 13, 17), (23, 29, 31)],
]



def run_test(shapes, layer_type, D_list, Npatts_list, sbook):
    errors = torch.zeros(len(D_list), len(Npatts_list))

    for k, D in enumerate(D_list):
        scaffold = FourierScaffold(shapes, D=D, _skip_K_calc=True, device=device)
        gbook = scaffold.gbook().T

        if layer_type == "analytic":
            layer = ComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars(
                input_size=Ns,
                N_h=D,
                N_patts=Npatts_total,
                hbook=gbook[:Npatts_total],
                device=device,
            )
        elif layer_type == "iterative":
            layer = ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
                input_size=Ns,
                N_h=D,
                epsilon_hs=0.1,
                epsilon_sh=0.1,
                hidden_layer_factor=0,
                device=device,
            )
        else:
            raise Exception("invalid layer type")

        i = 0
        for l, Npatts in enumerate(Npatts_list):
            for j in range(i, Npatts):
                layer.learn(gbook[i], sbook[i])
                i += 1

            true = sbook[:Npatts]
            input = (
                torch.complex(sbook[:Npatts], torch.zeros_like(sbook[:Npatts]))
                if layer_type == "analytic"
                else true
            )
            recovered = layer.sensory_from_hippocampal(
                layer.hippocampal_from_sensory(input)
            )
            error = torch.abs(true.real - torch.sign(recovered.real))
            pflip = error.mean() / 2
            errors[k, l] = pflip

    return errors

In [None]:
results = torch.zeros(runs, len(shape_configs), len(ps_types), len(D_list), len(Npatts_list))

for i in range(runs):
    for j, shape_config in enumerate(shape_configs):
        for k, layer_type in enumerate(ps_types):
            print(f"run {i+1}/{runs}, shapes={shape_config}, layer_type={layer_type} ")
            results[i, j, k] = run_test(torch.tensor(shape_config), layer_type, D_list, Npatts_list, sbook_all[i])
            print(results[i, j, k])

In [None]:
# results = torch.zeros(runs, len(shape_configs), 2, len(Npatts_list))

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(15, 9))


def build_title():
    parts = [f"Ns={Ns}", f"runs={runs}"]
    if len(shape_configs) == 1:
        parts.append(f"shapes={shape_configs[0]}")
    if len(ps_types) == 1:
        parts.append(f"ps_method={ps_types[0]}")
    if len(D_list) == 1:
        parts.append(f"D={D_list[0]}")
    return ", ".join(parts)


def build_label(shape_config, ps_type, D):
    parts = []
    if len(shape_configs) != 1:
        parts.append(f"shapes={shape_config}")
    if len(ps_types) != 1:
        parts.append(f"ps_method={ps_type}")
    if len(D_list) != 1:
        parts.append(f"D={D}")
    return ", ".join(parts)


for i, shape_config in enumerate(shape_configs):
    for k, D in enumerate(D_list):
        for l, method in enumerate(ps_types):  # , "analytic"]):
            m = 1 - (2 * results[:, i, 0, k])  # % of correct bits, p(correct)
            a = (1 + m) / 2  #
            b = (1 - m) / 2  # sor(a))
            b = torch.abs(torch.tensor(b)).cpu()
            S = -a * torch.log2(a) - b * torch.log2(
                b
            )  # H(.) = -P(Y=1) lg P(Y=1) - P(Y=0) lg P(Y=0)
            S = torch.where(m == 1, torch.zeros_like(S), S)
            MI = 1 - S

            ax.plot(
                Npatts_list,
                MI.mean(dim=0),
                label=build_label(shape_config, method, D)
            )
            ax.fill_between(
                Npatts_list,
                MI.mean(dim=0) - MI.std(dim=0),
                MI.mean(dim=0) + MI.std(dim=0),
                # label=build_label(shape_config, method, D),
                alpha=0.5,
            )
ax.legend()
ax.set_xlabel("N_patts")
ax.set_ylabel("MI per sensory bit")
ax.set_yscale("log")
ax.set_xticks(torch.arange(0, 200 + 1, 5))
ax.grid()
ax.set_title(build_title())

In [None]:
import math
import itertools
import numpy as np

print(f"|X| = {Npatts_total}, Ns = {Ns}")
for Ns in [100, 1000]:
    for D in D_list:
        print(f"--------------[D = {D}, Ns = {Ns}]---------------")
        for shapes in shape_configs:
            dim_sizes = np.prod(shapes, axis=0)
            combinations = [np.arange(dim_sizes[i]) for i in range(len(shapes[0]))]
            H = 0
            j = 0
            for k_tuple in itertools.product(*combinations):
                for d, k_d in enumerate(k_tuple):
                    for i, m_i in enumerate(shapes):
                        # print(f"H += lg({m_i[d]} / gcd({k_d}, {m_i[d]}))")
                        H += math.log2(m_i[d] / math.gcd(k_d, m_i[d]))
                        j += 1
                        if j >= Npatts_total:
                            continue

                if j >= Npatts_total:
                    break

            # print(f"   ------------------shapes={shapes}-------------------------------------")
            print(f"   H(X)    = {D*H:.3f}")
            print(f"   E[H(X)] = {D * H / Npatts_total:.3f}")
            print(f"   E[#]    = {math.floor(D * H / Ns)}      (imgs perfectly stored (MI/input bit=1))")

In [None]:
fig.savefig('periodic-vs-shapes-333-555.png')