In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class HopfieldEnergyNet(nn.Module):
    def __init__(self, num_neurons):
        super().__init__()
        self.num_neurons = num_neurons
        self.weights = nn.Parameter(torch.zeros(num_neurons, num_neurons))  # Symmetric, learnable
        self.bias = nn.Parameter(torch.zeros(num_neurons))  # For flexibility

    def energy(self, states):
        # E = -0.5 * s^T W s - b^T s (batch-averaged)
        interaction = torch.bmm(states.unsqueeze(1), torch.matmul(states, self.weights.transpose(-2, -1)).unsqueeze(2)).squeeze()
        bias_term = torch.matmul(states, self.bias)
        return -0.5 * interaction.mean() - bias_term.mean()

    def forward(self, init_states, steps=10, beta=1.0):
        # Relaxation: Iterative update to minimize energy (continuous, stochastic for exploration)
        states = init_states.clone()  # batch x num_neurons
        for _ in range(steps):
            activation = torch.matmul(states, self.weights) + self.bias
            states = torch.tanh(beta * activation)  # Soft update; beta controls sharpness
            states += 0.01 * torch.randn_like(states)  # Noise for escaping local minima
        return states

# Setup for multi-task: Patterns as desired states (e.g., inputs + outputs encoded)
num_neurons = 20  # Pool for clusters
net = HopfieldEnergyNet(num_neurons)
optimizer = optim.Adam(net.parameters(), lr=0.01)

# Toy patterns: Encode XOR/AND as binary states in specific neurons
# E.g., neurons 0-1: inputs, 2: XOR out; 3-4: inputs, 5: AND out; rest intermediaries (init to 0)
xor_patterns = torch.tensor([
    [1., 1., -1., 0., 0., 0.] + [0.] * 14,  # [in1=1, in2=1, out=-1 (0)]
    [1., -1., 1., 0., 0., 0.] + [0.] * 14,   # [1,0,1]
    [-1., 1., 1., 0., 0., 0.] + [0.] * 14,   # [0,1,1]
    [-1., -1., -1., 0., 0., 0.] + [0.] * 14  # [0,0,0]
])  # batch x num_neurons (sign for binary)
and_patterns = torch.tensor([
    [0., 0., 0., 1., 1., 1.] + [0.] * 14,    # [in3=1, in4=1, out=1]
    [0., 0., 0., 1., -1., -1.] + [0.] * 14,  # [1,0,0]
    [0., 0., 0., -1., 1., -1.] + [0.] * 14,  # [0,1,0]
    [0., 0., 0., -1., -1., -1.] + [0.] * 14  # [0,0,0]
])

# Training: Minimize energy for patterns + L1 sparsity to drop trivial connections
for epoch in range(1000):
    optimizer.zero_grad()
    
    # Compute energy for both tasks
    xor_energy = net.energy(xor_patterns)
    and_energy = net.energy(and_patterns)
    task_energy = xor_energy + and_energy
    
    # Sparsity: L1 on off-diagonal weights
    sparsity_loss = 0.1 * (net.weights - torch.diag(torch.diag(net.weights))).abs().sum()
    
    full_loss = task_energy + sparsity_loss
    full_loss.backward()
    optimizer.step()
    
    # Enforce symmetry (project after update)
    net.weights.data = (net.weights.data + net.weights.data.t()) / 2
    net.weights.data.fill_diagonal_(0)  # No self-loops
    
    if epoch % 200 == 0:
        print(f"Epoch {epoch}: Energy {task_energy.item():.4f}, Sparsity {sparsity_loss.item():.4f}")

# Inference: Inject inputs, relax, read outputs
def infer(inputs, input_neurons, output_neurons, steps=10):
    init_states = torch.zeros(1, num_neurons)
    for i, neuron in enumerate(input_neurons):
        init_states[0, neuron] = inputs[i]
    final_states = net(init_states, steps=steps)
    outputs = final_states[0, output_neurons]
    return torch.sign(outputs)  # Binary decode

# Test XOR
xor_test = torch.tensor([1., 1.])  # Should output -1 (0)
xor_out = infer(xor_test, [0, 1], [2])
print("XOR Output:", xor_out.item())

# Test AND
and_test = torch.tensor([1., 1.])  # Should output 1
and_out = infer(and_test, [3, 4], [5])
print("AND Output:", and_out.item())

  cpu = _conversion_method_template(device=torch.device("cpu"))


Epoch 0: Energy -0.0000, Sparsity 0.0000
Epoch 200: Energy -2.9741, Sparsity 0.7896
Epoch 400: Energy -5.9688, Sparsity 1.5875
Epoch 600: Energy -8.9661, Sparsity 2.3864
Epoch 800: Energy -11.9643, Sparsity 3.1857
XOR Output: 1.0
AND Output: -1.0
