In [4]:
import sys

sys.path.append("../..")
import torch
from hippocampal_sensory_layers import (
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars,
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayer,
    ComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars,
    ComplexExactPseudoInverseHippocampalSensoryLayer,
    RegularizedComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars,
    HippocampalSensoryLayer,
)
from fourier_scaffold import FourierScaffold, HadamardShiftMatrixRat
from preprocessing_cnn import (
    Preprocessor,
    RescalePreprocessing,
    SequentialPreprocessing,
    GrayscaleAndFlattenPreprocessing,
)
from experiments.fourier_miniworld_gridsearch.room_env import RoomExperiment
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from graph_utils import plot_imgs_side_by_side
from agent import TrueData
from tqdm import tqdm

forward_20 = [2] * 20
right_60_deg = [1] * 20
loop_path = (forward_20 + right_60_deg) * 6 + forward_20
device = "cuda"

In [5]:
def make_layer_iterative_no_hidden(sbook, D, device, gbook):
    layer = (
        ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
            input_size=sbook.shape[1],
            N_h=D,
            hidden_layer_factor=0,
            epsilon_sh=0.1,
            epsilon_hs=0.1,
            device=device,
        )
    )
    return layer


def make_layer_iterative(sbook, D, device, gbook):
    layer = (
        ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
            input_size=sbook.shape[1],
            N_h=D,
            hidden_layer_factor=1,
            epsilon_sh=0.1,
            epsilon_hs=0.1,
            device=device,
        )
    )
    return layer

def make_layer_iterative_noncomplex(sbook, D, device, gbook):
    layer = (
        ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayer(
            input_size=sbook.shape[1],
            N_h=D,
            hidden_layer_factor=1,
            epsilon_sh=0.1,
            epsilon_hs=0.1,
            device=device,
        )
    )
    return layer

def make_layer_iterative_no_hidden_noncomplex(sbook, D, device, gbook):
    layer = (
        ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayer(
            input_size=sbook.shape[1],
            N_h=D,
            hidden_layer_factor=0,
            epsilon_sh=0.1,
            epsilon_hs=0.1,
            device=device,
        )
    )
    return layer




def make_layer_analytic(sbook, D, device, gbook):
    layer = ComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars(
        input_size=sbook.shape[1],
        N_patts=len(gbook),
        hbook=gbook,
        N_h=D,
        device=device,
    )
    return layer


def make_layer_analytic_noncomplex(sbook, D, device, gbook):
    layer = ComplexExactPseudoInverseHippocampalSensoryLayer(
        input_size=sbook.shape[1],
        N_patts=len(gbook),
        hbook=gbook,
        N_h=D,
        device=device,
    )
    return layer


def make_layer_complex_regularized(sbook, D, device, gbook):
    layer = RegularizedComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars(
        input_size=sbook.shape[1],
        N_patts=len(gbook),
        hbook=gbook,
        N_h=D,
        device=device,
    )
    return layer


shapes_list = [
    [(3, 3, 3), (5, 5, 5)],
    # [(3, 3, 3), (7, 7, 7)],
    # [(5, 5, 5), (7, 7, 7)],
    # [(7, 7, 7), (9, 9, 9)],
]

D_list = [300, 600, 800]
D_reshape_size_map = {
    4000: (50, 80),
    2000: (40, 50),
    800: (20, 40),
    600: (20, 30),
    300: (15, 20),
}

layers_list = [
    ("iterative no hidden", make_layer_iterative_no_hidden),
    ("iterative hidden", make_layer_iterative_no_hidden),
    ("iterative no hidden noncomplex ", make_layer_iterative_no_hidden_noncomplex),
    ("iterative hidden noncomplex", make_layer_iterative_no_hidden_noncomplex),
    ("analytic", make_layer_analytic),
    ("analytic_noncomplex", make_layer_analytic_noncomplex),
    # ("analytic_regularized", make_layer_complex_regularized),
]

preprocessing_list = [
    SequentialPreprocessing(
        transforms=[
            RescalePreprocessing(0.25),
            GrayscaleAndFlattenPreprocessing(device),
        ]
    )
]

g_rescaling_list = [False]
sbook_rand_list = [False]
seeds = [44]


def make_env():
    return RoomExperiment([3, 0, 3], 0)


