In [60]:
# ruff: noqa: F722

In [61]:
import torch as t
from torch.nn import functional as F
%env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
from torchvision import datasets, transforms

from einops import rearrange, reduce, repeat

from icecream import ic

env: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python


In [62]:
from jaxtyping import Float, jaxtyped
from typeguard import typechecked
from typing import List
typechecker = typechecked

In [63]:
from dataclasses import dataclass


def relu_gradient(x: t.Tensor):
    return t.where(x > 0, t.ones_like(x), t.zeros_like(x))


def sigmoid_gradient(x: t.Tensor):
    return t.sigmoid(x) * (1 - t.sigmoid(x))


@dataclass
class PCNConfig:
    learning_rate = 0.1
    integration_step = 0.01
    n_relaxation_steps = 8


class PCN:
    def __init__(self, cfg: PCNConfig, shape: List[int], device="cuda:0") -> None:
        self.cfg = cfg
        self.n_layers = len(shape)
        self.shape = shape
        assert (
            self.n_layers >= 2
        ), "at least two dims are required (input dimension and output dimension)"

        self.device = t.device(device)
        with self.device:
            in_dim, *hidded, out_dim = shape

            self.activations = [None for _ in range(self.n_layers)]  # type: List[t.Tensor | None]

            self.errors = [None for _ in range(self.n_layers)]  # type: List[t.Tensor | None]

            self.weights = [
                t.nn.init.xavier_normal_(t.empty(from_, to))
                for [from_, to] in zip(shape, shape[1:])
            ]

            self.activation_fn = t.sigmoid
            self.activation_fn_gradient = sigmoid_gradient

    # @jaxtyped(typechecker=typechecker)
    # def error_prediction(
    #     x0: Float[t.Tensor, "batch act0"],
    #     x1: Float[t.Tensor, "batch act1"],
    #     w: Float[t.Tensor, "act0 act1"],
    #     f,
    # ) -> Float[t.Tensor, "batch act1"]:
    #     # Eq 11
    #     return x1 - w.matmul(f(x0))

    # @jaxtyped(typechecker=typechecker)
    # def weight_dynamics_step(
    #     e0: Float[t.Tensor, "batch act0"],
    #     e1: Float[t.Tensor, "batch act1"],
    #     x0: Float[t.Tensor, "batch act0"],
    #     w: Float[t.Tensor, "act0 act1"],
    #     df,
    # ) -> Float[t.Tensor, "batch act0"]:
    #     return -e0 + t.dot(df(x0), t.matmul(w, e1))

    def relaxation_step(self):
        x = self.activations
        e = self.errors
        w = self.weights
        f = self.activation_fn
        df = self.activation_fn_gradient

        for i in range(0, self.n_layers - 1):
            # Eq 11
            e[i + 1] = x[i + 1] - f(x[i]).matmul(w[i])

        for i in range(1, self.n_layers - 1):
            # Eq 12
            dx = -e[i] + t.einsum("bi,bi->bi", df(x[i]), t.matmul(w[i], e[i + 1].T).T)
            x[i] = x[i] + self.cfg.integration_step * dx

    def weight_update(self):
        x = self.activations
        e = self.errors
        w = self.weights
        f = self.activation_fn

        for i in range(0, len(x) - 1):
            dw = t.einsum("Bb,Ba->Bab", e[i + 1], f(x[i])).mean(dim=0)
            # ic(i, (dw * dw).mean(0).mean(0))
            w[i] = w[i] + self.cfg.learning_rate * dw

    def clear_activaitons(self, batch: int):
        x = self.activations
        with self.device:
            for i in range(self.n_layers):
                x[i] = t.zeros(batch, self.shape[i])

    def forward(self, input_: t.Tensor):
        self.clear_activaitons(input_.shape[0])
        self.activations[0] = input_.to(self.device)

        for _ in range(self.cfg.n_relaxation_steps):
            self.relaxation_step()

        return t.matmul(self.activations[-2], self.weights[-1])

    def training_step(self, input_: t.Tensor, output: t.Tensor, n_relaxations_steps=32):
        self.clear_activaitons(input_.shape[0])
        self.activations[0] = input_.to(self.device)
        self.activations[-1] = output.to(self.device)

        for _ in range(self.cfg.n_relaxation_steps):
            self.relaxation_step()

        loss = F.mse_loss(
            F.softmax(t.matmul((self.activations[-2]), self.weights[-1])), output
        )

        self.weight_update()

        return loss

