In [None]:
import math
import torch


# for every pose cell, calculate the velocity shift
def inject_activity(P, v, theta, omega, k_x=1, k_y=1, k_theta=1):
    updated_P = torch.clone(P)

    v_x = v * math.cos(theta) * k_x
    v_y = v * math.sin(theta) * k_y
    v_theta = k_theta * omega
    delta_f_x, delta_x = math.modf(v_x)
    delta_f_y, delta_y = math.modf(v_y)
    delta_f_theta, delta_theta = math.modf(v_theta)
    delta_f_x, delta_x = abs(delta_f_x), int(delta_x)
    delta_f_y, delta_y = abs(delta_f_y), int(delta_y)
    delta_f_theta, delta_theta = abs(delta_f_theta), int(delta_theta)

    kD, kH, kW = min(2, P.shape[0]), min(2, P.shape[1]), min(2, P.shape[2])

    alpha = calculate_alpha(
        delta_f_x,
        delta_f_y,
        delta_f_theta,
        shape=(kD, kH, kW),
    )

    shifted = torch.roll(P, shifts=(delta_theta, delta_x, delta_y), dims=(0, 1, 2))
    padded = torch.nn.functional.pad(
        shifted.unsqueeze(0).unsqueeze(0),
        (kW // 2, 0, kH // 2, 0, kD // 2, 0),
        mode="circular",
    )  # pad with circular padding
    # apply the convulution in the flipped directions
    flipped = torch.flip(padded, dims=(0, 1, 2, 3, 4))
    updated_P = (
        torch.nn.functional.conv3d(
            flipped,
            alpha.unsqueeze(0).unsqueeze(0),
            stride=1,
            padding=0,
        )
        .squeeze(0)
        .squeeze(0)
    )
    return updated_P.flip(0, 1, 2)  # flip back to original orientation


def calculate_velocity_shift(P, l, m, n, delta_x, delta_y, delta_theta, alpha):
    change = 0
    # print(alpha.shape)
    for t in range(delta_theta, delta_theta + len(alpha)):
        for x in range(delta_x, delta_x + len(alpha[0])):
            for y in range(delta_y, delta_y + len(alpha[0][0])):
                change += (
                    alpha[t - delta_theta][x - delta_x][
                        y - delta_y
                    ]  # paper does x y theta (but there alpha indices are weird)
                    * P[(n + t) % len(P)][(l + x) % len(P[0])][(m + y) % len(P[0][0])]
                )

    print("change", change)
    return change


def calculate_alpha(
    delta_f_x, delta_f_y, delta_f_theta, shape=(2, 2, 2)
):  # delta_x, delta_y, delta_theta
    alpha = torch.zeros(shape)
    for i in range(0, shape[0]):
        for j in range(0, shape[1]):
            for k in range(0, shape[2]):
                alpha[i][j][k] = g(delta_f_theta, i) * g(delta_f_x, j) * g(delta_f_y, k)
    print("alpha", alpha)
    return alpha


# this seemingly has some errors in the paper as b can be values other than 0 or 1, but we can have any (integer) value for b (maybe we need to tune k? or is it mod?)
def g(a, b):
    if b == 0:
        return 1 - a
    else:
        return a

In [356]:
updated_P = torch.tensor(
    [
        [
            [1.0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ],
        [
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ],
        [
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ],
    ]
)

In [357]:
updated_P = inject_activity(
    updated_P,
    v=0.0,
    theta=0,
    omega=1,
    k_x=1,
    k_y=1,
    k_theta=1,
)

print(updated_P)

alpha tensor([[[1., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]])
tensor([[[[[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],

          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],

          [[0., 0., 0., 0.],
           [0., 1., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],

          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]]]]])
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])


In [2]:
# MIGHT BE DOUBLE COUNTING P[i][j][k], since i do the += and do a copy (not fresh)
import torch
import math


def generate_epsilon(N_x, N_y, sigma):
    """
    Generate a 2D Gaussian kernel (epsilon) for the given dimensions and standard deviation.
    """
    kernel = torch.zeros((N_x, N_y))
    for i in range(N_x):
        for j in range(N_y):
            x = i - N_x // 2
            y = j - N_y // 2
            kernel[i, j] = math.exp(-(x**2 + y**2) / (2 * sigma**2))
    return kernel