def make_data(
    path: list[int],
    scaffold: FourierScaffold,
    preprocessing: Preprocessor,
    noise_dist=None,
):
    env = make_env()

    def get_true_pos():
        p_x, p_y, p_z = env.get_wrapper_attr("agent").pos
        angle = env.get_wrapper_attr("agent").dir
        p = torch.tensor([p_x, p_z, angle]).float().to(device)
        return p

    def _env_reset():
        obs, info = env.reset()
        img = obs
        processed_img = preprocessing.encode(img)
        p = get_true_pos()
        return processed_img, p

    def _obs_postpreprocess(step_tuple, action):
        obs, reward, terminated, truncated, info = step_tuple
        img = obs
        processed_img = preprocessing.encode(img)
        p = get_true_pos()
        return processed_img, p

    start_img, start_pos = _env_reset()
    v_cumulative = torch.zeros(3, device=device)

    true_data = TrueData(start_pos)

    true_positions = [true_data.true_position.clone()]
    Pbook = [scaffold.P.clone().unsqueeze(0)]
    gbook = [(scaffold.P @ scaffold.g_s).clone()]
    sbook = [start_img]

    print("gbook 0 shape:", gbook[0].shape)
    print("sbook 0 shape:", sbook[0].shape)
    for i, action in enumerate(path):
        ### env-specific observation processing
        step_tuple = env.step(action)

        ### this is the sensory input not flattened yet
        new_img, new_pos = _obs_postpreprocess(step_tuple, action)

        ### calculation of noisy input
        dp = new_pos - true_data.true_position
        true_data.true_position = new_pos
        noisy_dp = new_pos
        if noise_dist != None:
            noisy_dp += noise_dist.sample(3)

        dt = 1
        v = (dp / dt) * scaffold.scale_factor
        v_cumulative += v

        if v_cumulative.norm(p=float("inf")) < 1:
            continue

        scaffold.velocity_shift(v_cumulative)
        print("entropy(P): ", scaffold.entropy(scaffold.P))
        scaffold.smooth()
        # scaffold.sharpen()
        g_avg = scaffold.P @ scaffold.g_s
        print("||g_avg||₂² :", g_avg.norm() ** 2)
        true_positions += [true_data.true_position.clone()]
        Pbook += [scaffold.P.clone().unsqueeze(0)]
        gbook += [g_avg.clone()]
        sbook += [new_img]
        v_cumulative = torch.zeros(3, device=device)

    return (
        torch.vstack(true_positions),
        torch.vstack(gbook),
        torch.vstack(sbook),
        torch.concat(Pbook, dim=0),
    )


def get_xy_distributions(scaffold: FourierScaffold, Pbook: torch.Tensor):
    c_x, c_y, c_th = 3, 0, 3
    r_x, r_y, r_th = 7, 7, 7
    l_x, l_y, l_th = 2 * r_x + 1, 2 * r_y + 1, 2 * r_th + 1
    omega = torch.cartesian_prod(
        torch.arange(c_x - r_x, c_x + r_x + 1, 1, device=device),
        torch.arange(c_y - r_y, c_y + r_y + 1, 1, device=device),
        torch.arange(c_th - r_th, c_th + r_th + 1, 1, device=device),
    )
    xy_distributions = torch.empty(len(Pbook), l_x, l_y)
    for i in range(len(Pbook)):
        dist = scaffold.get_probability_abs_batched(omega, P=Pbook[i])
        xy_distributions[i] = dist.reshape(l_x, l_y, l_th).sum(-1)

    return xy_distributions

