In [1]:
from calendar import c
from hmac import new
import math
import torch


# for every pose cell, calculate the velocity shift
def inject_activity(P, v, theta, omega, k_x=1, k_y=1, k_theta=1):
    updated_P = torch.clone(P)
    for g_theta in range(P.shape[0]):
        for g_x in range(P.shape[1]):
            for g_y in range(P.shape[2]):
                updated_P[g_theta][g_x][g_y] = calculate_velocity_shift(
                    P, g_x, g_y, g_theta, -v, theta, -omega, k_x, k_y, k_theta
                )  # -v for same functionality as ezra had implimented, and - omega same reason

    return updated_P


def calculate_velocity_shift(P, l, m, n, v, theta, omega, k_x, k_y, k_theta):
    change = 0
    delta_x, delta_y, delta_theta = calculate_deltas(
        v, theta, omega, k_x, k_y, k_theta
    )  # speed and angular velocity of the robot
    delta_f_x, delta_f_y, delta_f_theta = calculate_delta_fs(
        delta_x, delta_y, delta_theta, v, theta, omega, k_x, k_y, k_theta
    )
    alpha = calculate_alpha(
        delta_f_x,
        delta_f_y,
        delta_f_theta,
        shape=(min(2, P.shape[0]), min(2, P.shape[1]), min(2, P.shape[2])),
    )
    for theta in range(delta_theta, delta_theta + len(alpha)):
        for x in range(delta_x, delta_x + len(alpha[0])):
            for y in range(delta_y, delta_y + len(alpha[0][0])):
                change += (
                    alpha[theta - delta_theta][x - delta_x][
                        y - delta_y
                    ]  # paper does x y theta (but there alpha indices are weird)
                    * P[(n + theta) % len(P)][(l + x) % len(P[0])][
                        (m + y) % len(P[0][0])
                    ]
                )
    return change


def calculate_deltas(
    velocity, theta, omega, k_x, k_y, k_theta
):  # velocity is really a speed, theta a global direction
    delta_x = math.floor(velocity * math.cos(theta) * k_x)
    delta_y = math.floor(velocity * math.sin(theta) * k_y)
    delta_theta = math.floor(k_theta * omega)

    return delta_x, delta_y, delta_theta


def calculate_delta_fs(
    delta_x, delta_y, delta_theta, v, theta, omega, k_x, k_y, k_theta
):
    delta_f_x = (
        k_x * v * math.cos(theta) - delta_x
    )  # i think they forgot this in the paper equations
    delta_f_y = k_y * v * math.sin(theta) - delta_y
    delta_f_theta = k_theta * omega - delta_theta

    return delta_f_x, delta_f_y, delta_f_theta


def calculate_alpha(
    delta_f_x, delta_f_y, delta_f_theta, shape=(2, 2, 2)
):  # delta_x, delta_y, delta_theta
    alpha = torch.zeros(shape)
    for i in range(0, shape[0]):
        for j in range(0, shape[1]):
            for k in range(0, shape[2]):
                alpha[i][j][k] = (
                    g(
                        delta_f_x, i
                    )  # paper has i - delta_x, but its really just saying floored value gets 1 - a, and the other gets a
                    * g(delta_f_y, j)
                    * g(delta_f_theta, k)
                )
    return alpha


# this seemingly has some errors in the paper as b can be values other than 0 or 1, but we can have any (integer) value for b (maybe we need to tune k? or is it mod?)
def g(a, b):
    if b == 0:
        return 1 - a
    else:
        return a

In [None]:
# MIGHT BE DOUBLE COUNTING P[i][j][k], since i do the += and do a copy (not fresh)
import torch
import math


def generate_epsilon(N_x, N_y, sigma):
    """
    Generate a 2D Gaussian kernel (epsilon) for the given dimensions and standard deviation.
    """
    kernel = torch.zeros((N_x, N_y))
    for i in range(N_x):
        for j in range(N_y):
            x = i
            y = j
            kernel[i, j] = math.exp(-(x**2 + y**2) / (2 * sigma**2))
    return kernel


