In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms 

import torch_geometric

from tqdm.notebook import tqdm

ModuleNotFoundError: No module named 'torch_geometric'

In [None]:
import sys
sys.path.append('./..')

In [None]:
from megnn.invertible import RevNetBlock
from megnn.backprop import MemoryEfficientNet

# Setup Network

In [None]:
def dense_net(d_in, d_out, d_hidden=50):
    return torch.nn.Sequential(
        torch.nn.Linear(d_in, d_hidden),
        torch.nn.LeakyReLU(),
        torch.nn.BatchNorm1d(d_hidden),
        torch.nn.Linear(d_hidden, d_out)
    )

Very simple ResNet using dense layers for the residual part. The information is propagated as follows

```
    x = mnist_pixels
    y = zeros
    ...
    x <- x + f(y)
    y <- y + g(x)
    ...
    pred = linear([x, y])
```

In [None]:
class Net(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self._net = MemoryEfficientNet([
            RevNetBlock(*(dense_net(dim, dim) for _ in range(2)))
            for _ in range(n_blocks)
        ])
        self._out = torch.nn.Linear(2 * dim, 10)
    
    def forward(self, x):
        y = torch.zeros_like(x)  # augment feature space with zeros
        x_, y_ = self._net(x, y)
        z = torch.cat([x_, y_], dim=-1)
        return self._out(z)

# Data Loader

In [None]:
dim = 28 * 28

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: x.view(dim))
])

mnist_train = MNIST("../../coarse-vae/data/mnist/", download=True, train=True, transform=transform)
mnist_test = MNIST("../../coarse-vae/data/mnist/", download=True, train=False, transform=transform)

# Train Model / Evaluate

In [None]:
n_blocks = 10
n_epochs = 10
batch_size = 100

In [None]:
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size)

In [None]:
net = Net()

In [None]:
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:
net = net.train()
for epoch in range(n_epochs):
    for data, labels in train_loader:
        preds = net(data)
        loss = criterion(preds, labels)
        optim.zero_grad()
        loss.backward()
        optim.step()
        print(f"\r epoch: {epoch}/{n_epochs}, loss: {loss.item():.4}", end="")

In [None]:
net = net.eval()
with torch.no_grad():
    acc = []
    for data, labels in tqdm(test_loader):
        preds = net(data).max(dim=-1).indices
        acc.append((preds == labels).float().mean())
    acc = sum(acc) / len(acc)
print(f"Accuracy {acc.item() * 100:.4}%")