In [6]:
def test_layer(
    layer: HippocampalSensoryLayer,
    hbook: torch.Tensor,
    sbook: torch.Tensor,
    large: bool,
):
    err_l1_first_img_s_h_s = -torch.ones(len(sbook))
    err_l1_last_img_s_h_s = -torch.ones(len(sbook))
    avg_accumulated_err_l2 = -torch.ones(len(sbook))
    first_img = sbook[0]

    gbook_recovered = torch.zeros_like(gbook_)
    recovered_first_imgs_shs = torch.zeros_like(sbook)
    recovered_last_imgs_shs = torch.zeros_like(sbook)
    recovered_last_imgs_hs = torch.zeros_like(sbook)

    for i in tqdm(range(len(sbook))):
        h_ = hbook[i]
        if large:
            h = torch.einsum("i,j->ij", h_, h_.conj()).flatten()
        else:
            h = h_

        s = sbook[i]
        layer.learn(h, s)

        gbook_recovered[i] = layer.hippocampal_from_sensory(s)
        recovered_first_imgs_shs[i] = layer.sensory_from_hippocampal(
            layer.hippocampal_from_sensory(first_img)
        )[0]
        recovered_last_imgs_shs[i] = layer.sensory_from_hippocampal(
            layer.hippocampal_from_sensory(s)
        )[0]
        recovered_last_imgs_hs[i] = layer.sensory_from_hippocampal(
            h
        )[0]

        err_l1_first_img_s_h_s[i] = torch.mean(
            torch.abs(
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(first_img)
                )[0]
                - first_img
            )
        )

        err_l1_last_img_s_h_s[i] = torch.mean(
            torch.abs(
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(sbook[i])
                )[0]
                - sbook[i]
            )
        )

        avg_accumulated_err_l2[i] = torch.mean(
            (
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(sbook[: i + 1])
                )
                - sbook[: i + 1]
            )
            ** 2
        )
        if (
            # err_l1_first_img_s_h_s[i] > 10e5
            # or avg_accumulated_err_l2[i] > 10e5
            torch.any(torch.isnan(err_l1_first_img_s_h_s[i]))
            or torch.any(torch.isnan(avg_accumulated_err_l2[i]))
        ):
            break

    return (
        err_l1_first_img_s_h_s,
        err_l1_last_img_s_h_s,
        avg_accumulated_err_l2,
        recovered_first_imgs_shs,
        recovered_last_imgs_shs,
        gbook_recovered,
        recovered_last_imgs_hs
    )




