In [1]:
#
# Overlapping sequences
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Define HTPC Layer
class HTPCLayer(nn.Module):
    def __init__(self, size, next_size=None, prev_size=None):
        super().__init__()
        self.size = size
        self.state = torch.zeros(size)
        self.prediction = torch.zeros(size)
        self.error = torch.zeros(size)

        self.ff_weights = nn.Parameter(torch.randn(size, prev_size) * 0.1) if prev_size else None
        self.fb_weights = nn.Parameter(torch.randn(size, next_size) * 0.1) if next_size else None

    def forward(self, bottom_up=None, top_down=None):
        ff_input = F.linear(bottom_up, self.ff_weights) if self.ff_weights is not None and bottom_up is not None else 0
        fb_input = F.linear(top_down, self.fb_weights) if self.fb_weights is not None and top_down is not None else 0

        self.state = torch.tanh(ff_input + fb_input)
        self.prediction = fb_input
        self.error = self.state - self.prediction

        return self.state, self.error

# Define HTPC Model
class HTPCModel(nn.Module):
    def __init__(self, input_size=30, l2_size=10, l3_size=5):
        super().__init__()
        self.L1 = HTPCLayer(input_size, next_size=l2_size)
        self.L2 = HTPCLayer(l2_size, next_size=l3_size, prev_size=input_size)
        self.L3 = HTPCLayer(l3_size, prev_size=l2_size)

    def forward(self, input_vec):
        self.L1.state = input_vec.clone()
        l2_state, _ = self.L2.forward(bottom_up=self.L1.state)
        l3_state, _ = self.L3.forward(bottom_up=self.L2.state)

        self.L2.forward(top_down=self.L3.state)
        self.L1.forward(top_down=self.L2.state)

        return {
            "L1": self.L1.state,
            "L2": self.L2.state,
            "L3": self.L3.state,
            "error_L1": self.L1.error.clone(),
            "error_L2": self.L2.error.clone()
        }

    def learn(self, lr=0.01):
        with torch.no_grad():
            self.L2.ff_weights += lr * torch.ger(self.L2.error, self.L1.state)
            self.L3.ff_weights += lr * torch.ger(self.L3.error, self.L2.state)
            self.L1.fb_weights += lr * torch.ger(self.L1.error, self.L2.state)

# Helper to create sparse input vectors
def make_input(indices, size=30):
    vec = torch.zeros(size)
    vec[indices] = 1.0
    return vec

# Create model
model = HTPCModel()

# Define sequences
A = make_input([0, 1, 2])
B = make_input([10, 11])
C = make_input([20])
X = make_input([12, 13])
Y = make_input([21])

seq1 = [A, B, C]  # Sequence 1
seq2 = [A, X, Y]  # Sequence 2

# === Train on both sequences ===
for epoch in range(30):
    for seq in [seq1, seq2]:
        for step in seq:
            model.forward(step)
            model.learn(lr=0.01)

# === Test disambiguation ===
print("\n=== Test: Sequence 1 (A→B→C) ===")
model.forward(A)
model.forward(B)
output1 = model.forward(C)
print("Prediction Error (L1):", torch.sum(output1["error_L1"]).item())

print("\n=== Test: Sequence 2 (A→X→Y) ===")
model.forward(A)
model.forward(X)
output2 = model.forward(Y)
print("Prediction Error (L1):", torch.sum(output2["error_L1"]).item())

print("\n=== Ambiguous Case: A only, then B and X")
model.forward(A)
out_B = model.forward(B)
out_X = model.forward(X)
print("Error when continuing with B:", torch.sum(out_B["error_L1"]).item())
print("Error when continuing with X:", torch.sum(out_X["error_L1"]).item())



=== Test: Sequence 1 (A→B→C) ===
Prediction Error (L1): 3.1257513910532e-08

=== Test: Sequence 2 (A→X→Y) ===
Prediction Error (L1): 2.264278009533882e-08

=== Ambiguous Case: A only, then B and X
Error when continuing with B: -4.94212144985795e-07
Error when continuing with X: -5.864421837031841e-08
