In [1]:
import sys

sys.path.append("../..")
import torch
from hippocampal_sensory_layers import (
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars,
    ComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars,
    HippocampalSensoryLayer,
)
from fourier_scaffold import FourierScaffold
from preprocessing_cnn import (
    Preprocessor,
    RescalePreprocessing,
    SequentialPreprocessing,
    GrayscaleAndFlattenPreprocessing,
)
from experiments.fourier_miniworld_gridsearch.room_env import RoomExperiment
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from graph_utils import plot_imgs_side_by_side
from agent import TrueData
from tqdm import tqdm

forward_20 = [2] * 20
right_60_deg = [1] * 20
loop_path = (forward_20 + right_60_deg) * 6 + forward_20
device = "cuda"

In [None]:
def make_layer_iterative_no_hidden(sbook, D, device, gbook):
    layer = (
        ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
            input_size=sbook.shape[1],
            N_h=D,
            hidden_layer_factor=0,
            epsilon_sh=0.1,
            epsilon_hs=0.1,
            device=device,
        )
    )
    return layer


def make_layer_iterative(sbook, D, device, gbook):
    layer = (
        ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
            input_size=sbook.shape[1],
            N_h=D,
            hidden_layer_factor=1,
            epsilon_sh=0.1,
            epsilon_hs=0.1,
            device=device,
        )
    )
    return layer


def make_layer_analytic(sbook, D, device, gbook):
    layer = ComplexExactPseudoInverseHippocampalSensoryLayerComplexScalars(
        input_size=sbook.shape[1],
        N_patts=len(gbook),
        hbook=gbook,
        N_h=D,
        device=device,
    )
    return layer


shapes_list = [
    [(3, 3, 3), (5, 5, 5)],
    # [(5, 5, 5), (7, 7, 7)],
    # [(7, 7, 7), (9, 9, 9)],
]

D_list = [2000]
D_reshape_size_map = {
    2000: (40, 50),
    800: (20, 40),
    300: (15, 20),
}

layers_list = [
    ("iterative no hidden", make_layer_iterative_no_hidden),
    # ("iterative hidden", make_layer_iterative_no_hidden),
    ("analytic", make_layer_iterative_no_hidden),
]

preprocessing_list = [
    SequentialPreprocessing(
        transforms=[
            RescalePreprocessing(0.25),
            GrayscaleAndFlattenPreprocessing(device),
        ]
    )
]

# g_rescaling_list = [True, False]
g_rescaling_list = [False]
sbook_rand_list = [True, False]


def make_env():
    return RoomExperiment([3, 0, 3], 0)


def make_gbook(D, shapes):
    scaffold = FourierScaffold(
        D=D, shapes=torch.tensor(shapes), _skip_K_calc=True, _skip_gs_calc=True
    )
    return scaffold.gbook()


def make_data(
    path: list[int],
    scaffold: FourierScaffold,
    preprocessing: Preprocessor,
    noise_dist=None,
):
    env = make_env()

    def get_true_pos():
        p_x, p_y, p_z = env.get_wrapper_attr("agent").pos
        angle = env.get_wrapper_attr("agent").dir
        p = torch.tensor([p_x, p_z, angle]).float().to(device)
        return p

    def _env_reset():
        obs, info = env.reset()
        img = obs
        processed_img = preprocessing.encode(img)
        p = get_true_pos()
        return processed_img, p

    def _obs_postpreprocess(step_tuple, action):
        obs, reward, terminated, truncated, info = step_tuple
        img = obs
        processed_img = preprocessing.encode(img)
        p = get_true_pos()
        return processed_img, p

    start_img, start_pos = _env_reset()
    v_cumulative = torch.zeros(3, device=device)

    true_data = TrueData(start_pos)

    true_positions = [true_data.true_position.clone()]
    gbook = [scaffold.P @ scaffold.g_s]
    sbook = [start_img]

    for i, action in enumerate(path):
        ### env-specific observation processing
        step_tuple = env.step(action)

        ### this is the sensory input not flattened yet
        new_img, new_pos = _obs_postpreprocess(step_tuple, action)

        ### calculation of noisy input
        dp = new_pos - true_data.true_position
        true_data.true_position = new_pos
        noisy_dp = new_pos
        if noise_dist != None:
            noisy_dp += noise_dist.sample(3)

        dt = 1
        v = (dp / dt) * scaffold.scale_factor
        v_cumulative += v

        if v_cumulative.norm(p=float("inf")) < 1:
            continue

        scaffold.velocity_shift(v)
        scaffold.smooth()
        scaffold.sharpen()
        g_avg = scaffold.P @ scaffold.g_s
        true_positions += [true_data.true_position.clone()]
        gbook += [g_avg]
        sbook += [new_img]
        v_cumulative = torch.zeros(3, device=device)

    return torch.vstack(true_positions), torch.vstack(gbook), torch.vstack(sbook)

