In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


In [None]:
# function to create a 3D plot of the firing rates of cells

def create3Dplot(rows,cols,Z):
    # Generate the grid
    x = np.arange(0, rows, 1)  # Equivalent to -5:0.1:5 in MATLAB
    y = np.arange(0, cols, 1)
    X, Y = np.meshgrid(x, y)

    # Create the 3D plot
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the surface
    surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none')

    # Add labels and title

    ax.set_title('3D Surface Plot')

    # Add a colour bar for the surface
    #fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)

    # Show the plot
    plt.ion()
    plt.show()
    plt.pause(0.001)
    input("Press Enter to continue...")  # Wait for the user to press Enter
    plt.close(fig)

In [None]:
# STEP 1: Assign memory for the weights matrix and create an environment with a cell representing each location
# initialisation for place cells
pc_rows = 20
pc_cols = 20
pc_nCells = pc_rows *pc_cols
pc_weightsMat = np.zeros(((pc_rows, pc_cols, pc_nCells)))
nIterations = 10
NetThreshold = 0.5


# Create an index sheet for the PC output
# PC_idx is a 10x10 matrix where each element represents a unique index from 0 to 99
PC_idx = np.arange(pc_rows * pc_cols).reshape(pc_rows, pc_cols)

In [None]:
# STEP 2: Assign enviroment into bins and initialisation for sensory cells
# Environment grid size (20x20 bins)
env_rows = 20
env_cols = 20
# Create coordinate grids for the environment
x = np.arange(env_cols)
y = np.arange(env_rows)
X, Y = np.meshgrid(x, y)
# np arrange creates 2d array indices 
# meshgrid creates two 2D arrays (X and Y) that represent the coordinates of every point in the grid.
# which vectorize the distance calculations for the Gaussian firing field.


In [None]:
# STEP 3: Initialize sensory cells' firing fields with Gaussian distributions
sc_nCells = 10  # Number of sensory cells
sigma = 6  # Standard deviation for the Gaussian, controls the spread of the firing field
# Determine how big each Gaussian firing field covers. 
# Approx. covers 1/3 of the environment.
sc_firing_fields = np.zeros((sc_nCells, env_rows, env_cols)) 
# Create a 3D array to store firing fields for each sensory cell across the environment.
# Dimensions are: (number of sensory cells, rows of environment, columns of environment).


In [None]:
# STEP 4: Evenly space the sensory cells' centers
# Dynamically calculate the number of rows and columns based on input dimensions.
centers = []  # List to store the calculated centers of sensory cells.

# Calculate the number of rows and columns needed to place the sensory cells evenly.
n_rows = int(np.sqrt(sc_nCells)) 
# np.sqrt(sc_nCells) calculates the square root of the total number of sensory cells (sc_nCells). 
# This gives an approximate number of rows if the cells were arranged in a square grid.
# int() converts the square root value into an integer (i.e., truncates any decimal).

n_cols = int(np.ceil(sc_nCells / n_rows))
# Divides the total number of sensory cells by the number of rows. This gives the required number of columns per row.
# np.ceil() rounds up to the nearest integer to ensure all sensory cells are placed, even if the number is not perfectly divisible.

# Calculate the step size for even spacing between sensory cells in the grid.
row_step = env_rows // (n_rows + 1)  # Step size for rows to space out cells evenly.
col_step = env_cols // (n_cols + 1)  # Step size for columns to space out cells evenly.

