In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


In [2]:
class HTPCLayer(nn.Module):
    def __init__(self, size, next_size=None, prev_size=None):
        super().__init__()
        self.size = size
        self.state = torch.zeros(size)
        self.prediction = torch.zeros(size)
        self.error = torch.zeros(size)

        # Feedforward weights (prev → this)
        if prev_size:
            self.ff_weights = nn.Parameter(torch.randn(size, prev_size) * 0.1)
        else:
            self.ff_weights = None

        # Feedback weights (next → this)
        if next_size:
            self.fb_weights = nn.Parameter(torch.randn(size, next_size) * 0.1)
        else:
            self.fb_weights = None

    def forward(self, bottom_up=None, top_down=None):
        # Bottom-up input
        if bottom_up is not None and self.ff_weights is not None:
            ff_input = F.linear(bottom_up, self.ff_weights)
        else:
            ff_input = torch.zeros(self.size)

        # Top-down prediction
        if top_down is not None and self.fb_weights is not None:
            fb_input = F.linear(top_down, self.fb_weights)
        else:
            fb_input = torch.zeros(self.size)

        # Combine and update state
        self.state = torch.tanh(ff_input + fb_input)
        self.prediction = fb_input
        self.error = self.state - self.prediction

        return self.state, self.error


In [3]:
class HTPCModel(nn.Module):
    def __init__(self, input_size=30, l2_size=10, l3_size=5):
        super().__init__()
        self.L1 = HTPCLayer(input_size, next_size=l2_size)
        self.L2 = HTPCLayer(l2_size, next_size=l3_size, prev_size=input_size)
        self.L3 = HTPCLayer(l3_size, prev_size=l2_size)

        self.layers = [self.L1, self.L2, self.L3]

    def forward(self, input_vec):
        # L1 receives input directly
        self.L1.state = input_vec.clone()

        # Forward pass
        l2_state, l1_error = self.L2.forward(bottom_up=self.L1.state)
        l3_state, l2_error = self.L3.forward(bottom_up=self.L2.state)

        # Backward prediction
        self.L2.forward(top_down=self.L3.state)
        self.L1.forward(top_down=self.L2.state)

        return {
            "L1": self.L1.state,
            "L2": self.L2.state,
            "L3": self.L3.state,
            "error_L1": self.L1.error,
            "error_L2": self.L2.error
        }

    def learn(self, lr=0.01):
        # Update feedforward weights using prediction error
        with torch.no_grad():
            # L2 ← L1
            dW = torch.ger(self.L2.error, self.L1.state)
            self.L2.ff_weights += lr * dW

            # L3 ← L2
            dW = torch.ger(self.L3.error, self.L2.state)
            self.L3.ff_weights += lr * dW

            # Feedback (optional)
            # L2 → L1
            dW_fb = torch.ger(self.L1.error, self.L2.state)
            self.L1.fb_weights += lr * dW_fb


In [4]:
# Create model
model = HTPCModel(input_size=30, l2_size=10, l3_size=5)

# Simulate sequence A → B → C
A = torch.zeros(30); A[0:5] = 1.0
B = torch.zeros(30); B[10:15] = 1.0
C = torch.zeros(30); C[20:25] = 1.0

sequence = [A, B, C]

# Train
for epoch in range(20):
    for step in sequence:
        output = model.forward(step)
        model.learn(lr=0.01)
