In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

2025-12-05 10:47:28.162258: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-05 10:47:29.468539: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-12-05 10:47:33.830251: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-12-05 10:47:33.834633: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# create dataset of values between -1 and 1 for 2D goal space
intervals = 51
au = 0.8
xx = np.linspace(-au, au, intervals)
yy = np.linspace(-au, au, intervals)
xx, yy = np.meshgrid(xx, yy)
traj = np.vstack([xx.ravel(), yy.ravel()]).T
print(traj.shape)

ncues = 12
holes = np.linspace(-au+0.2, au-0.2, 7)
holoc_all = np.array([[y, x] for x in holes[::-1] for y in holes])
reward_locs = np.random.choice(len(holoc_all), ncues, replace=False)
targets = holoc_all[reward_locs]
memory = np.concatenate([targets, np.ones((ncues,1))], axis=1)
print(memory.shape)


(2601, 2)
(12, 3)


In [12]:
class MotorControllerWithActivations(tf.keras.Model):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.h1 = base_model.h1
        self.h2 = base_model.h2
        self.action = base_model.action

    def call(self, x):
        h1_out = self.h1(x)
        h2_out = self.h2(h1_out)
        a_out  = self.action(h2_out)
        return a_out, h1_out, h2_out

In [14]:
# load the motor controller
N = 128
L = 2
raw_model = tf.keras.models.load_model('../motor_controller/mc_2h128_linear_30mb_31sp_0.6oe_20e_2022-10-08')
model = MotorControllerWithActivations(raw_model)

print(raw_model.summary())

Model: "motor_controller"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 h1 (Dense)                  multiple                  768       
                                                                 
 h2 (Dense)                  multiple                  16512     
                                                                 
 action (Dense)              multiple                  5160      
                                                                 
Total params: 22,440
Trainable params: 22,440
Non-trainable params: 0
_________________________________________________________________
None


In [4]:
c=0
np.concatenate([np.tile(memory[c], (traj.shape[0], 1)), traj], axis=1).shape

(2601, 5)

In [17]:
# get neural responses
hs = np.zeros([ncues, traj.shape[0], N*L])

for c in range(ncues):

    x = np.concatenate([np.tile(memory[c], (traj.shape[0], 1)), traj], axis=1)

    # get hidden activations of tensorflow model
    # hidden_activations = nmc.layers[0](tf.convert_to_tensor(x[t], dtype=tf.float32))
    action_out, h1_act, h2_act = model(x)
    h = tf.concat([h1_act, h2_act], axis=1)
    hs[c, :, :] = h.numpy()

In [18]:
# find neurons that have an x and y axis tuning

from scipy.stats import pearsonr

traj_x, traj_y = traj[:, 0], traj[:, 1]  

x_rvals = np.zeros((N, ncues))
y_rvals = np.zeros((N, ncues))

for c in range(ncues):
    for n in range(N):
        h = hs[c, :, n]
        x_rvals[n, c] = abs(pearsonr(h, traj_x)[0])
        y_rvals[n, c] = abs(pearsonr(h, traj_y)[0])

# Find cells with strong x (or y) encoding, and not both
correct_threshold = 0.7
wrong_axis_threshold = 0.2
x_cells = np.where((x_rvals.mean(axis=1) > correct_threshold) & (y_rvals.mean(axis=1) < wrong_axis_threshold))[0]
y_cells = np.where((y_rvals.mean(axis=1) > correct_threshold) & (x_rvals.mean(axis=1) < wrong_axis_threshold))[0]

# x_cells = np.where((np.nanmean(x_rvals, axis=1) > correct_threshold))[0]
# y_cells = np.where((np.nanmean(y_rvals, axis=1) > correct_threshold))[0]

# x_cells = np.where((np.mean(x_rvals, axis=1) > correct_threshold))[0]
# y_cells = np.where((np.mean(y_rvals, axis=1) > correct_threshold))[0]

print('X-coding neurons:', x_cells)
print('Y-coding neurons:', y_cells)




X-coding neurons: [14 57 92]
Y-coding neurons: [ 32  37  58 121]
