# MiniWorld VectorHash + CNN Performance Tests

This notebook runs a grid of VectorHash localization tests on the
MiniWorld-Maze environment, comparing raw-flattened pixels against a
ResNet-18 pretrained encoder.

In [1]:
# ── Imports & reproducibility ──
import os, itertools, pickle
import torch, numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

from miniworld.params import DEFAULT_PARAMS
from preprocessing_cnn import PreprocessingCNN
from vectorhash import build_vectorhash_architecture
from agent import VectorhashAgent
from agent_history import VectorhashAgentKidnappedHistory
from smoothing import IdentitySmoothing, PolynomialSmoothing, SoftmaxSmoothing

# fix seeds
torch.manual_seed(0)
np.random.seed(0)


2025-06-09 17:44:13.544 Python[52628:3106088] ApplePersistenceIgnoreState: Existing state will not be touched. New state will be written to /var/folders/zz/_8nvyjvj4jd9v1rfts46s0ph0000gn/T/com.apple.python3.savedState


In [3]:
# ── Helpers: build_env & build_model ──

def build_env():
    """Return a new MiniWorld-Maze-v0 gym env with no randomness."""
    params = DEFAULT_PARAMS.copy().no_random()
    return gym.make(
        "MiniWorld-Maze-v0",
        max_episode_steps=-1,
        params=params,
        domain_rand=False
    )

def build_model(
    smoothing,
    shift=None,
    shapes=[(5,5,5),(8,8,8)],
    N_h=600,
    latent_dim=128,
    device=torch.device("cpu"),
):
    """
    Returns (vh, preproc) tuple:
      - vh: VectorHaSH expecting inputs of size `latent_dim`.
      - preproc: CNN encoder mapping images → R^latent_dim.
    """
    # 1) CNN preprocessor (ResNet-18 adapter)
    preproc = PreprocessingCNN(
        device=torch.device("cpu"),
        latent_dim=latent_dim,
        input_channels=3,
        target_size=(224,224),
        model_path="resnet18_adapter.pth"
    )

    # 2) VectorHaSH scaffold
    #    input_size = latent_dim because agent feeds vh the CNN output
    vh = build_vectorhash_architecture(
        shapes=shapes,
        N_h=N_h,
        input_size=latent_dim,
        smoothing=smoothing,
        shift=shift,
        limits=(2*np.pi, 2*np.pi, 2*np.pi), 
        relu=True,
        percent_nonzero_relu=0.2,
        device=device
    )

    return vh, preproc


In [4]:
store_opts   = [True, False]   # True="Always", False="When New"
shift_opts   = ["additive", "multiplicative"]
hard_opts    = [True, False]   # True="Hard", False="Soft"
smooth_opts  = [
    IdentitySmoothing(),
    PolynomialSmoothing(k=1.0),
    PolynomialSmoothing(k=1.5),
    SoftmaxSmoothing(T=0.1),
]

basedir = "miniworld_cnn_tests"
os.makedirs(basedir, exist_ok=True)


In [5]:

for store_new, shift_m, hard_store, smooth in itertools.product(
        store_opts, shift_opts, hard_opts, smooth_opts
    ):

    sm_str = (
        "Identity" if isinstance(smooth, IdentitySmoothing) else
        f"Poly(k={smooth.k})" if hasattr(smooth, "k") else
        f"Softmax(T={smooth.T})"
    )
    run_name = (
        f"cnn__{'Always' if store_new else 'WhenNew'}"
        f"__{shift_m}__{'Hard' if hard_store else 'Soft'}__{sm_str}"
    )
    print("Running", run_name)

    env = build_env()

    #shift_inst = RatShiftWithCompetitiveAttractorDynamics(
    #    sigma_xy=0.3, sigma_theta=0.3,
    #    inhibition_constant=0.004, delta_gamma=1,
    #    device=torch.device("cpu")
    #)
    vh, preproc = build_model(smoothing=smooth)

    agent = VectorhashAgent(
        vectorhash=vh,
        env=env,
        hard_store=hard_store,
        store_new=store_new,
        shift_method=shift_m,
        preprocessor=preproc
    )

    hist = VectorhashAgentKidnappedHistory(
        agent=agent,
        n_steps=500,
        kidnapping_step=200,
        smoothing=smooth
    )
    hist.run()

    with open(f"{basedir}/{run_name}.pkl", "wb") as f:
        pickle.dump(hist, f)

    # plot error curve
    plt.figure(figsize=(6,3))
    plt.plot(hist.errors, label=run_name)
    plt.title(run_name)
    plt.xlabel("Timestep")
    plt.ylabel("Position error")
    plt.legend(fontsize="x-small")
    plt.tight_layout()
    plt.savefig(f"{basedir}/{run_name}_error.png", dpi=120)
    plt.close()

    env.close()


Running cnn__Always__additive__Hard__Identity
Falling back to num_samples=4
Falling back to non-multisampled frame buffer
Falling back to num_samples=4
Falling back to non-multisampled frame buffer




by_scaling
module shapes:  [(5, 5, 5), (8, 8, 8)]
N_g     :  637
N_patts :  64000
N_h     :  600


AttributeError: 'NoneType' object has no attribute 'device'