Weight memory where the readout isn't exactly the home vector, but activity that is shifted towards the home vector in relation to the current heading, so it represents a turn direction instead of how much to turn.

It kind of works, but when the memory is very flat, the current heading bump doesn't get shifted much and the memory readout is more or less always what the current head direction already is. The steering doesn't kick in until we're quite a bit away from home. This means that the search pattern is very large.

In [None]:
%load_ext autoreload
%autoreload 2
import json
import matplotlib.pyplot as plt
import matplotlib.colors as clr
import numpy as np
from loguru import logger
logger.remove()

from pim.simulator import SimulationExperiment

In [None]:
memory=("memory", "internal")
#memory=("CPU4", "output")

memory_balance = 0.25

parameters = {
    "type": "simulation",
    "T_outbound": 1500,
    "T_inbound": 3000,
    "min_homing_distance": 300,
    #"motor_factor": 0.1,
    "motor_factor": 4,
    "record": ["memory", "TB1", "CPU1", "Pontine", "motor", "theory", "CPU4", "CPU4.old"],
    "cx": {
        "type": "weights",
        "output_layer": "motor",
        "params": {
            "noise": 0.1,
            "mem_fade": 0.1 / memory_balance,
            "mem_gain": 0.0025 * memory_balance,
            "pfn_weight_factor": 1 / memory_balance,
        }
    }
}

T = 1600

def timeline(t):
    plt.plot([t, t], [-0.5, 16-0.5], '--')

experiment = SimulationExperiment(parameters)
results = experiment.run("test")
t = results.closest_position_timestep()

plt.figure(figsize=(10, 10))
ax = plt.axes()
ax.axis("equal")
results.plot_path(ax)
plt.legend()

weights = np.array(results.recordings["memory"]["internal"]).T
memory = np.array(results.recordings["memory"]["output"]).T
cpu4 = np.array(results.recordings["CPU4"]["output"]).T
cpu1 = np.array(results.recordings["CPU1"]["output"]).T

plt.figure(figsize=(10, 5))
plt.imshow(cpu4, cmap="hot", interpolation="nearest", aspect="auto", norm=clr.Normalize(0, 1))
timeline(T)
plt.title("PFN")

fig = plt.figure(figsize=(10, 5))
plt.imshow(weights, cmap="hot", interpolation="nearest", aspect="auto", norm=clr.Normalize(0, 1))
timeline(t)
timeline(T)
plt.plot([parameters["T_outbound"], parameters["T_outbound"]], [-0.5, 16-0.5], '-')
plt.title("weights")

plt.figure(figsize=(10, 5))
plt.imshow(memory, cmap="hot", interpolation="nearest", aspect="auto", norm=clr.Normalize(0, 1))
timeline(T)
plt.title("memory output")

plt.figure(figsize=(10, 5))
plt.imshow(weights[:8,:] - weights[8:,:], cmap="hot", interpolation="nearest", aspect="auto", norm=clr.Normalize(-1, 1))
#timeline(T)

plt.figure(figsize=(10, 5))
plt.plot(weights[:8, T], "b-", label="weights L")
plt.plot(weights[8:, T], "b--", label="weights R")

plt.plot(cpu4[:8, T], "g-", label="CPU4 L")
plt.plot(cpu4[8:, T], "g--", label="CPU4 R")

plt.plot(memory[:8, T], "m-", label="memory L")
plt.plot(memory[8:, T], "m--", label="memory R")

plt.plot(cpu1[:8, T], "r-", label="CPU1 L")
plt.plot(cpu1[8:, T], "r--", label="CPU1 R")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 10))

ax = plt.axes()
ax.axis("equal")

results.plot_path(ax)


nth = 10
T_outbound = results.parameters["T_outbound"]

velocities = results.velocities[T_outbound+1::nth]
path = np.array(results.reconstruct_path())[T_outbound+1:,:]
#headings = np.array(results.reconstruct_path)

left = np.flip(velocities, 1) * (1, -1)


motor = np.array(results.recordings["motor"]["output"][T_outbound::nth])
plt.quiver(path[::nth,0], path[::nth,1], left[:,0] * motor, left[:,1] * motor, color="yellow", label="motor", scale=0.5)

motor_motor = motor

motor = np.array(results.recordings["theory"]["output"][T_outbound::nth])
plt.quiver(path[::nth,0], path[::nth,1], left[:,0] * motor, left[:,1] * motor, color="lightgreen", label="cheat", scale=0.5)

motor_theory = motor

print(np.mean((motor_motor - motor_theory)**2))

plt.legend()
plt.show()

#plt.plot([path[T,0], path[T,1]])

plt.figure(figsize=(10, 5))
plt.plot(weights[:8, T], "b-", label="weights L")
plt.plot(weights[8:, T], "b--", label="weights R")

plt.plot(cpu4[:8, T], "g-", label="CPU4 L")
plt.plot(cpu4[8:, T], "g--", label="CPU4 R")

plt.plot(memory[:8, T], "m-", label="memory L")
plt.plot(memory[8:, T], "m--", label="memory R")

plt.plot(cpu1[:8, T], "r-", label="CPU1 L")
plt.plot(cpu1[8:, T], "r--", label="CPU1 R")
plt.legend()
plt.show()

