<a href="https://colab.research.google.com/github/liuyao12/pytorch-cifar/blob/master/cifar10_with_PDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cifar10 with PDE

* As far as I'm aware, a simple and novel architecture of ConvNets (Convolutional Neural Networks) that is readily applicable to any existing ResNet backbone.

* The key idea would be hard to come by or justify without viewing ResNet as a partial differential equation (like the heat equation). Traditionally, the standard toolkit for machine learning only includes bits of multi-variable calculus, linear algebra, and statistics, and not so much PDE. This partly explains why ResNet comes on the scene relatively late (2015), and why this enhanced version of ResNet has not been "reinvented" by the DL community.

* Code based off of https://github.com/kuangliu/pytorch-cifar, and the [official PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) 

* Questions and comments shall be greatly appreciated [@liuyao12](https://twitter.com/liuyao12) or liuyao@gmail.com




A quick summary of ConvNets from a Partial Differential Equations (PDE) point of view. For details, see my [notebook](https://observablehq.com/@liuyao12/neural-networks-and-partial-differential-equations) on observable.

neural network | heat equation
:----:|:-------:
input layer | initial condition
feed forward | solving the equation
hidden layers | solution at intermediate times
output layer | solution at final time
convolve with 3×3 kernel | differential operator of order ≤ 2
weights | coefficients
boundary handling (padding) | boundary condition

Basically, classical ConvNets (ResNets) are **linear PDEs with constant coefficients**, and here I'm simply making it **variable coefficients**, with the variables being polynomials of degree ≤ 1, which should theoretically enable the neural net to learn more ways to deform than diffusion and translation (e.g., rotation and scaling).

In [3]:
''' 
ResNet in PyTorch  https://github.com/kuangliu/pytorch-cifar
Reference:
    Kaiming He 何恺明, Xiangyu Zhang 张祥雨, Shaoqing Ren 任少卿, Jian Sun 孙剑 (Microsoft Research Asia)
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, iterations=1):
        super(BasicBlock, self).__init__()
        self.iterations = iterations
        self.X1, self.Y1, self.X2, self.Y2 = None, None, None, None
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1x = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1y = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2x = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2y = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x1 += self.twist1(x)       # <===== can manually comment out to compare
        x1 = F.relu(self.bn1(x1))
        x2 = self.conv2(x1)
        x2 += self.twist2(x1)      # <===== can manually comment out to compare
        x3 = self.shortcut(x) + self.bn2(x2)
        x3 = F.relu(x3)
        return x3

    def twist1(self, x):
        if self.X1 is None:
            _, _, h, w = list(self.conv1(x).shape)
            self.X1 = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.Y1 = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
            with torch.no_grad():
                self.conv1.weight /= 3
                self.conv1x.weight /= 3
                self.conv1y.weight /= 3
        return self.X1 * self.conv1x(x) + self.Y1 * self.conv1y(x)

    def twist2(self, x):
        if self.X2 is None:
            _, _, h, w = list(self.conv2(x).shape)
            self.X2 = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.Y2 = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
            with torch.no_grad():
                self.conv2.weight /= 3
                self.conv2x.weight /= 3
                self.conv2y.weight /= 3
        return self.X2 * self.conv2x(x) + self.Y2 * self.conv2y(x)

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.conv2x = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2y = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv3 = nn.Conv2d(channels, self.expansion * channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = F.relu(self.bn1(self.conv1(x)))
        x2 = self.conv2(x1)
        # x2 += self.twist(x1)       # <===== can manually comment out to compare
        x3 = F.relu(self.bn2(x2))
        x3 = self.bn3(self.conv3(x3))
        x3 += self.shortcut(x)
        x3 = F.relu(x3)
        return x3

    def twist(self, x):
        if self.XX is None:
            _, _, h, w = list(self.conv2(x).shape)
            self.XX = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.YY = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
            # with torch.no_grad():
            #     print(x.mean().item(), x.std().item())
            #     x = self.conv2(x)
            #     print(x.mean().item(), x.std().item())
            #     x += self.XX * self.conv2x(x)
            #     print(x.mean().item(), x.std().item())
            #     x += self.YY * self.conv2y(x)
            #     print(x.mean().item(), x.std().item())
            #     self.conv2.weight /= 3
            #     self.conv2x.weight /= 3
            #     self.conv2y.weight /= 3
        return self.XX * self.conv2x(x) + self.YY * self.conv2y(x)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, channels, stride))
            self.in_channels = channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])


net = ResNet50()
epoch = 0 
lr = 0.1
checkpoint = {'acc': 0, 'epoch': 0}
history = [{'acc': 0, 'epoch': 0}]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device =', device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
net.to(device)
print('Testing on a random input:')
test = torch.randn(1,3,32,32).to(device)
print('INPUT', test.shape)
print('OUTPUT', net(test).shape)

device = cuda
Testing on a random input:
INPUT torch.Size([1, 3, 32, 32])
OUTPUT torch.Size([1, 10])


In [4]:
import torchvision
import torchvision.transforms as transforms

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

0it [00:00, ?it/s]

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:08, 19826550.27it/s]                               


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [0]:
# Training
def train(loss_func, opt):
    global history
    print('Epoch %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        pred = net(x)
        loss = loss_func(pred, y)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        _, predicted = pred.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    print('train loss: {:.3f} | acc: {:.3f} ({}/{})'.format(
        train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': 100. * correct / total})

def test(loss_func):
    global checkpoint, history
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            pred = net(x)
            loss = loss_func(pred, y)
            test_loss += loss.item()
            _, predicted = pred.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    acc = 100. * correct / total
    history[-1]['loss'] = test_loss
    history[-1]['acc'] = acc
    if acc > checkpoint['acc']:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{}) (up by {:.2f})'.format(
               test_loss / (batch_idx + 1), 100. * correct / total, correct, total,
               acc - checkpoint['acc']))
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        checkpoint = state
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')
        # torch.save(state, './checkpoint/ckpt.pth')
    else:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{})'.format(
            test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

loss_func = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

for _ in range(100):
    global epoch
    if epoch - checkpoint['epoch'] >= 15:
        lr *= 0.1
        print('learning rate downgraded to {} at epoch {}'.format(lr, epoch))
        print('loading state_dict from Epoch {} (acc = {})'.format(checkpoint['epoch'], checkpoint['acc']))
        net.load_state_dict(checkpoint['net'])
        opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    train(loss_func, opt)
    test(loss_func)
    epoch += 1
print('finish at lr =', lr)

Epoch 0
train loss: 2.667 | acc: 19.380 (9690/50000)
test  loss: 2.012 | acc: 24.11  ( 2411/10000) (up by 0.00)
Epoch 1
train loss: 1.804 | acc: 32.202 (16101/50000)
test  loss: 1.672 | acc: 37.41  ( 3741/10000) (up by 0.00)
Epoch 2
train loss: 1.596 | acc: 40.638 (20319/50000)
test  loss: 1.466 | acc: 46.32  ( 4632/10000) (up by 0.00)
Epoch 3
train loss: 1.426 | acc: 47.638 (23819/50000)
test  loss: 1.315 | acc: 51.03  ( 5103/10000) (up by 0.00)
Epoch 4
train loss: 1.273 | acc: 53.834 (26917/50000)
test  loss: 1.304 | acc: 54.75  ( 5475/10000) (up by 0.00)
Epoch 5
train loss: 1.124 | acc: 59.392 (29696/50000)
test  loss: 1.184 | acc: 56.75  ( 5675/10000) (up by 0.00)
Epoch 6
train loss: 0.999 | acc: 64.378 (32189/50000)
test  loss: 1.165 | acc: 60.75  ( 6075/10000) (up by 0.00)
Epoch 7
train loss: 0.905 | acc: 67.652 (33826/50000)
test  loss: 0.878 | acc: 68.57  ( 6857/10000) (up by 0.00)
Epoch 8
train loss: 0.843 | acc: 70.146 (35073/50000)
test  loss: 0.997 | acc: 67.72  ( 6772/1000

In [0]:
history

In [0]:
ResNet18 = {'run0': [(0, 40.77), (1, 50.38), (2, 60.33), (3, 67.54), (4, 72.53), (5, 74.63), (8, 77.29), (10, 80.53), (14, 81.13), (16, 84.1), (20, 84.58), (29, 85.29), (42, 87.32), (50, 92.76), (51, 93.15), (52, 93.52), (53, 93.64), (54, 93.65), (58, 93.92), (75, 94.05), (76, 94.35), (82, 94.37), (85, 94.38), (86, 94.39), (92, 94.47)],
        'run1': [(0, 46.66), (1, 59.49), (2, 64.82), (3, 67.94), (4, 73.07), (5, 76.9), (7, 80.42), (10, 82.33), (12, 83.24), (15, 83.56), (22, 84.14), (31, 84.36), (32, 86.76), (36, 87.09), (50, 92.81), (51, 92.97), (52, 93.31), (53, 93.33), (54, 93.58), (55, 93.61), (58, 93.7), (60, 93.71), (62, 93.72), (64, 93.88), (75, 94.32), (76, 94.47), (77, 94.52), (78, 94.55), (79, 94.56), (80, 94.59), (81, 94.68), (83, 94.74), (85, 94.76), (87, 94.77), (88, 94.82), (99, 94.89)], 
        'run2': [(0, 53.2), (1, 59.44), (2, 74.87), (4, 77.51), (5, 80.48), (8, 81.88), (13, 82.32), (16, 82.91), (18, 82.99), (19, 83.11), (24, 83.8), (25, 85.36), (28, 86.4), (50, 92.47), (51, 93.07), (52, 93.63), (56, 93.71), (61, 93.77), (63, 93.94), (75, 94.26), (76, 94.41), (77, 94.43), (78, 94.67), (81, 94.81), (84, 94.84)]}

### Results on ResNet18: 
Following kuangliu, I manually change the learning rate as follows:

* `lr=0.1` for Epoch [0:50] 
* `lr=0.01` for Epoch [50:75]
* `lr=0.001` for Epoch [75:100]

epoch | 0 | 5 |  10 | 15 | 20 | 25 | 30 | 45 | 50 | 55 | 65 | 75 | 80 | 90 | 99
:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:
without twist ('run0') | 40.77 | 74.63 | 80.53 | 81.13 | 84.58 | - | 85.29 | 87.32 | 92.76 | 93.65 | - | 94.05 | 94.35 | 94.39 | 94.47
with twist | 42.39 | 74.86 | 81.51 | 82.88 | - | 84.34 | 86.98 | 87.85 | 92.34 | 93.65 | - | 93.73 | 94.22
with twist ('run1') | 46.66 | 76.90 | 82.33 | 83.56 | - | 84.14 | - | 87.09 | 92.81 | 93.61 | 93.88 | 94.32 | 94.59 | 94.82 | 94.89
with twist | 50.56 | 80.40 | 82.81 | 85.15 | - | - | - | 87.36 | 93.20 | 93.84 | 94.02
deep twist (3) | 41.55 | 72.37 | 78.81 | 79.49 | 82.38 | 82.38 | 83.14 | 83.36 

Not a significant improvement as I initially thought based on the reported accuracy at https://github.com/kuangliu/pytorch-cifar 