In [111]:
import torch as t
from torch.nn import functional as F
from typing import List

%env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
from torchvision import datasets, transforms

env: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python


In [112]:
# Hyperparams
integration_step = 0.2
learning_rate = 0.05

In [113]:
# Variable setup
s_in = t.tensor(
    [-1.0, 1.0],
)
s_target = t.tensor(
    [0.0, 1.0],
)

x = [
    s_in,
    t.zeros(
        [5],
    ),
    t.zeros(
        [3],
    ),
    s_target,
]
e = [None, None, None, None]  # type: List[t.Tensor | None]

w = [
    t.rand(
        [2, 5],
    ),
    t.rand(
        [5, 3],
    ),
    t.rand(
        [3, 2],
    ),
]

f = t.relu


def df(x: t.Tensor):
    return t.where(x > 0, t.ones_like(x), t.zeros_like(x))

In [114]:
for _ in range(100):
    for i in range(1, len(x) - 1):
        x[i] = t.zeros_like(x[i])

    for _ in range(32):
        for i in range(0, len(x) - 1):
            # Eq 11
            e[i + 1] = x[i + 1] - f(x[i]).matmul(w[i])

        for i in range(1, len(x) - 1):
            # Eq 12
            dx = integration_step * (-e[i] + t.dot(df(x[i]), t.matmul(w[i], e[i + 1])))
            x[i] = x[i] + dx

    for i in range(0, len(x) - 1):
        dw = learning_rate * t.matmul(e[i + 1].unsqueeze(0).T, f(x[i]).unsqueeze(0))
        w[i] = w[i] + dw.T

    print(x[2])

tensor([0.4418, 0.2911, 0.4763])
tensor([0.4468, 0.2948, 0.4806])
tensor([0.4516, 0.2984, 0.4847])
tensor([0.4563, 0.3018, 0.4888])
tensor([0.4608, 0.3051, 0.4927])
tensor([0.4651, 0.3084, 0.4964])
tensor([0.4692, 0.3114, 0.5000])
tensor([0.4732, 0.3143, 0.5035])
tensor([0.4771, 0.3171, 0.5068])
tensor([0.4808, 0.3197, 0.5100])
tensor([0.4844, 0.3223, 0.5131])
tensor([0.4878, 0.3247, 0.5160])
tensor([0.4911, 0.3270, 0.5189])
tensor([0.4942, 0.3292, 0.5215])
tensor([0.4973, 0.3312, 0.5241])
tensor([0.5002, 0.3332, 0.5266])
tensor([0.5030, 0.3351, 0.5289])
tensor([0.5056, 0.3368, 0.5312])
tensor([0.5082, 0.3385, 0.5333])
tensor([0.5107, 0.3401, 0.5354])
tensor([0.5130, 0.3416, 0.5373])
tensor([0.5153, 0.3431, 0.5392])
tensor([0.5175, 0.3444, 0.5410])
tensor([0.5195, 0.3457, 0.5427])
tensor([0.5215, 0.3469, 0.5443])
tensor([0.5235, 0.3481, 0.5458])
tensor([0.5253, 0.3492, 0.5473])
tensor([0.5271, 0.3502, 0.5487])
tensor([0.5288, 0.3512, 0.5500])
tensor([0.5304, 0.3521, 0.5513])
tensor([0.

tensor([0.5478, 0.3692, 0.5667])
tensor([0.5433, 0.3645, 0.5624])
tensor([0.5442, 0.3652, 0.5631])
tensor([0.5494, 0.3692, 0.5680])
tensor([0.5498, 0.3696, 0.5683])
tensor([0.5501, 0.3699, 0.5687])
tensor([0.5504, 0.3703, 0.5689])
tensor([0.5508, 0.3706, 0.5692])
tensor([0.5511, 0.3709, 0.5695])
tensor([0.5514, 0.3712, 0.5697])
tensor([0.5517, 0.3715, 0.5700])
tensor([0.5465, 0.3670, 0.5650])
tensor([0.5471, 0.3674, 0.5654])
tensor([0.5520, 0.3719, 0.5704])
tensor([0.5522, 0.3722, 0.5706])
tensor([0.5524, 0.3728, 0.5708])
tensor([0.5518, 0.3733, 0.5705])
tensor([0.5520, 0.3737, 0.5706])
tensor([0.5520, 0.3737, 0.5706])
tensor([0.5521, 0.3740, 0.5707])
tensor([0.5523, 0.3743, 0.5709])
tensor([0.5534, 0.3748, 0.5717])
tensor([0.5534, 0.3750, 0.5717])
tensor([0.5534, 0.3752, 0.5718])
tensor([0.5529, 0.3754, 0.5713])
tensor([0.5529, 0.3756, 0.5712])
tensor([0.5550, 0.3761, 0.5734])
tensor([0.5551, 0.3763, 0.5735])
tensor([0.5552, 0.3766, 0.5736])
tensor([0.5552, 0.3768, 0.5736])
tensor([0.

In [115]:
from dataclasses import dataclass


def relu_gradient(x: t.Tensor):
    return t.where(x > 0, t.ones_like(x), t.zeros_like(x))


def sigmoid_gradient(x: t.Tensor):
    return t.sigmoid(x) * (1 - t.sigmoid(x))


@dataclass
class PCNConfig:
    learning_rate = 0.005
    integration_step = 0.01
    n_relaxation_steps = 8


class PCN:
    def __init__(self, cfg: PCNConfig, shape: List[int], device="cuda:0") -> None:
        self.cfg = cfg
        self.n_layers = len(shape)
        self.shape = shape
        assert (
            self.n_layers >= 2
        ), "at least two dims are required (input dimension and output dimension)"

        self.device = t.device(device)
        with self.device:
            in_dim, *hidded, out_dim = shape

            self.activations = [None for _ in range(self.n_layers)]  # type: List[t.Tensor | None]

            self.errors = [None for _ in range(self.n_layers)]  # type: List[t.Tensor | None]

            self.weights = [
                t.nn.init.xavier_normal_(t.empty(from_, to))
                for [from_, to] in zip(shape, shape[1:])
            ]

            self.activation_fn = t.sigmoid
            self.activation_fn_gradient = sigmoid_gradient

    def relaxation_step(self):
        x = self.activations
        e = self.errors
        w = self.weights
        f = self.activation_fn
        df = self.activation_fn_gradient

        # its done once per input
        # for i in range(1, self.n_layers - 1):
        #     x[i] = t.zeros_like(x[i])
        print(x)
        for i in range(0, self.n_layers - 1):
            # Eq 11
            e[i + 1] = x[i + 1] - w[i].matmul(f(x[i]))
            print(e)

        print(e)

        for i in range(1, self.n_layers - 1):
            # Eq 12
            dx = -e[i] + t.dot(df(x[i]), t.matmul(w[i], e[i + 1]))
            x[i] = x[i] + self.cfg.integration_step * dx

    def weight_update(self):
        x = self.activations
        e = self.errors
        w = self.weights
        f = self.activation_fn

        for i in range(0, len(x) - 1):
            dw = t.matmul(e[i + 1].unsqueeze(0).T, f(x[i]).unsqueeze(0))
            w[i] = w[i] + self.cfg.learning_rate * dw.T

    def clear_activaitons(self, batch: int):
        x = self.activations
        with self.device:
            for i in range(self.n_layers):
                x[i] = t.zeros(batch, self.shape[i])

    def forward(self, input_: t.Tensor):
        self.clear_activaitons(input_.shape[0])
        self.activations[0] = input_.to(self.device)

        for _ in range(self.cfg.n_relaxation_steps):
            self.relaxation_step()

        return t.matmul(self.activations[-2], self.weights[-1])

    def training_step(self, input_: t.Tensor, output: t.Tensor, n_relaxations_steps=32):
        self.clear_activaitons(input_.shape[0])
        self.activations[0] = input_.to(self.device)
        self.activations[-1] = output.to(self.device)

        for _ in range(self.cfg.n_relaxation_steps):
            self.relaxation_step()

        loss = F.mse_loss(
            t.matmul(F.softmax(self.activations[-2]), self.weights[-1]), output
        )

        self.weight_update()

        return loss

In [116]:
from torch.utils.data import DataLoader

# Define transformations for normalizing images
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.1307,), (0.3081,)
        ),  # Mean and standard deviation of MNIST
    ]
)

    # Load training and test sets
