In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### Architecture

In [31]:
class CustomNeuron(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomNeuron, self).__init__()
        self.hidden_size = hidden_size
        self.in2hidden = MimicLayer(input_size + hidden_size, hidden_size)
        self.in2output = MimicLayer(input_size + hidden_size, output_size)
    
    def forward(self, x, hidden_state):
        combined = torch.cat((x, hidden_state), 1)
        hidden = torch.sigmoid(self.in2hidden(combined))
        output = self.in2output(combined)
        return output, hidden
    
    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)


In [None]:
import torch
import torch.nn as nn
import math

class MimicLayer(nn.Module):
    def __init__(self, size_in, size_out, dd_min=0.1, dd_max=1.0):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        
        self.light_weights = nn.Parameter(torch.empty(size_out, size_in))
        
        dendrite_tensor = torch.empty(1)
        nn.init.uniform_(dendrite_tensor, dd_min, dd_max)
        self.register_buffer("dendrite_distance", dendrite_tensor)  # fixed, not learnable
        
        self.bias = nn.Parameter(torch.empty(size_out))
        
        nn.init.kaiming_uniform_(self.light_weights, a=math.sqrt(5))
        
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.light_weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        effective_weights = self.light_weights * self.dendrite_distance
        return torch.matmul(x, effective_weights.T) + self.bias


In [33]:
from torchview import draw_graph
model = CustomNeuron(2, 8, 1)
hidden = model.init_hidden(1)
graph = draw_graph(model, input_data=[torch.randn(1, 2), hidden])
graph.visual_graph.render("model_arch", format="png")


'model_arch.png'

In [34]:
from torchsummary import summary
model = CustomNeuron(input_size=2, hidden_size=8, output_size=1).to(device)
summary(model, [(2,), (8,)], device=str(device))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        MimicLayer-1                    [-1, 8]               8
        MimicLayer-2                    [-1, 1]               1
Total params: 9
Trainable params: 0
Non-trainable params: 9
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------


### Training to see it works

In [None]:
seq_len = 5
input_dim = 2
hidden_dim = 8
output_dim = 1
batch_size = 16

model = CustomNeuron(input_dim, hidden_dim, output_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(200):
    x_seq = torch.rand(seq_len, batch_size, input_dim)
    y_seq = x_seq.sum(dim=2, keepdim=True)  # target: sum of two inputs

    hidden = model.init_hidden(batch_size)
    loss = 0

    for t in range(seq_len):
        out, hidden = model(x_seq[t], hidden)
        loss += criterion(out, y_seq[t])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

test_x = torch.tensor([[0.2, 0.6]])
hidden = model.init_hidden(1)
out, _ = model(test_x, hidden)
print("Test input:", test_x)
print("Predicted sum:", out.item())
print("Actual sum:", test_x.sum().item())


Epoch 20, Loss: 1.7533
Epoch 40, Loss: 1.0984
Epoch 60, Loss: 1.0990
Epoch 80, Loss: 0.8414
Epoch 100, Loss: 0.5220
Epoch 120, Loss: 0.4730
Epoch 140, Loss: 0.5377
Epoch 160, Loss: 0.5247
Epoch 180, Loss: 0.4386
Epoch 200, Loss: 0.4892
Test input: tensor([[0.2000, 0.6000]])
Predicted sum: 0.7605530023574829
Actual sum: 0.800000011920929
