In [1]:
import torch
import torch.nn as nn
import torch.utils.tensorboard as tb

import shutil
import tempfile
import numpy as np
import sys

sys.path.append("..")
from datasets import load_data, MNIST
import pcn_kernels


class pcnpass(torch.autograd.Function):
    @staticmethod
    def forward(c: torch.Tensor, l: torch.Tensor, lnext: torch.Tensor, b: torch.Tensor):
        output = pcn_kernels.forward(c, l, lnext, b)
        return output
    
    @staticmethod
    def setup_context(ctx, inputs, _):
        c, l, lnext, b = inputs
        ctx.save_for_backward(c, l, lnext, b)

    @staticmethod
    def backward(ctx, grad_cnext):
        c, l, lnext, b = ctx.saved_tensors
        grad_c, grad_l, grad_lnext, grad_b = pcn_kernels.backward(grad_cnext.contiguous(), c, l, lnext, b)
        return grad_c, grad_l, grad_lnext, grad_b

class PCN(nn.Module):
    def __init__(
        self,
        layers: list[int],
        dimensions: int = 20,
    ):
        super().__init__()
        self.weight_transform = tri(0.1, 1)
        if len(layers) < 2:
            raise ValueError("At least 2 layers are required")

        self.layers = nn.ParameterList(
            [nn.Parameter(torch.rand(l, dimensions) * 2 - 1) for l in layers]
        )
        self.layers_bias = nn.ParameterList(
            [nn.Parameter((torch.rand(l) * 2 - 1) * 0.1) for l in layers]
        )

    def forward(self, x: torch.Tensor):
        z = x
        for i, (l, lnext) in enumerate(zip(self.layers, self.layers[1:])):
            z = pcnpass.apply(z, l, lnext, self.layers_bias[i + 1])
            if i < len(self.layers) - 2:
                z = torch.relu(z)
        return z

def tri(period: float, amplitude: float):
    """
    triangle wave function centered around 0 with period and amplitude
    """

    def triangle_wave_transform(x: torch.Tensor):
        # using sigal
        return (amplitude / period) * (
            (period - abs(x % (2 * period) - (1 * period)) - period / 2)
        )

    return triangle_wave_transform

class PCN_Old(nn.Module):
    def __init__(
        self,
        layers: list[int],
        dimensions: int = 20,
    ):
        super().__init__()
        self.weight_transform = tri(0.1, 1)
        if len(layers) < 2:
            raise ValueError("At least 2 layers are required")

        self.layers = nn.ParameterList(
            [nn.Parameter(torch.rand(l, dimensions) * 2 - 1) for l in layers]
        )
        self.layers_bias = nn.ParameterList(
            [nn.Parameter((torch.rand(l, 1) * 2 - 1) * 0.1) for l in layers]
        )

    def forward(self, x: torch.Tensor):
        z = x
        for i, (l, lnext) in enumerate(zip(self.layers, self.layers[1:])):
            z = (
                z @ (self.weight_transform(torch.cdist(l, lnext)) / np.sqrt(l.shape[0]))
            ) + self.layers_bias[i + 1].T
            if i < len(self.layers) - 2:
                z = torch.relu(z)
        return z

class FCN(nn.Module):
    def __init__(
        self,
        n_in: int,
        n_out: int,
        hidden: list[int],
    ):
        super().__init__()

        c = n_in
        L = []
        for l in hidden:
            L.append(nn.Linear(c, l))
            L.append(nn.ReLU())
            c = l

        L.append(nn.Linear(l, n_out))
        self.net = nn.Sequential(*L)

    def forward(self, x):
        return self.net(x)

class pcnSGD(torch.optim.Optimizer):
    def __init__(self, model: nn.Module, lr=1e-3, opp=None):
        self.opp = opp
        self.model = model
        defaults = dict(lr=lr)
        super(pcnSGD, self).__init__(model.parameters(), defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        # for param in 

        for group in self.param_groups:
            positions = group["params"][:int(len(group["params"])/2)]
            biases = group["params"][int(len(group["params"])/2):]

            for pos in positions:
                if pos.grad is None:
                    continue
                grad = pos.grad.data
                # apply custom gradient descent update
                
                if self.opp == None:
                    pos.data.add_(-group["lr"] * (pos.shape[0]), grad)
                elif self.opp == "log":
                    pos.data.add_(-group["lr"] * (pos.shape[0] / np.log2(pos.shape[0])), grad)
                elif self.opp == "sqrt":
                    pos.data.add_(-group["lr"] * (np.sqrt(pos.shape[0])), grad)


            for bias in biases:
                if bias.grad is None:
                    continue
                grad = bias.grad.data
                # apply custom gradient descent update
                bias.data.add_(-group["lr"] * 1e5, grad)

        return loss

In [2]:
# Load Training Data

train = load_data(MNIST("train"))
test = load_data(MNIST("test"))

In [6]:
shutil.rmtree("log_dir", ignore_errors=True)

In [7]:

log_dir = "log_dir"
learn_rate = 0.01
epochs = 100

cpu = torch.device("cpu")
gpu = torch.device("cuda:0")
device = gpu if torch.cuda.is_available() else cpu

models: list[tuple[str, nn.Module]] = [
    ("pcn512_16D_Old", PCN_Old([28 * 28, 512, 10], dimensions=16).to(device)),
    ("pcn512_16D_New", PCN([28 * 28, 512, 10], dimensions=16).to(device)),
    ("fcn_100_100_100", FCN(28 * 28, 10, [100, 100, 100]).to(device)),
]

loggers = [tb.SummaryWriter(log_dir + "/" + name) for name, _ in models]
optimizers = [
    pcnSGD(models[0][1], lr=0.00002, opp="log"),
    pcnSGD(models[1][1], lr=0.00002, opp="log"),
    torch.optim.SGD(models[2][1].parameters(), lr=0.01),
]

loss = nn.CrossEntropyLoss()

step = 0
for epoch in range(epochs):
    for i, (x, y) in enumerate(train):
        x = x.float().to(device)
        y = y.to(device)
        for (_, model), optimzer, logger in zip(models, optimizers, loggers):
            pred = model(x).float()
            l = loss(pred, y)

            l.backward()

            optimzer.step()
            optimzer.zero_grad()

            logger.add_scalar("loss", l, step)

        step += 1

    with torch.no_grad():
        for (_, model), logger in zip(models, loggers):
            correct = 0
            total = 0
            for i, (x, y) in enumerate(test):
                x = x.float().to(device)
                y = y.to(device)
                pred = model(x).float()

                correct += (pred.argmax(dim=1) == y).sum().item()
                total += y.shape[0]

            logger.add_scalar("accuracy", correct / total, epoch)

In [4]:

# Start tensorboard

shutil.rmtree(tempfile.gettempdir() + "/.tensorboard-info", ignore_errors=True) # sort of 'force reload' for tensorboard
%load_ext tensorboard 
%tensorboard --logdir log_dir --reload_interval 1 --port 6005