In [None]:
for seed in seeds:
    if seed == 42:
        continue
    torch.manual_seed(seed)
    for i, D in enumerate(D_list):
        for j, shapes in enumerate(shapes_list):
            for k, preprocessing in enumerate(preprocessing_list):
                scaffold = FourierScaffold(
                    shapes=torch.tensor(shapes),
                    D=D,
                    shift=HadamardShiftMatrixRat(shapes=torch.tensor(shapes)),
                    device=device,
                )
                true_positions, gbook, sbook, Pbook = make_data(
                    loop_path, scaffold, preprocessing
                )
                xy_dists_true = get_xy_distributions(scaffold=scaffold, Pbook=Pbook)
                for l, g_rescaling in enumerate(g_rescaling_list):
                    if g_rescaling:
                        gbook_ = gbook * 1e6
                    else:
                        gbook_ = gbook
                    for m, sbook_rand in enumerate(sbook_rand_list):
                        if sbook_rand:
                            sbook_ = torch.sign(torch.randn_like(sbook))
                        else:
                            sbook_ = sbook

                        for n, [layer_name, layer_constr] in enumerate(layers_list):
                            layer = layer_constr(sbook_, D, device, gbook_)
                            (
                                err_l1_first_img_s_h_s,
                                err_l1_last_img_s_h_s,
                                avg_accumulated_err_l2,
                                recovered_first_imgs_shs,
                                recovered_last_imgs_shs,
                                gbook_recovered,
                                recovered_last_imgs_hs,
                            ) = test_layer(layer, gbook_, sbook_, False)

                            fig, axs = plt.subplots(
                                nrows=len(recovered_first_imgs_shs),
                                ncols=6,
                                figsize=(28, 48),
                            )

                            plot_imgs_side_by_side(
                                imgs=[
                                    recovered_first_imgs_shs[i].cpu().reshape(15, 20)
                                    for i in range(len(recovered_first_imgs_shs))
                                ],
                                titles=[f"recovered sbook[0]" for j in range(len(sbook_))],
                                axs=axs[:, 0],
                                fig=fig,
                                use_first_img_scale=True,
                            )
                            plot_imgs_side_by_side(
                                imgs=[
                                    sbook_[i].cpu().reshape(15, 20)
                                    for i in range(len(sbook_))
                                ],
                                titles=[f"sbook[{j}]" for j in range(len(sbook_))],
                                axs=axs[:, 1],
                                fig=fig,
                                use_first_img_scale=True,
                            )
                            plot_imgs_side_by_side(
                                imgs=[
                                    recovered_last_imgs_shs[i].cpu().reshape(15, 20)
                                    for i in range(len(recovered_last_imgs_shs))
                                ],
                                titles=[f"s->h->s sbook[{j}]" for j in range(len(sbook_))],
                                axs=axs[:, 2],
                                fig=fig,
                                use_first_img_scale=True,
                            )
                            plot_imgs_side_by_side(
                                imgs=[
                                    recovered_last_imgs_shs[i].cpu().reshape(15, 20)
                                    for i in range(len(recovered_last_imgs_hs))
                                ],
                                titles=[f"h-> s sbook[{j}]" for j in range(len(sbook_))],
                                axs=axs[:, 3],
                                fig=fig,
                                use_first_img_scale=True,
                            )
                            plot_imgs_side_by_side(
                                imgs=[
                                    gbook_[i].cpu().reshape(D_reshape_size_map[D]).real
                                    for i in range(len(gbook_))
                                ],
                                titles=[
                                    f"Re(gbook[{j}]). ||g||²={gbook_[k].norm() **2:.1e}"
                                    for j in range(len(sbook_))
                                ],
                                axs=axs[:, 4],
                                fig=fig,
                                use_first_img_scale=True,
                            )
                            plot_imgs_side_by_side(
                                imgs=[
                                    gbook_recovered[i]
                                    .cpu()
                                    .reshape(D_reshape_size_map[D])
                                    .real
                                    for i in range(len(gbook_recovered))
                                ],
                                titles=[
                                    f"Re(gbook_rec[{j}]), ||g||²={gbook_recovered[k].norm()**2:.1e}"
                                    for j in range(len(sbook_))
                                ],
                                axs=axs[:, 5],
                                fig=fig,
                                use_first_img_scale=True,
                            )
                            fig.suptitle(
                                f"layer_type={layer_name}, D={D}, downscaling=0.25, g_rescaling={g_rescaling}, rand_sbook={sbook_rand}, shapes={shapes}"
                            )
                            fig.savefig(
                                f"D={D}-layer={layer_name}-rescaling={g_rescaling}-rand_sbook={sbook_rand}-seed={seed}.png"
                            )
                            plt.close(fig)
                            fig, axs = plt.subplots(
                                nrows=len(recovered_first_imgs_shs),
                                ncols=3,
                                figsize=(14, 48),
                            )
                            xy_dists_from_gbook = get_xy_distributions(
                                scaffold,
                                Pbook=torch.einsum("bi,bj->bij", gbook_, gbook_.conj()),
                            )
                            xy_dists_from_g_rec = get_xy_distributions(
                                scaffold,
                                Pbook=torch.einsum(
                                    "bi,bj->bij", gbook_recovered, gbook_recovered.conj()
                                ),
                            )
                            plot_imgs_side_by_side(
                                imgs=[xy_dists_true[k].cpu() for k in range(len(xy_dists_true))], 
                                titles=[f'true_xy_dist[{k}]' for k in range(len(xy_dists_true))], 
                                axs = axs[:, 0], 
                                fig=fig, 
                                use_first_img_scale=False
                            )
                            plot_imgs_side_by_side(
                                imgs=[xy_dists_from_gbook[k].cpu() for k in range(len(xy_dists_true))], 
                                titles=[f'gbook_xy_dist[{k}]' for k in range(len(xy_dists_true))], 
                                axs = axs[:, 1], 
                                fig=fig, 
                                use_first_img_scale=False
                            )
                            plot_imgs_side_by_side(
                                imgs=[xy_dists_from_g_rec[k].cpu() for k in range(len(xy_dists_true))], 
                                titles=[f'g_rec_xy_dist[{k}]' for k in range(len(xy_dists_true))], 
                                axs = axs[:, 2], 
                                fig=fig, 
                                use_first_img_scale=False
                            )
                            fig.suptitle(
                                f"dists-layer_type={layer_name}, D={D}, downscaling=0.25, g_rescaling={g_rescaling}, rand_sbook={sbook_rand}, shapes={shapes}"
                            )
                            fig.savefig(
                                f"dists-D={D}-layer={layer_name}-rescaling={g_rescaling}-rand_sbook={sbook_rand}-seed={seed}.png"
                            )
                            plt.close(fig)

  self.shapes = torch.tensor(shapes).int()