In [64]:
from torch.utils.data import DataLoader

# Define transformations for normalizing images
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.1307,), (0.3081,)
        ),  # Mean and standard deviation of MNIST
    ]
)

    # Load training and test sets
train_set = datasets.MNIST("data", train=True, download=True, transform=transform)
test_set = datasets.MNIST("data", train=False, download=True, transform=transform)

# Create data loaders for batching
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, pin_memory=True)

In [65]:
model = PCN(PCNConfig(), [784, 33, 16, 10])


for [i, [input_, label]] in enumerate(train_loader):
    input_ = input_.flatten(-3, -1).to(model.device)
    label = F.one_hot(label, num_classes=10).to(model.device)

    print(f"\n step {i}")
    print("loss", model.training_step(input_, label))
    # print("e[-2 norm]", model.errors[-1].norm())
    # print("x[-2] norm", model.activations[-2].norm())
    # print("eigenvalue", t.svd(model.weights[-1])[1])

    if i > 10000:
        break


 step 0
loss tensor(0.0896, device='cuda:0')

 step 1
loss tensor(0.0895, device='cuda:0')

 step 2
loss tensor(0.0896, device='cuda:0')



 step 3
loss tensor(0.0896, device='cuda:0')

 step 4
loss tensor(0.0895, device='cuda:0')

 step 5
loss tensor(0.0895, device='cuda:0')

 step 6
loss tensor(0.0896, device='cuda:0')

 step 7
loss tensor(0.0895, device='cuda:0')


  F.softmax(t.matmul((self.activations[-2]), self.weights[-1])), output



 step 8
loss tensor(0.0895, device='cuda:0')

 step 9
loss tensor(0.0896, device='cuda:0')

 step 10
loss tensor(0.0896, device='cuda:0')

 step 11
loss tensor(0.0893, device='cuda:0')

 step 12
loss tensor(0.0895, device='cuda:0')

 step 13
loss tensor(0.0897, device='cuda:0')

 step 14
loss tensor(0.0895, device='cuda:0')

 step 15
loss tensor(0.0895, device='cuda:0')

 step 16
loss tensor(0.0895, device='cuda:0')

 step 17
loss tensor(0.0896, device='cuda:0')

 step 18
loss tensor(0.0895, device='cuda:0')

 step 19
loss tensor(0.0895, device='cuda:0')

 step 20
loss tensor(0.0895, device='cuda:0')

 step 21
loss tensor(0.0895, device='cuda:0')

 step 22
loss tensor(0.0895, device='cuda:0')

 step 23
loss tensor(0.0895, device='cuda:0')

 step 24
loss tensor(0.0895, device='cuda:0')

 step 25
loss tensor(0.0896, device='cuda:0')

 step 26
loss tensor(0.0895, device='cuda:0')

 step 27
loss tensor(0.0894, device='cuda:0')

 step 28
loss tensor(0.0896, device='cuda:0')

 step 29
loss 

KeyboardInterrupt: 

In [None]:
batch = next(iter(test_set))

In [None]:
image, label = batch
image = image.flatten(0, -1) # batch in not handeled rn
model.forward(image)
model.weights[-2]

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [n

In [None]:
model.activations[-2]

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0')

In [None]:
t.matmul(model.activations[-2], model.weights[-1])

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0')

In [69]:
image.detach().cpu().numpy()

array([-0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42421296,
       -0.42421296, -0.42421296, -0.42421296, -0.42421296, -0.42