In [None]:
import torch
import torch.nn as nn

In [None]:
class MTL(nn.Module):
    def __init__(self, n_entorhinal_in, n_ca3, n_ca1, n_entorhinal_out):
        super(MTL, self).__init__()

        #network parameters
        self.beta_btsp = 0.9

        # Initialize weight matrices for each layer
        self.W_ei_ca3 = nn.Parameter(torch.randn(n_entorhinal_in, n_ca3))
        self.W_ei_ca1 = nn.Parameter(torch.randn(n_entorhinal_in, n_ca1))
        self.W_ca3_ca1 = nn.Parameter(torch.randn(n_ca3, n_ca1))
        self.W_ca1_eo = nn.Parameter(torch.randn(n_ca1, n_entorhinal_out))

    def forward(self, x_ei):
        # Forward pass through the entorhinal cortex to CA3
        x_ca3 = torch.matmul(x_ei, self.W_ei_ca3)
        #x = torch.relu(x)  # Activation function (ReLU)

        # Forward pass through CA3 to CA1
        x_ca1 = torch.matmul(x_ca3, self.W_ca3_ca1)
        #x = torch.relu(x)  # Activation function (ReLU)

        #compute instructive signal
        IS = torch.matmul(x_ei, self.W_ei_ca1)

        #update ca3 -> ca1 connectivity via BTSP
        W_ca3_ca1_prime  = nn.Parameter(torch.einsum('im,in->imn', x_ca3, IS))
        self.W_ca3_ca1 = nn.Parameter((1 - self.beta_btsp)*self.W_ca3_ca1 + self.beta_btsp*W_ca3_ca1_prime)


        # Forward pass through CA1 to entorhinal cortex output
        x_eo = torch.matmul(x_ca1, self.W_ca1_eo)
        #x = torch.relu(x)  # Activation function (ReLU)

        return x_eo

In [None]:
# Example usage
n_entorhinal_in = 100
n_ca3 = 200
n_ca1 = 150
n_entorhinal_out = 100

network = MTL(n_entorhinal_in, n_ca3, n_ca1, n_entorhinal_out)

In [None]:
input_data = torch.randn(1, n_entorhinal_in)  # Batch size of 1 for simplicity
output_data = network(input_data)
print(output_data)

tensor([[[ -828109.2500, -2431570.2500,  2248487.0000, -3377903.5000,
          -7256623.0000, -1265284.0000,   705316.6250,  3768275.2500,
          -3346849.7500, -1362203.6250,  3239516.5000,   372517.4375,
            222799.2031,  4753923.0000,  1697995.1250, -2359737.0000,
          -3376488.0000, -2174224.7500, -1635557.6250, -1312488.0000,
           7309594.0000, -2257230.2500,  3105848.0000,  1698194.2500,
           1783592.1250,  -218549.5781,  -765356.5625,  -590373.2500,
            258794.9688, -2569851.2500,  -645466.6250,   865262.0625,
           -332693.3438,  3538089.0000,  1892420.6250, -3281589.5000,
            755682.1250,   599740.3125,  1957883.8750, -1949523.7500,
            -90988.0469,  2998818.2500, -3008883.2500, -5017602.0000,
            957276.1875, -3724553.5000,  1044929.2500,  3254864.2500,
          -1077123.2500, -2731304.7500, -2578856.5000,  2773974.7500,
           -493858.9375,  3154586.0000,  1261040.5000,  -511922.9375,
          -2640615.7