train_set = datasets.MNIST("data", train=True, download=True, transform=transform)
test_set = datasets.MNIST("data", train=False, download=True, transform=transform)

# Create data loaders for batching
train_loader = DataLoader(train_set, batch_size=1, shuffle=True, pin_memory=True)

In [117]:
model = PCN(PCNConfig(), [784, 32, 16, 10])


for [i, [input_, label]] in enumerate(train_loader):
    input_ = input_.flatten(-3, -1).to(model.device)
    label = F.one_hot(label, num_classes=10).to(model.device)

    print(f"\n step {i}")
    print("loss", model.training_step(input_, label))
    print("e[-2 norm]", model.errors[-1].norm())
    # print("x[-2] norm", model.activations[-2].norm())
    # print("eigenvalue", t.svd(model.weights[-1])[1])

    if i > 10000:
        break


 step 0
[tensor([[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (784x32 and 1x784)

In [None]:
batch = next(iter(test_set))

In [None]:
image, label = batch
image = image.flatten(0, -1) # batch in not handeled rn
model.forward(image)
model.weights[-2]

[None, None, None, None]


TypeError: unsupported operand type(s) for -: 'NoneType' and 'Tensor'

In [None]:
model.activations[-2]

tensor([-0.0042, -0.0041, -0.0041, -0.0041, -0.0042, -0.0041, -0.0042, -0.0042,
        -0.0041, -0.0042, -0.0042, -0.0041, -0.0042, -0.0042, -0.0042, -0.0042],
       device='cuda:0')

In [None]:
t.matmul(model.activations[-2], model.weights[-1])

tensor([-0.0006, -0.0010, -0.0004, -0.0009, -0.0008, -0.0011, -0.0008, -0.0011,
        -0.0007, -0.0009], device='cuda:0')