In [None]:
def test_layer(
    layer: HippocampalSensoryLayer,
    hbook: torch.Tensor,
    sbook: torch.Tensor,
    large: bool,
):
    err_l1_first_img_s_h_s = -torch.ones(len(sbook))
    err_l1_last_img_s_h_s = -torch.ones(len(sbook))
    avg_accumulated_err_l2 = -torch.ones(len(sbook))
    first_img = sbook[0]

    gbook_recovered = torch.zeros_like(gbook_)
    recovered_first_imgs = torch.zeros_like(sbook)
    recovered_last_imgs = torch.zeros_like(sbook)

    for i in tqdm(range(len(sbook))):
        h_ = hbook[i]
        if large:
            h = torch.einsum("i,j->ij", h_, h_.conj()).flatten()
        else:
            h = h_

        s = sbook[i]
        layer.learn(h, s)

        gbook_recovered[i] = layer.hippocampal_from_sensory(s)
        recovered_first_imgs[i] = layer.sensory_from_hippocampal(
            layer.hippocampal_from_sensory(first_img)
        )[0]
        recovered_last_imgs[i] = layer.sensory_from_hippocampal(
            layer.hippocampal_from_sensory(s)
        )[0]

        err_l1_first_img_s_h_s[i] = torch.mean(
            torch.abs(
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(first_img)
                )[0]
                - first_img
            )
        )

        err_l1_last_img_s_h_s[i] = torch.mean(
            torch.abs(
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(sbook[i])
                )[0]
                - sbook[i]
            )
        )

        avg_accumulated_err_l2[i] = torch.mean(
            (
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(sbook[: i + 1])
                )
                - sbook[: i + 1]
            )
            ** 2
        )
        if (
            # err_l1_first_img_s_h_s[i] > 10e5
            # or avg_accumulated_err_l2[i] > 10e5
            torch.any(torch.isnan(err_l1_first_img_s_h_s[i]))
            or torch.any(torch.isnan(avg_accumulated_err_l2[i]))
        ):
            break

    return (
        err_l1_first_img_s_h_s,
        err_l1_last_img_s_h_s,
        avg_accumulated_err_l2,
        recovered_first_imgs,
        recovered_last_imgs,
        gbook_recovered,
    )


def plot_avg_acc_l2_err_on_ax(ax: Axes, avg_accumulated_err_l2: torch.Tensor, label):
    x = torch.arange(0, len(avg_accumulated_err_l2[0]))
    mean = avg_accumulated_err_l2.mean(dim=0)
    std = avg_accumulated_err_l2.std(dim=0)
    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean - std, mean + std, alpha=0.2)

    return ax


def plot_first_img_l1_err_on_ax(ax: Axes, err_l1_first_img_s_h_s: torch.Tensor, label):
    x = torch.arange(0, len(err_l1_first_img_s_h_s[0]))
    mean = err_l1_first_img_s_h_s.mean(dim=0)
    std = err_l1_first_img_s_h_s.std(dim=0)
    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean - std, mean + std, alpha=0.2)

    return ax


def plot_last_img_l1_err_on_ax(ax: Axes, err_l1_last_img_s_h_s: torch.Tensor, label):
    x = torch.arange(0, len(err_l1_last_img_s_h_s[0]))
    mean = err_l1_last_img_s_h_s.mean(dim=0)
    std = err_l1_last_img_s_h_s.std(dim=0)
    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean - std, mean + std, alpha=0.2)

    return ax


