In [9]:
#
# FULL INTEGRATED CODE (HTM + HTPC in PyTorch)
#
import torch
import torch.nn as nn
import torch.nn.functional as F

# === Minicolumn Encoder ===
class MinicolumnEncoder(nn.Module):
    def __init__(self, n_columns, cells_per_column):
        super().__init__()
        self.n_columns = n_columns
        self.cells_per_column = cells_per_column
        self.total_cells = n_columns * cells_per_column
        self.recurrent_weights = nn.Parameter(torch.zeros(self.total_cells, self.total_cells))

    def activate(self, input_vector, prev_state):
        input_cells = input_vector.repeat_interleave(self.cells_per_column)
        context_input = torch.matmul(self.recurrent_weights, prev_state)
        combined = input_cells + context_input
        combined = combined.view(self.n_columns, self.cells_per_column)
        active_cells = torch.zeros_like(combined)
        winners = torch.argmax(combined, dim=1)
        active_cells[range(self.n_columns), winners] = 1.0
        return active_cells.view(-1)

    def learn(self, prev_state, current_state, lr=0.01):
        dw = torch.ger(current_state, prev_state)
        self.recurrent_weights.data += lr * dw

# === HTPC Layer ===
class HTPCLayer(nn.Module):
    def __init__(self, size, next_size=None, prev_size=None):
        super().__init__()
        self.size = size
        self.state = torch.zeros(size)
        self.prediction = torch.zeros(size)
        self.error = torch.zeros(size)
        self.ff_weights = nn.Parameter(torch.randn(size, prev_size) * 0.1) if prev_size else None
        self.fb_weights = nn.Parameter(torch.randn(size, next_size) * 0.1) if next_size else None

    def forward(self, bottom_up=None, top_down=None):
        ff_input = F.linear(bottom_up, self.ff_weights) if self.ff_weights is not None and bottom_up is not None else 0
        fb_input = F.linear(top_down, self.fb_weights) if self.fb_weights is not None and top_down is not None else 0
        self.state = torch.tanh(ff_input + fb_input)
        self.prediction = fb_input
        self.error = self.state - self.prediction
        return self.state, self.error

# === HTPC Model with HTM L1 ===
class HTPCModelHTM(nn.Module):
    def __init__(self, n_columns=10, cells_per_column=16, l2_size=10):
        super().__init__()
        self.encoder = MinicolumnEncoder(n_columns, cells_per_column)
        self.n_input = n_columns * cells_per_column
        self.L2 = HTPCLayer(l2_size, prev_size=self.n_input)
        self.L1_fb = nn.Parameter(torch.randn(self.n_input, l2_size) * 0.1)
        self.prev_L1_state = torch.zeros(self.n_input)

    def forward(self, column_input):
        self.current_input_column = column_input.clone()  # Save for learning
        # Step 1: Sparse L1 activation with context
        L1_state = self.encoder.activate(column_input, self.prev_L1_state)

        # Step 2: L2 feedforward
        L2_state, _ = self.L2.forward(bottom_up=L1_state)

        # Step 3: Top-down prediction to L1
        L1_pred = F.linear(L2_state, self.L1_fb)
        L1_error = L1_state - L1_pred

        # Store L1 state for next step context
        self.prev_L1_state = L1_state.clone()

        return {
            "L1_state": L1_state,
            "L2_state": L2_state,
            "L1_error": L1_error,
        }
    
    def learn(self, lr=0.01):
        with torch.no_grad():
            self.L2.ff_weights += lr * torch.ger(self.L2.error, self.prev_L1_state)
            self.L1_fb += lr * torch.ger(-self.prev_L1_state, self.L2.state)

            # Correct use of original column input:
            new_L1_state = self.encoder.activate(self.current_input_column, self.prev_L1_state)
            self.encoder.learn(prev_state=self.prev_L1_state, current_state=new_L1_state, lr=lr)


In [10]:
def make_column_input(indices, n_columns=10):
    vec = torch.zeros(n_columns)
    vec[indices] = 1.0
    return vec

n_columns = 10
cells_per_column = 16  # increase this for more separation

model = HTPCModelHTM(n_columns, cells_per_column, l2_size=10)

# Sequences
A = make_column_input([0], n_columns)
B = make_column_input([1], n_columns)
C = make_column_input([2], n_columns)
X = make_column_input([3], n_columns)
Y = make_column_input([4], n_columns)

# Check shapes match
print("Encoder output:", model.encoder.activate(A, torch.zeros(model.n_input)).shape)
print("L2 weight input dim:", model.L2.ff_weights.shape[1])


seq1 = [A, B, C]
seq2 = [A, X, Y]

# === Training ===
for epoch in range(30):
    for seq in [seq1, seq2]:
        for step in seq:
            out = model.forward(step)
            model.learn(lr=0.01)

# === Testing Disambiguation ===
print("\nTesting Sequence A → B → C")
model.prev_L1_state = torch.zeros_like(model.prev_L1_state)
for step in [A, B, C]:
    out = model.forward(step)
    print("L1 Error:", torch.sum(out['L1_error']).item())

print("\nTesting Sequence A → X → Y")
model.prev_L1_state = torch.zeros_like(model.prev_L1_state)
for step in [A, X, Y]:
    out = model.forward(step)
    print("L1 Error:", torch.sum(out['L1_error']).item())


Encoder output: torch.Size([160])
L2 weight input dim: 160

Testing Sequence A → B → C
L1 Error: 171.73410034179688
L1 Error: 171.73410034179688
L1 Error: 171.73410034179688

Testing Sequence A → X → Y
L1 Error: 171.73410034179688
L1 Error: 171.73410034179688
L1 Error: 171.73410034179688
