In [None]:
! pip install torchvision torch

In [21]:
import torchvision
import torch

In [78]:
from torchvision.transforms import v2

transform = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32),
    v2.Lambda(lambd=lambda img: img / 255)
])


train_dataset = torchvision.datasets.MNIST("dataset", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST("dataset", train=False, download=True, transform=transform)

In [393]:
import torch.nn as nn
import torch.nn.functional as F

class SubSampling(nn.Module):
    def __init__(self, in_shape, out_shape):
        super().__init__()
        assert len(in_shape) == 3
        assert len(out_shape) == 3
        assert in_shape[0] == out_shape[0]
        assert in_shape[1] == in_shape[2]
        assert out_shape[1] == out_shape[2]
        assert in_shape[1] // 2 == out_shape[1]

        # self.mask = torch.eye(out_shape[1]).repeat_interleave(2, 0)

        C = out_shape[0]
        self.alpha = nn.Parameter(torch.ones((C, 1, 1)), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros((C, 1, 1)), requires_grad=True)

    def forward(self, x):
        # print(x.shape)
        # print(self.mask.shape)
        # h = ((x @ self.mask).T @ self.mask).T
        h = F.avg_pool2d(x, kernel_size=2, divisor_override=1)
        # TODO: should a sigmoid be present here? given that it is followed by a Tanh
        # return F.sigmoid(self.alpha * h + self.beta)
        return self.alpha * h + self.beta

class LeNetTanh(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 1.7159 * F.tanh(2/3 * x)

class RBF(nn.Module):
    def __init__(self):
        super().__init__()
        _zero = torch.tensor([
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,0,1,1,1,0,0],
            [0,1,1,0,1,1,0],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [0,1,1,0,1,1,0],
            [0,0,1,1,1,0,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _zero = torch.where(_zero == 0, -1, _zero).flatten()

        _one = torch.tensor([
            [0,0,0,1,1,0,0],
            [0,0,1,1,1,0,0],
            [0,1,1,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,1,1,1,1,1,1],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _one = torch.where(_one == 0, -1, _one).flatten()

        _two = torch.tensor([
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,1,1,1,1,1,0],
            [1,1,0,0,0,1,1],
            [1,0,0,0,0,1,1],
            [0,0,0,0,1,1,0],
            [0,0,1,1,1,0,0],
            [0,1,1,0,0,0,0],
            [1,1,0,0,0,0,0],
            [1,1,1,1,1,1,1],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _two = torch.where(_two == 0, -1, _two).flatten()

        _three = torch.tensor([
            [1,1,1,1,1,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,1,1,0],
            [0,0,0,1,1,0,0],
            [0,0,1,1,1,1,0],
            [0,0,0,0,0,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,0,1,1],
            [1,0,0,0,0,1,1],
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _three = torch.where(_three == 0, -1, _three).flatten()

        _four = torch.tensor([
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
            [0,1,1,0,0,1,1],
            [0,1,1,0,0,1,1],
            [1,1,1,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,1,1,1],
            [0,1,1,1,1,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,0,1,1],
        ])
        _four = torch.where(_four == 0, -1, _four).flatten()

        _five = torch.tensor([
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [1,1,1,1,1,1,1],
            [1,1,0,0,0,0,0],
            [1,1,0,0,0,0,0],
            [0,1,1,1,1,0,0],
            [0,0,1,1,1,1,0],
            [0,0,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _five = torch.where(_five == 0, -1, _five).flatten()

        _six = torch.tensor([
            [0,0,1,1,1,1,0],
            [0,1,1,0,0,0,0],
            [1,1,0,0,0,0,0],
            [1,1,0,0,0,0,0],
            [1,1,1,1,1,1,0],
            [1,1,1,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,1,0,0,1,1],
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _six = torch.where(_six == 0, -1, _six).flatten()

        _seven = torch.tensor([
            [1,1,1,1,1,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,1,1,0],
            [0,0,0,1,1,0,0],
            [0,0,0,1,1,0,0],
            [0,0,1,1,0,0,0],
            [0,0,1,1,0,0,0],
            [0,0,1,1,0,0,0],
            [0,0,1,1,0,0,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _seven = torch.where(_seven == 0, -1, _seven).flatten()

        _eight = torch.tensor([
            [0,1,1,1,1,1,0],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [0,1,1,1,1,1,0],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [0,1,1,1,1,1,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])

        _eight = torch.where(_eight == 0, -1, _eight).flatten()
        _nine = torch.tensor([
            [0,1,1,1,1,1,0],
            [1,1,0,0,1,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,0,1,1],
            [1,1,0,0,1,1,1],
            [0,1,1,1,1,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,0,1,1],
            [0,0,0,0,1,1,0],
            [1,1,1,1,1,0,0],
            [0,0,0,0,0,0,0],
            [0,0,0,0,0,0,0],
        ])
        _nine = torch.where(_nine == 0, -1, _nine).flatten()

        self.W = torch.stack((_zero, _one, _two, _three, _four, _five, _six, _seven, _eight, _nine), dim=0)

    def forward(self, x):
        x = x.unsqueeze(dim=1)
        return (x - self.W).pow(2).sum(dim=2)

class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            LeNetTanh(),
            SubSampling(in_shape=(6, 28, 28), out_shape=(6, 14, 14)),
            LeNetTanh(),
            # todo the next layer should be sparse, not all 6 in_channels are connected with the 16 out_channels
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            LeNetTanh(),
            SubSampling(in_shape=(16, 10, 10), out_shape=(16, 5, 5)),
            LeNetTanh(),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
            LeNetTanh(),
            nn.Flatten(),
            nn.Linear(in_features=120, out_features=84),
            LeNetTanh(),
            RBF(),
        )
    
    def forward(self, x):
        return self.layers(x)

In [496]:
model = LeNet5()

In [None]:
def loss(X, y, model):
    j = 2
    pred = model(X)
    mask = torch.ones((y.numel(), 10), dtype=torch.bool)
    mask[torch.arange(X.size(0)), y] = False

    correct = pred[torch.arange(y.numel()), y]
    incorrect = torch.log(torch.exp(torch.tensor(-j)) + torch.exp(-pred[mask].view(X.size(0), -1)).sum(dim=1))

    return (correct + incorrect).mean()

In [495]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64)

for X, y in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

model = LeNet5()
X = torch.randn((4, 1, 28, 28))
y = torch.randint(high=10, size=(4,))



loss(X, y, model)

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


tensor(96.4081, grad_fn=<MeanBackward0>)