<a href="https://colab.research.google.com/github/liuyao12/pytorch-cifar/blob/master/cifar10_with_PDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cifar10 with PDE

* As far as I'm aware, a simple and novel architecture of ConvNets (Convolutional Neural Networks) that is readily applicable to any existing ResNet backbone.

* The key idea would be hard to come by or justify without viewing ResNet as a partial differential equation (like the heat equation). Traditionally, the standard toolkit for machine learning only includes bits of multi-variable calculus, linear algebra, and statistics, and not so much PDE. This partly explains why ResNet comes on the scene relatively late (2015), and why this enhanced version of ResNet has not been "reinvented" by the DL community.

* Code based off of https://github.com/kuangliu/pytorch-cifar, and the [official PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) 

* Questions and comments shall be greatly appreciated [@liuyao12](https://twitter.com/liuyao12) or liuyao@gmail.com




A quick summary of ConvNets from a Partial Differential Equations (PDE) point of view. For details, see my [notebook](https://observablehq.com/@liuyao12/neural-networks-and-partial-differential-equations) on observable.

neural network | heat equation
:----:|:-------:
input layer | initial condition
feed forward | solving the equation
hidden layers | solution at intermediate times
output layer | solution at final time
convolution with 3×3 kernel | differential operator of order ≤ 2
weights | coefficients
boundary handling (padding) | boundary condition
multiple channels | system of (coupled) PDEs
e.g. 16×16×3×3 kernel | 16×16 matrix of differential operators
16×16×1×1 kernel | 16×16 matrix of constants


Basically, classical ConvNets (ResNets) are **linear PDEs with constant coefficients**, and here I'm simply making it **variable coefficients**, with the variables being polynomials of degree ≤ 1, which should theoretically enable the neural net to learn more ways to deform than diffusion and translation (e.g., rotation and scaling).

In [229]:
''' 
ResNet in PyTorch  https://github.com/kuangliu/pytorch-cifar
Reference:
    Kaiming He 何恺明, Xiangyu Zhang 张祥雨, Shaoqing Ren 任少卿, Jian Sun 孙剑 (Microsoft Research Asia)
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, iterations=1):
        super(BasicBlock, self).__init__()
        self.iterations = iterations
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.relu(self.bn1(x1))
        x2 = self.conv2(x1)
        x3 = self.shortcut(x) + self.bn2(x2)
        x3 = F.relu(x3)
        return x3

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv3 = nn.Conv2d(channels, self.expansion * channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.relu(self.bn1(x1))
        x2 = self.conv2(x1)
        x3 = F.relu(self.bn2(x2))
        x4 = self.conv3(x3)
        x4 = self.bn3(x4)
        x4 += self.shortcut(x)
        x4 = F.relu(x4)
        return x4


class PDE_Block(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, act=True):
        super(PDE_Block, self).__init__()
        self.act = act
        self.conv = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.convx = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.convy = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv(x)
        x1 += self.twist(x, x1.shape)
        x2 = self.bn(x1)
        x2 += self.shortcut(x)
        if self.act:
            x2 = F.relu(x2)
        return x2

    def twist(self, x, out_shape):
        if self.XX is None:
            _, _, h, w = list(out_shape)
            self.XX = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.YY = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
        self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3))/2
        self.convy.weight.data = (self.convy.weight - self.convy.weight.flip(2).flip(3))/2
        return self.XX * self.convx(x) + self.YY * self.convy(x)

    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 8
        self.num_blocks = num_blocks
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2)
        self.linear = nn.Linear(128 * block.expansion, num_classes)

    def _make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for idx, stride in enumerate(strides):
            # layers.append(block(self.in_channels, channels, stride))
            if block == PDE_Block and idx % 3 == 2:
                layers.append(block(self.in_channels, channels, stride, False))
            else:
                layers.append(block(self.in_channels, channels, stride))
            self.in_channels = channels * block.expansion
                
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    

    
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])


net = ResNet(PDE_Block, [8,8,8,8])
epoch = 0 
lr = 0.1
checkpoint = {'acc': 0, 'epoch': 0}
history = [{'acc': 0, 'epoch': 0}]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device =', device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
net.to(device)
print('Testing on a random input:')
test = torch.randn(1,3,32,32).to(device)
print('INPUT ', test.shape)
print('OUTPUT', net(test).shape)

device = cuda
Testing on a random input:
INPUT  torch.Size([1, 3, 32, 32])
OUTPUT torch.Size([1, 10])


In [59]:
import torchvision
import torchvision.transforms as transforms

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Training
def train(loss_func, opt):
    global history
    print('Epoch {}'.format(epoch))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        pred = net(x)
        loss = loss_func(pred, y)
        energy = get_energy(net)
        if batch_idx % 100 == 0:
            print(batch_idx, loss.item(), energy)
        if epoch < 0:
            energy_sum = energy[1] + energy[2]
            loss += energy_sum * 100
        loss.backward()
        opt.step()
        train_loss += loss.item()
        _, predicted = pred.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    print('train loss: {:.3f} | acc: {:.3f} ({}/{})'.format(
        train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': 100. * correct / total})

def get_energy(net):
    energy = [0, 0, 0]
    m = net.module
    layers = [m.layer1, m.layer2, m.layer3, m.layer4]
    for n, layer in enumerate(layers):
        for i in range(m.num_blocks[n]-1):
            A = layer[i].conv.weight
            B = layer[i+1].conv.weight
            if A.shape == B.shape:
                diff = A - B
                energy[0] += torch.sum(diff * diff, dim=(2,3)).min(0)[0].max()
                A = layer[i].convx.weight
                B = layer[i+1].convx.weight
                diff = A - B
                energy[1] += torch.sum(diff * diff, dim=(2,3)).min(0)[0].max()
                A = layer[i].convy.weight
                B = layer[i+1].convy.weight
                diff = A - B
                energy[2] += torch.sum(diff * diff, dim=(2,3)).min(0)[0].max()
    return [e.item() for e in energy]
    
def test(loss_func):
    global checkpoint, history
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            pred = net(x)
            loss = loss_func(pred, y)
            test_loss += loss.item()
            _, predicted = pred.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    acc = 100. * correct / total
    history[-1]['loss'] = test_loss
    history[-1]['acc'] = acc
    if acc > checkpoint['acc']:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{}) (up by {:.2f})'.format(
               test_loss / (batch_idx + 1), 100. * correct / total, correct, total,
               acc - checkpoint['acc']))
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        checkpoint = state
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')
        # torch.save(state, './checkpoint/ckpt.pth')
    else:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{})'.format(
            test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

loss_func = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)  # 5e-4

for _ in range(200):
    global epoch, checkpoint, history
    if epoch - checkpoint['epoch'] >= 20:
        lr *= 0.1
        print('learning rate downgraded to {} at epoch {}'.format(lr, epoch))
        print('loading state_dict from Epoch {} (acc = {})'.format(checkpoint['epoch'], checkpoint['acc']))
        net.load_state_dict(checkpoint['net'])
        checkpoint['epoch'] = epoch
        history.append({'epoch': checkpoint['epoch'], 'acc': checkpoint['acc']})
        opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    train(loss_func, opt)
    test(loss_func)
    epoch += 1
print('finish at lr =', lr)

Epoch 0
0 2.621119260787964 [0.26603174209594727, 0.2705247402191162, 0.2586051821708679]
100 2.254976987838745 [0.42982226610183716, 0.3049733340740204, 0.2759949564933777]
200 2.0854451656341553 [0.3846457302570343, 0.2742742598056793, 0.25108227133750916]
300 1.8720414638519287 [0.35999801754951477, 0.24873992800712585, 0.22601842880249023]
train loss: 2.154 | acc: 21.532 (10766/50000)
test  loss: 2.099 | acc: 27.36  ( 2736/10000) (up by 27.36)
Epoch 1
0 1.7926530838012695 [0.3287651538848877, 0.050074175000190735, 0.05373675748705864]
100 1.8145718574523926 [0.3062690496444702, 0.04554499313235283, 0.05205224081873894]
200 1.7525367736816406 [0.27903982996940613, 0.04170724377036095, 0.04629070311784744]
300 1.5481927394866943 [0.25506502389907837, 0.03975250944495201, 0.04279080405831337]
train loss: 1.716 | acc: 36.148 (18074/50000)
test  loss: 1.649 | acc: 39.36  ( 3936/10000) (up by 12.00)
Epoch 2
0 1.6551792621612549 [0.23592537641525269, 0.03381640836596489, 0.035643130540847

train loss: 0.575 | acc: 80.270 (40135/50000)
test  loss: 0.775 | acc: 73.44  ( 7344/10000)
Epoch 18
0 0.6972066760063171 [0.08370932936668396, 0.0009309304878115654, 0.0012791709741577506]
100 0.5550433993339539 [0.07841374725103378, 0.0014296907465904951, 0.002110864967107773]
200 0.6731474995613098 [0.07262791693210602, 0.0014554635854437947, 0.0028041889891028404]
300 0.4247148633003235 [0.06934884935617447, 0.0020732718985527754, 0.0031631195452064276]
train loss: 0.569 | acc: 80.294 (40147/50000)
test  loss: 0.714 | acc: 76.77  ( 7677/10000)
Epoch 19
0 0.5979697704315186 [0.07526425272226334, 0.000888223119545728, 0.0014423412503674626]
100 0.4077020287513733 [0.08384042233228683, 0.0014802508521825075, 0.0022622165270149708]
200 0.40080755949020386 [0.07797499001026154, 0.001807714463211596, 0.002396781463176012]
300 0.5369938611984253 [0.07168252766132355, 0.002119452226907015, 0.002372283022850752]
train loss: 0.550 | acc: 80.882 (40441/50000)
test  loss: 0.699 | acc: 76.68  (

300 0.3880174160003662 [0.0838351845741272, 0.0030637714080512524, 0.00335335242561996]
train loss: 0.504 | acc: 82.714 (41357/50000)
test  loss: 0.842 | acc: 72.61  ( 7261/10000)
Epoch 36
0 0.5108032822608948 [0.09085855633020401, 0.0009114841232076287, 0.0013477579923346639]
100 0.4734669327735901 [0.08414360135793686, 0.001761490129865706, 0.002367902547121048]
200 0.46316665410995483 [0.08655106276273727, 0.0019181885290890932, 0.003252035938203335]
300 0.48596546053886414 [0.09223815053701401, 0.0021711874287575483, 0.0030762297101318836]
train loss: 0.505 | acc: 82.678 (41339/50000)
test  loss: 0.715 | acc: 76.60  ( 7660/10000)
Epoch 37
0 0.5684757828712463 [0.09389279782772064, 0.0009987927041947842, 0.0012832869542762637]
100 0.6188259720802307 [0.08870349824428558, 0.0015754939522594213, 0.0021625813096761703]
200 0.40221819281578064 [0.07853559404611588, 0.0017384688835591078, 0.0030544812325388193]
300 0.3984885811805725 [0.09229308366775513, 0.0020795317832380533, 0.0034130

0 0.15532353520393372 [0.0869075208902359, 0.0008254848653450608, 0.0008913796627894044]
100 0.18233615159988403 [0.08691118657588959, 0.0007987382705323398, 0.0009884374449029565]
200 0.33762088418006897 [0.08588048070669174, 0.0008053504279814661, 0.0009811898926272988]
300 0.19287925958633423 [0.08547412604093552, 0.000810656463727355, 0.0011303172213956714]
train loss: 0.228 | acc: 92.154 (46077/50000)
test  loss: 0.322 | acc: 89.24  ( 8924/10000)
Epoch 54
0 0.24385741353034973 [0.08596447110176086, 0.0007617035298608243, 0.0009427473996765912]
100 0.2659147083759308 [0.08669768273830414, 0.0007924020173959434, 0.0010589573066681623]
200 0.09842280298471451 [0.08726697415113449, 0.0007809321396052837, 0.0010691372444853187]
300 0.2494063824415207 [0.08785205334424973, 0.0008091972558759153, 0.0011279555037617683]
train loss: 0.222 | acc: 92.406 (46203/50000)
test  loss: 0.311 | acc: 89.40  ( 8940/10000)
Epoch 55
0 0.24449476599693298 [0.08817902952432632, 0.0007038518670015037, 0.0

In [204]:
history

[{'acc': 0, 'epoch': 0},
 {'epoch': 0,
  'train_loss': 1169.85959982872,
  'train_acc': 16.206,
  'loss': 202.5710642337799,
  'acc': 20.63},
 {'epoch': 1,
  'train_loss': 857.1845973730087,
  'train_acc': 27.856,
  'loss': 173.35470294952393,
  'acc': 35.48},
 {'epoch': 2,
  'train_loss': 737.6731110811234,
  'train_acc': 37.112,
  'loss': 156.45672821998596,
  'acc': 41.09},
 {'epoch': 3,
  'train_loss': 658.7332085371017,
  'train_acc': 43.226,
  'loss': 147.26791775226593,
  'acc': 46.12},
 {'epoch': 4,
  'train_loss': 579.6091915369034,
  'train_acc': 49.828,
  'loss': 133.8121293783188,
  'acc': 52.62},
 {'epoch': 5,
  'train_loss': 497.3395382165909,
  'train_acc': 56.332,
  'loss': 118.1007451415062,
  'acc': 58.48},
 {'epoch': 6,
  'train_loss': 436.36036986112595,
  'train_acc': 62.228,
  'loss': 117.68855303525925,
  'acc': 57.54},
 {'epoch': 7,
  'train_loss': 399.8831261396408,
  'train_acc': 65.47,
  'loss': 96.41954684257507,
  'acc': 66.25},
 {'epoch': 8,
  'train_loss'

In [0]:
ResNet18 = {'run0': [(0, 40.77), (1, 50.38), (2, 60.33), (3, 67.54), (4, 72.53), (5, 74.63), (8, 77.29), (10, 80.53), (14, 81.13), (16, 84.1), (20, 84.58), (29, 85.29), (42, 87.32), (50, 92.76), (51, 93.15), (52, 93.52), (53, 93.64), (54, 93.65), (58, 93.92), (75, 94.05), (76, 94.35), (82, 94.37), (85, 94.38), (86, 94.39), (92, 94.47)],
        'run1': [(0, 46.66), (1, 59.49), (2, 64.82), (3, 67.94), (4, 73.07), (5, 76.9), (7, 80.42), (10, 82.33), (12, 83.24), (15, 83.56), (22, 84.14), (31, 84.36), (32, 86.76), (36, 87.09), (50, 92.81), (51, 92.97), (52, 93.31), (53, 93.33), (54, 93.58), (55, 93.61), (58, 93.7), (60, 93.71), (62, 93.72), (64, 93.88), (75, 94.32), (76, 94.47), (77, 94.52), (78, 94.55), (79, 94.56), (80, 94.59), (81, 94.68), (83, 94.74), (85, 94.76), (87, 94.77), (88, 94.82), (99, 94.89)], 
        'run2': [(0, 53.2), (1, 59.44), (2, 74.87), (4, 77.51), (5, 80.48), (8, 81.88), (13, 82.32), (16, 82.91), (18, 82.99), (19, 83.11), (24, 83.8), (25, 85.36), (28, 86.4), (50, 92.47), (51, 93.07), (52, 93.63), (56, 93.71), (61, 93.77), (63, 93.94), (75, 94.26), (76, 94.41), (77, 94.43), (78, 94.67), (81, 94.81), (84, 94.84)]}

### Results on ResNet18: 
Following kuangliu, I manually change the learning rate as follows:

* `lr=0.1` for Epoch [0:50] 
* `lr=0.01` for Epoch [50:75]
* `lr=0.001` for Epoch [75:100]

epoch | 0 | 5 |  10 | 15 | 20 | 25 | 30 | 45 | 50 | 55 | 65 | 75 | 80 | 90 | 99
:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:
without twist ('run0') | 40.77 | 74.63 | 80.53 | 81.13 | 84.58 | - | 85.29 | 87.32 | 92.76 | 93.65 | - | 94.05 | 94.35 | 94.39 | 94.47
with twist | 42.39 | 74.86 | 81.51 | 82.88 | - | 84.34 | 86.98 | 87.85 | 92.34 | 93.65 | - | 93.73 | 94.22
with twist ('run1') | 46.66 | 76.90 | 82.33 | 83.56 | - | 84.14 | - | 87.09 | 92.81 | 93.61 | 93.88 | 94.32 | 94.59 | 94.82 | 94.89
with twist | 50.56 | 80.40 | 82.81 | 85.15 | - | - | - | 87.36 | 93.20 | 93.84 | 94.02
deep twist (3) | 41.55 | 72.37 | 78.81 | 79.49 | 82.38 | 82.38 | 83.14 | 83.36 

Not a significant improvement as I initially thought based on the reported accuracy at https://github.com/kuangliu/pytorch-cifar 