def set_ax_titles(ax: Axes, title, xtitle, ytitle):
    ax.set_title(title)
    ax.set_xlabel(xtitle)
    ax.set_ylabel(ytitle)
    ax.legend()


def add_vertical_bar_on_ax(ax: Axes, x):
    ax.axvline(x=x, color="b", linestyle="--")


def add_horizontal_bar_on_ax(ax: Axes, y, label):
    ax.axhline(y=y, color="k", linestyle="--", label=label)


for i, D in enumerate(D_list):
    for j, shapes in enumerate(shapes_list):
        for k, preprocessing in enumerate(preprocessing_list):
            scaffold = FourierScaffold(shapes=torch.tensor(shapes), D=D, device=device)
            true_positions, gbook, sbook = make_data(loop_path, scaffold, preprocessing)
            for l, g_rescaling in enumerate(g_rescaling_list):
                if g_rescaling:
                    gbook_ = gbook * 1e6
                else:
                    gbook_ = gbook
                for m, sbook_rand in enumerate(sbook_rand_list):
                    if sbook_rand:
                        sbook_ = torch.sign(torch.randn_like(sbook))
                    else:
                        sbook_ = sbook

                    for n, [layer_name, layer_constr] in enumerate(layers_list):
                        layer = layer_constr(sbook_, D, device, gbook_)
                        (
                            err_l1_first_img_s_h_s,
                            err_l1_last_img_s_h_s,
                            avg_accumulated_err_l2,
                            recovered_first_imgs,
                            recovered_last_imgs,
                            gbook_recovered,
                        ) = test_layer(layer, gbook_, sbook_, False)

                        fig, axs = plt.subplots(
                            nrows=len(recovered_first_imgs), ncols=5, figsize=(20, 48)
                        )

                        plot_imgs_side_by_side(
                            imgs=[
                                recovered_first_imgs[i].cpu().reshape(15, 20)
                                for i in range(len(recovered_first_imgs))
                            ],
                            titles=[f"recovered sbook[0]" for j in range(len(sbook_))],
                            axs=axs[:, 0],
                            fig=fig,
                            use_first_img_scale=True,
                        )
                        plot_imgs_side_by_side(
                            imgs=[
                                sbook_[i].cpu().reshape(15, 20)
                                for i in range(len(sbook_))
                            ],
                            titles=[f"sbook[{j}]" for j in range(len(sbook_))],
                            axs=axs[:, 1],
                            fig=fig,
                            use_first_img_scale=True,
                        )
                        plot_imgs_side_by_side(
                            imgs=[
                                recovered_last_imgs[i].cpu().reshape(15, 20)
                                for i in range(len(recovered_last_imgs))
                            ],
                            titles=[
                                f"recovered sbook[{j}]" for j in range(len(sbook_))
                            ],
                            axs=axs[:, 2],
                            fig=fig,
                            use_first_img_scale=True,
                        )
                        plot_imgs_side_by_side(
                            imgs=[
                                gbook_[i].cpu().reshape(D_reshape_size_map[D]).real
                                for i in range(len(gbook_))
                            ],
                            titles=[f"Re(gbook[{j}])" for j in range(len(sbook_))],
                            axs=axs[:, 3],
                            fig=fig,
                            use_first_img_scale=True,
                        )
                        plot_imgs_side_by_side(
                            imgs=[
                                gbook_recovered[i]
                                .cpu()
                                .reshape(D_reshape_size_map[D])
                                .real
                                for i in range(len(gbook_recovered))
                            ],
                            titles=[
                                f"Re(gbook_recovered[{j}])" for j in range(len(sbook_))
                            ],
                            axs=axs[:, 4],
                            fig=fig,
                            use_first_img_scale=True,
                        )
                        fig.suptitle(
                            f"layer_type={layer_name}, D={D}, downscaling=0.25, g_rescaling={g_rescaling}, rand_sbook={sbook_rand}, shapes={shapes}"
                        )
                        fig.savefig(
                            f"D={D}-layer={layer_name}-rescaling={g_rescaling}-rand_sbook={sbook_rand}.png"
                        )

module shapes:  tensor([[3, 3, 3],
        [5, 5, 5]])
N_g (D) :  2000
M       :  2
d       :  3
N_patts :  3375


KeyboardInterrupt: 