def update_internal_P_jk(P, epsilon):
    assert len(P.shape) == 3, "P should be a 3D matrix. (x, y, theta)"
    updated_P = torch.clone(P)
    N_theta = len(P)
    N_x = len(P[0])
    N_y = len(P[0][0])
    # for every layer
    for i in range(N_theta):
        # for each P[j][k]
        for j in range(N_x):
            for k in range(N_y):
                # for each P[a][b]
                for a in range(N_x):
                    for b in range(N_y):
                        # update P[j][k] with the weighted sum of P[a][b]
                        # ours wraps around for x and y, so abs (instead of clipping)
                        # lower distance, more excitation
                        updated_P[i][j][k] += (
                            P[i][a][b] * epsilon[abs(j - a)][abs(k - b)]
                        )

    return updated_P


def generate_delta(N_theta, sigma):
    delta = torch.zeros(N_theta)
    for i in range(N_theta):
        x = i
        delta[i] = math.exp(-(x**2) / (2 * sigma**2))
    return delta


def update_inter_layer_P_ijk(P, delta, gamma=2):
    assert len(P.shape) == 3, "P should be a 3D matrix. (x, y, theta)"
    updated_P = torch.clone(P)
    N_theta = len(P)
    N_x = len(P[0])
    N_y = len(P[0][0])
    # for every P[i][j][k]
    for i in range(N_theta):
        for j in range(N_x):
            for k in range(N_y):
                for c in range(i - gamma, i + gamma + 1):
                    # update P[i][j][k] with the weighted sum of P[c][j][k]
                    updated_P[i][j][k] += delta[abs(c - i)] * P[c % N_theta][j][k]
    return updated_P


def global_inhibition(P, inhibition_constant=0.004):
    assert len(P.shape) == 3, "P should be a 3D matrix. (x, y, theta)"
    updated_P = torch.clone(P)
    N_theta = len(P)
    N_x = len(P[0])
    N_y = len(P[0][0])
    for i in range(N_theta):
        for j in range(N_x):
            for k in range(N_y):
                updated_P[i][j][k] = max(
                    0, P[i, 1, k] + inhibition_constant * (P[i][j][k] - P.max())
                )
    return updated_P


def normalize(P):
    return P / torch.sum(P)  # normalize the probabilities

In [3]:
l = 0
for i in range(l - 2, l + 2):
    print(i)

-2
-1
0
1


In [4]:
from matplotlib.pyplot import disconnect


N_x = 5
N_y = 5

P = torch.tensor(
    [
        [
            [0.2, 0.1, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0.2, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0.1, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0.1, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0.05, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0.05, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0.1, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0.1, 0],
            [0, 0, 0, 0, 0],
        ],
    ]
)

epsilon = generate_epsilon(N_x, N_y, 1.0)
print("Original P:")
print(P)
print("Exciting nearby with internal xy connections")
print("Epsilon:")
print(epsilon)
updated_P = update_internal_P_jk(P, epsilon)
print("Updated P:")
print(updated_P)

delta = generate_delta(5, 1.0)
print("Delta:")
print(delta)
print("Exciting nearby with intra layer")
updated_P = update_inter_layer_P_ijk(P, delta)
print("Updated P:")
print(updated_P)


print("Only checks 2 layers away")
disconnect_P = torch.tensor(
    [
        [
            [0.2, 0.1, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0.2, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0.2, 0.1, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0.2, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
        ],
        [
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
        ],
    ]
)
updated_P = update_inter_layer_P_ijk(disconnect_P, delta)
print("Updated P:")
print(updated_P)

print("all together: ")
updated_P = update_internal_P_jk(P, epsilon)
print("Updated P:")
print(updated_P)
updated_P = update_inter_layer_P_ijk(updated_P, delta)
print("Updated P:")
print(updated_P)
print("Global inhibition")
updated_P = global_inhibition(updated_P)
print("Updated P:")
print(updated_P)
updated_P = normalize(updated_P)
print("Updated P:")
print(updated_P)

print("Will reach one fixed point")
for i in range(1000):
    updated_P = update_internal_P_jk(updated_P, epsilon)
    updated_P = update_inter_layer_P_ijk(updated_P, delta)
    updated_P = global_inhibition(updated_P)
    updated_P = normalize(updated_P)

print("Updated P:")
print(updated_P)

Original P:
tensor([[[0.2000, 0.1000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.2000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.1000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.1000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0500, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0500, 0.0000],
         [0.0

KeyboardInterrupt: 