def update_internal_P_jk(P, epsilon):
    assert len(P.shape) == 3, "P should be a 3D matrix. (x, y, theta)"
    updated_P = torch.clone(P)
    # for every layer
    p_x = (
        ((len(P[0])) // 2, (len(P[0]) - 1) // 2)
        if len(P[0]) % 2 == 0
        else ((len(P[0])) // 2, (len(P[0]) // 2))
    )
    p_y = (
        ((len(P[0][0])) // 2, (len(P[0][0]) - 1) // 2)
        if len(P[0][0]) % 2 == 0
        else ((len(P[0][0])) // 2, (len(P[0][0]) // 2))
    )
    padded = torch.nn.functional.pad(
        P,
        p_x + p_y,
        # mode="circular",
        mode="constant",
        value=0,
    )  # pad the tensor with zeros
    # print(padded)
    print(f"padded shape: {padded.shape}")
    updated_P = torch.nn.functional.conv2d(
        padded.unsqueeze(1),
        epsilon.unsqueeze(0).unsqueeze(0),
        padding="valid",
        stride=1,
    )
    print(f"updated shape: {updated_P.shape}")
    updated_P = updated_P.reshape(P.shape)

    return P + updated_P


def generate_delta(N_theta, sigma, gamma=2):
    delta = torch.zeros(N_theta)
    for i in range(N_theta):
        x = i - N_theta // 2
        delta[i] = math.exp(-(x**2) / (2 * sigma**2))

    return delta[
        (N_theta // 2) - gamma : (N_theta // 2) + gamma + 1
    ]  # cut off the first and last gamma values


def update_inter_layer_P_ijk(P, delta):
    assert len(P.shape) == 3, f"P should be a 3D matrix. (x, y, theta), got {P.shape}"
    updated_P = torch.clone(P)
    print(f"P shape: {P.unsqueeze(0).shape}")
    # pad theta layers circularly
    padded = torch.nn.functional.pad(
        P.permute(1, 2, 0).unsqueeze(0),
        (len(delta) // 2, len(delta) // 2) + (0, 0) + (0, 0),
        mode="circular",
    )
    print(f"padded shape: {padded.shape}")
    padded = padded.reshape(P.shape[0] * P.shape[1], 1, -1)
    # print(padded[0])
    updated_P = torch.nn.functional.conv1d(
        padded,
        delta.unsqueeze(0).unsqueeze(0),
        padding="valid",
        stride=1,
    )
    print(f"updated shape: {updated_P.shape}")
    updated_P = updated_P.permute(2, 0, 1).reshape(P.shape)
    return P + updated_P


def global_inhibition(P, inhibition_constant=0.004):
    assert len(P.shape) == 3, "P should be a 3D matrix. (x, y, theta)"
    updated_P = torch.clip(
        P + inhibition_constant * (P - torch.max(P)), min=0
    )  # clip to avoid negative values
    return updated_P


def normalize(P):
    return P / torch.sum(P)  # normalize the probabilities

In [None]:
from matplotlib.pyplot import disconnect


N_theta = 50
N_x = 50
N_y = 50

# P = torch.tensor(
#     [
#         [
#             [0.2, 0.1, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0.2, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0.1, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0.1, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0.05, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0.05, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0.1, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0.1, 0],
#             [0, 0, 0, 0, 0],
#         ],
#     ]
# )

P = torch.rand(N_theta, N_x, N_y)
P[0][0][0] += N_theta * N_x * N_y
P[-5][-5][-5] += N_theta * N_x * N_y
P = P / torch.sum(P)  # normalize the probabilities

epsilon = generate_epsilon(N_x, N_y, 0.3)
print("Original P:")
print(P)
print("Exciting nearby with internal xy connections")
print("Epsilon:")
print(epsilon)
updated_P = update_internal_P_jk(P, epsilon)
print("Updated P:")
print(updated_P)

delta = generate_delta(N_theta, 0.3)
print("Delta:")
print(delta)
print("Exciting nearby with intra layer")
updated_P = update_inter_layer_P_ijk(P, delta)
print("Updated P:")
print(updated_P)


# print("Only checks 2 layers away")
# disconnect_P = torch.tensor(
#     [
#         [
#             [0.2, 0.1, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0.2, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0.2, 0.1, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0.2, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#         ],
#         [
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#             [0, 0, 0, 0, 0],
#         ],
#     ]
# )
# updated_P = update_inter_layer_P_ijk(disconnect_P, delta)
# print("Updated P:")
# print(updated_P)

print("all together: ")
updated_P = update_internal_P_jk(P, epsilon)
print("Updated P:")
print(updated_P)
updated_P = update_inter_layer_P_ijk(updated_P, delta)
print("Updated P:")
print(updated_P)
print("Global inhibition")
updated_P = global_inhibition(updated_P)
print("Updated P:")
print(updated_P)
updated_P = normalize(updated_P)
print("Updated P:")
print(updated_P)

print("Will reach one fixed point")
updated_P = P
for i in range(50):
    print("Iteration", i)
    updated_P = update_internal_P_jk(updated_P, epsilon)
    updated_P = update_inter_layer_P_ijk(updated_P, delta)
    updated_P = global_inhibition(updated_P)
    updated_P = normalize(updated_P)
    print(updated_P.max())

print("Updated P:")
print(updated_P)
print(len(updated_P[updated_P > 0]))

Original P:
tensor([[[4.0000e-01, 2.6324e-06, 1.7954e-06,  ..., 1.9351e-06,
          1.3921e-06, 2.1707e-06],
         [4.1497e-07, 3.1724e-06, 2.8008e-06,  ..., 2.7830e-06,
          1.4756e-06, 1.0346e-06],
         [2.9319e-06, 2.3213e-06, 2.5922e-06,  ..., 2.4418e-06,
          1.5849e-06, 3.1970e-06],
         ...,
         [6.5957e-07, 1.1530e-06, 1.7759e-06,  ..., 5.3733e-07,
          9.1039e-07, 3.4111e-07],
         [7.1720e-07, 2.6045e-06, 3.0138e-06,  ..., 2.4191e-06,
          2.0372e-06, 1.3924e-06],
         [5.3318e-07, 2.0856e-06, 1.5016e-06,  ..., 2.3251e-07,
          4.8671e-08, 1.9446e-07]],

        [[2.5636e-06, 1.6051e-06, 3.0638e-06,  ..., 2.3826e-06,
          1.5895e-06, 1.6676e-07],
         [7.5330e-07, 2.5702e-06, 3.0804e-06,  ..., 1.8306e-06,
          8.3945e-07, 1.7333e-06],
         [6.0254e-07, 1.0593e-06, 1.7238e-07,  ..., 2.7682e-06,
          2.2228e-06, 7.1936e-07],
         ...,
         [4.6177e-07, 2.8001e-06, 2.5750e-06,  ..., 8.1100e-07,
   