# Loop through each sensory cell to calculate its position in the grid.
for i in range(sc_nCells):
    row = ((i // n_cols) + 1) * row_step
    # For each sensory cell, calculate its row position by dividing the index by n_cols 
    # and multiplying by the row step size to get even spacing.
    
    col = ((i % n_cols) + 1) * col_step
    # For each sensory cell, calculate its column position by using modulus (i % n_cols) to cycle through columns.
    
    centers.append((row, col))  # Append the calculated (row, col) pair to the list of centers.

In [None]:
# STEP 5: Create firing field of the sensory cells

# Define Gaussian function for firing fields
def gaussian(x, y, cx, cy, sigma=3.0):
    return np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (2 * sigma ** 2))
# peak firing at cx, cy which is the center of the gaussian in each bin 

for i, (cx, cy) in enumerate(centers):
    sc_firing_fields[i] = gaussian(X, Y, cx, cy)
# for each center of the sensory cell which has the coordinates of cx and cy, compute gassian around it
# enumerate() lets you iterate over a list while also keeping track of the index.
# assign each computed Gaussian firing field to the correct sensory cell index.

# Plot the firing fields
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i, ax in enumerate(axes):
    ax.imshow(sc_firing_fields[i], origin='lower', cmap='viridis')
    ax.set_title(f'Sensory Cell {i+1}')
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
# STEP 6: Create a weights matrix between the sensory cells and each of the place cells (attractor cells)
sc2pc_weightsMat = np.zeros((sc_nCells, pc_nCells))  # Initialize the weights matrix with zeros

# Loop over each sensory cell
for sc in range(sc_nCells):
    
    # Get the firing rate of the current sensory cell across the environment
    sc_firing_rate = sc_firing_fields[sc]  # This is a 2D array with firing rates for each cell in the environment
    
    # Loop over each place cell
    for pc in range(pc_nCells):
        
        # Convert the flat index of the place cell (pc) into 2D coordinates (pc_x, pc_y) within the place cell grid
        pc_x, pc_y = np.unravel_index(pc, (pc_rows, pc_cols))  
        # This gets the (x, y) coordinates of the place cell from the flattened index

        # Use the firing rate of the current sensory cell at the place cell’s location to calculate the connection weight
        sc2pc_weightsMat[sc, pc] = sc_firing_rate[pc_y, pc_x]  
        # The firing rate at the place cell's location is assigned as the weight for the connection
        # between the sensory cell (sc) and the place cell (pc)

# Normalize the weights per sensory cell
sc2pc_weightsMat /= sc2pc_weightsMat.sum(axis=1, keepdims=True)
# Normalize the weights by dividing by the sum of the weights across each row (sensory cell)
# This ensures that the weights for each sensory cell sum to 1.

In [None]:
# STEP 7: Initialise the weights in the weights matrix of place cells
    # The cell has a weight of 1 with itself and a weight of 0.5 with it's immediate neighbours
    
for cell in range(pc_nCells): # cycle through each cell
    # range(nCells) generates numbers from 0 to nCells - 1 (0 to 99 for a 10x10 grid)

    r, c = np.where(PC_idx == cell) # find the position of that cell in the PC sheet
    # nfinds where in the 2D array PC_idx this number exists and returns (row, col)
    
    pc_weightsMat[r,c,cell] = 1 # Set weight with itself to 1
    
    if r>0:   # If the row is not the top...
        pc_weightsMat[r-1,c,cell] = 0.5
    if r<pc_rows-1: # If row is not at the bottom... (Have to -1 because of python zero idx)
        pc_weightsMat[r+1,c,cell] = 0.5
    if c>0:
        pc_weightsMat[r,c-1,cell] = 0.5
    if c<pc_cols-1: # Have to -1 because of zero idx
        pc_weightsMat[r,c+1,cell] = 0.5
        

In [None]:
# Step 8: Set initial firing rate in the enviornment
NetAct = np.random.uniform(low=0, high=1, size=(pc_rows, pc_cols))  # Add noise
# the firing rate (or activation) of the current cell in the grid, at the position
# NetAct is a matrix of the same size as the grid, with each element representing the firing rate of a corresponding cell in the grid.
# This firing rate is essentially how much influence that particular cell has in the network.

# Alternativly set initial activation to a specific value
#NetAct= (np.ones((rows, cols)))/2; #Set inital activity to 0.5


In [None]:
# Step 9: Calculate the firing of each cell

for _ in range(nIterations):
    # the _ means: This loop is only running nIterations times, and I don’t need to track the loop variable.
    NetInput = np.zeros((pc_rows,pc_cols)) # create a temporary store of cell activity so they don't interact
    # draft version of NetAct where we compute all the new values first.
	# Once all updates are finished, we copy NetInput → NetAct.
    # Each iteration, NetInput starts fresh as a zero matrix, ensuring no leftover values from the previous step interfere
    for k in range(pc_nCells):
        row,col = np.where(PC_idx == k) # find row and column index of cell
        # k loops over all cell indices (0 to nCells - 1)
        # finds where each k is located in the 2D grid
        NetInput = NetInput+pc_weightsMat[:,:,k]*NetAct[row,col] # New activity = old activity * weights (summed across cells) = update rule
        # filling in each next cell on the temporary sheet
        
        # element-wise multiplication between the weight matrix slice weightsMat[:,:,k] and the activation of the current cell NetAct[row,col]
        # NetInput+ the result from the multiplication is added to NetInput. This accumulates the effects of all the cells,
        # on the current cell, effectively calculating the total input that each cell in the grid receives.
        # expresses a weighted sum of the activations of neighboring cells, forming the total input (NetInput) for the current iteration.


        # PC_idx[row, col] == k ensures each cell is linked to one specific (row, col) position.
        # weightsMat[:, :, k] is designed so that the k-th slice always corresponds to the same k in PC_idx.
        # The loop over k (for k in range(nCells)) ensures we update each place cell exactly once.
        
   
    NetInput = NetInput / np.max(NetInput)  # Normalise to values between 0 & 1
    # fire rate relative between cells as they dont just keep firing 
    
                   # Clear the figure for the next iteration


    NetInput = NetInput * (NetInput >= NetThreshold)  # Threshold
    # For each element in the NetInput matrix:
    # If the value is greater than or equal to NetThreshold, the corresponding position will be 1
    # If the value is less than NetThreshold, the corresponding position in the boolean mask will be 0
    # element-wise multiplication between the NetInput matrix and the boolean mask

    # When NetInput >= NetThreshold is True (i.e., 1), the corresponding value in NetInput remains unchanged.
    # When NetInput >= NetThreshold is False (i.e., 0), the corresponding value in NetInput is set to 0.
    # Cells with values below the threshold are set to zero, effectively “turning off” those cells’ activity.
    
    NetAct = NetInput.copy()  # Set Input to each cell as the new firing rate
    
    create3Dplot(pc_rows,pc_cols,NetAct) # plot - you need to press Enter to progress on to next iteration
       
       