<a href="https://colab.research.google.com/github/liuyao12/pytorch-cifar/blob/master/cifar10_with_PDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ResNet with a "twist"

* As far as I'm aware, a simple and novel architecture of ConvNets (Convolutional Neural Networks) that is readily applicable to any existing ResNet backbone.

* The key idea would be hard to come by or justify without viewing ResNet as a partial differential equation (like the heat equation). Traditionally, the standard toolkit for machine learning typically includes basics of multi-variable calculus, linear algebra, and statistics, and not so much PDE. This partly explains why ResNet comes on the scene relatively late (2015), and why this enhanced version of ResNet has not been "reinvented" by the DL community.

* Code based off of https://github.com/kuangliu/pytorch-cifar

* Questions and comments shall be greatly appreciated [@liuyao12](https://twitter.com/liuyao12) or liuyao@gmail.com

A quick summary of ConvNets from a Partial Differential Equations (PDE) point of view. For details, see my [blog post on Observable](https://observablehq.com/@liuyao12/neural-networks-and-partial-differential-equations).

neural network | heat equation
:----:|:-------:
input layer | initial condition
feed forward | solving the equation
hidden layers | solution at intermediate times
output layer | solution at final time
convolution with 3×3 kernel | differential operator of order ≤ 2
weights | coefficients
boundary handling (padding) | boundary condition
multiple channels | system of (coupled) PDEs
e.g. 16×16×3×3 kernel | 16×16 matrix of differential operators
16×16×1×1 kernel | 16×16 matrix of constants


Basically, classical ConvNets (ResNets) are **linear PDEs with constant coefficients**, and here I'm simply trying to make it **variable coefficients**, with the variables being polynomials of degree ≤ 1, which should (in theory) enable the neural net to learn more ways to deform the input than diffusion and translation (e.g., rotation and scaling).

In [109]:
''' 
ResNet in PyTorch, forked from https://github.com/kuangliu/pytorch-cifar
Reference:
    Kaiming He 何恺明, Xiangyu Zhang 张祥雨, Shaoqing Ren 任少卿, Jian Sun 孙剑 (Microsoft Research Asia)
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, twist=False):
        super(BasicBlock, self).__init__()
        self.twist = twist
        self.mask1, self.mask2 = None, None
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        if self.twist:
            _, c, h, w = list(x1.shape)
            k = c // 3
            # symmetrize the x-kernel (forcing it to be a 1st-order differential operator, aka a vector field)
            self.conv1.weight.data[:k] = (self.conv1.weight[:k] - self.conv1.weight[:k].flip(2).flip(3))/2
            # copy the x-kernel to be the y-kernel
            self.conv1.weight.data[k:2*k] = self.conv1.weight[:k].transpose(2,3).flip(3)
            x1 = self.conv1(x)
            if self.mask1 is None:
                ones = np.ones((h,w), dtype='float32')
                XX = np.indices((h,w), dtype='float32')[1] / w - 0.5
                YY = np.indices((h,w), dtype='float32')[0] / h - 0.5
                mask = [XX] * k + [YY] * k + [ones] * (c - 2 * k)
                self.mask1 = torch.from_numpy(np.stack(mask, axis=0)).to(x.device)
            x1 = self.mask1 * x1
            # add the x- and y-branches onto the main branch
            x1[:,2*k:3*k] += x1[:,:k] + x1[:,k:2*k]
        x1 = F.relu(self.bn1(x1))
        
        x2 = self.conv2(x1)
        if self.twist:
            _, c, h, w = list(x2.shape)
            k = c // 3
            # symmetrize the x-kernel (forcing it to be a 1st-order differential operator, aka a vector field)
            self.conv2.weight.data[:k] = (self.conv2.weight[:k] - self.conv2.weight[:k].flip(2).flip(3))/2
            # copy and rotate the x-kernel to be the y-kernel
            self.conv2.weight.data[k:2*k] = self.conv2.weight[:k].transpose(2,3).flip(3)
            x2 = self.conv2(x1)
            if self.mask2 is None:
                ones = np.ones((h,w), dtype='float32')
                XX = np.indices((h,w), dtype='float32')[1] / w - 0.5
                YY = np.indices((h,w), dtype='float32')[0] / h - 0.5
                mask = [XX] * k + [YY] * k + [ones] * (c - 2 * k)
                self.mask2 = torch.from_numpy(np.stack(mask, axis=0)).to(x.device)
            x2 = self.mask2 * x2
            x2[:,2*k:3*k] += x2[:,:k] + x2[:,k:2*k]
        x3 = self.shortcut(x) + self.bn2(x2)
        x3 = F.relu(x3)
        return x3

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1, twist=True):
        super(Bottleneck, self).__init__()
        self.twist = twist
        self.channels = channels
        self.XY = None
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv3 = nn.Conv2d(channels, self.expansion * channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.relu(self.bn1(x1))
        k = self.channels // 3
        if self.twist: 
            # symmetrize the kernels (force it to be a 1st-order diff op, i.e. a vector field)
            self.conv2.weight.data[:,:k] = (self.conv2.weight[:,:k] - self.conv2.weight[:,:k].flip(2).flip(3))/2
            # make y-vector perpendicular to x-vector
            self.conv2.weight.data[:,k:2*k] = self.conv2.weight[:,:k].transpose(2,3).flip(3)
        x2 = self.conv2(x1)
        if self.twist:
            if self.XY is None: # initialize self.XY
                _, c, h, w = tuple(x2.shape)
                ones = np.ones((h,w), dtype='float32')
                XX = np.indices((h,w), dtype='float32')[1] * 0.1 / w - 0.05
                YY = np.indices((h,w), dtype='float32')[0] * 0.1 / h - 0.05
                XY = [XX] * k + [YY] * k + [ones] * (c - 2 * k)
                self.XY = torch.from_numpy(np.stack(XY, axis=0)).to(x.device)
            x2 = self.XY * x2
            x2[:,2*k:3*k] += x2[:,:k] + x2[:,k:2*k]
        x3 = F.relu(self.bn2(x2))
        x4 = self.conv3(x3)
        x4 = self.bn3(x4)
        x4 += self.shortcut(x)
        x4 = F.relu(x4)
        return x4


class PDE_Block(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, twist=True):
        super(PDE_Block, self).__init__()
        self.twist = twist and in_channels == channels and stride == 1
        self.conv = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.convx = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.convy = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(channels)
        self.kernel_x = torch.tensor([[1,2,1],[0,0,0],[-1,-2,-1]]) / 2
        self.kernel_y = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]]) / 2
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        if self.XX is None:
            _, _, h, w = tuple(x.shape)
            self.XX = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.YY = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
            # self.conv.weight.data = self.orthogonal(self.conv.weight)
            # self.convx.weight.data = self.orthogonal(self.convx.weight)
        
        if self.twist:
            # symmetrize kernels
            # self.convx.weight.data = self.convx.weight[:,:,0:1,1:2] * self.kernel_y.to(x.device
            #                    ) +  self.convx.weight[:,:,1:2,0:1] * self.kernel_x.to(x.device)
            self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3)) / 2
            # self.convy.weight.data = (self.convy.weight - self.convy.weight.flip(2).flip(3)) / 2
            self.convy.weight.data = self.convx.weight.transpose(2,3).flip(2)
            for _ in range(3):
                x = self.Euler_step(x)
        else:
            x = self.Euler_step(x)
        x = self.bn(x)
        x = F.relu(x)
        return x
    
    def orthogonal(self, x):
        shape = x.shape
        q, r = torch.qr(self.convx.weight.view(list(shape)[0], -1))
        return r.view(shape).to(x.device)
    
    def Euler_step(self, x):
        x1 = self.conv(x)
        if self.twist:
            x1 += self.XX * self.convx(x) + self.YY * self.convy(x)
        x1 += self.shortcut(x)
        return x1

    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        c0 = 32
        channel = [c0, 2 * c0, 4 * c0, 4 * c0]
        self.in_channels = c0
        self.num_blocks = num_blocks
        self.conv1 = nn.Conv2d(3, channel[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channel[0])
        self.layer1 = self._make_layer(block, channel[0], num_blocks[0], stride=1, twist=True)
        self.layer2 = self._make_layer(block, channel[1], num_blocks[1], stride=2, twist=True)
        self.layer3 = self._make_layer(block, channel[2], num_blocks[2], stride=2, twist=True)
        self.layer4 = self._make_layer(block, channel[3], num_blocks[3], stride=2, twist=False)
        self.linear = nn.Linear(channel[3] * block.expansion, num_classes)

    def _make_layer(self, block, channels, num_blocks, stride, twist=True):
        strides = [stride] + [1] * num_blocks
        layers = []
        for idx, stride in enumerate(strides):
            # twist = twist and idx < 3
            layers.append(block(self.in_channels, channels, stride, twist))
            self.in_channels = channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
    
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

net = ResNet18()
# net = ResNet(PDE_Block, [3,4,6,3])
epoch = 0 
lr = 0.1
checkpoint = {'acc': 0, 'epoch': 0}
history = [{'acc': 0, 'epoch': 0}]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device =', device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
net.to(device)
print('Testing on a random input:')
test = torch.randn(1,3,32,32).to(device)
print('INPUT ', test.shape)
print('OUTPUT', net(test).shape)

device = cuda
Testing on a random input:
INPUT  torch.Size([1, 3, 32, 32])
OUTPUT torch.Size([1, 10])


In [82]:
import torchvision
import torchvision.transforms as transforms

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(1.1,1.2)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Training
def train(loss_func, opt):
    global history
    print('Epoch {} (lr = {:.6f})'.format(epoch, lr))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        pred = net(x)
        loss = loss_func(pred, y)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        _, predicted = pred.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    print('train loss: {:.3f} | acc: {:.3f} ({}/{})'.format(
        train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': 100. * correct / total})
    
def test(loss_func):
    global checkpoint, history
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            pred = net(x)
            loss = loss_func(pred, y)
            test_loss += loss.item()
            _, predicted = pred.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    acc = 100. * correct / total
    history[-1]['loss'] = test_loss
    history[-1]['acc'] = acc
    if acc > checkpoint['acc']:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{}) (up by {:.2f})'.format(
               test_loss / (batch_idx + 1), 100. * correct / total, correct, total,
               acc - checkpoint['acc']))
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        checkpoint = state
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')
        # torch.save(state, './checkpoint/ckpt.pth')
    else:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{})'.format(
            test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

loss_func = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)  # 5e-4

for _ in range(300):
    global epoch, checkpoint, history
    if epoch < 15:
        lr = (epoch + 1) / 150
        for param_group in opt.param_groups:
            param_group['lr'] = lr
    if epoch - checkpoint['epoch'] >= 20:
        if lr < 0.00011 or history[-1].get('train_acc', 0) > 99.9:
            break
        lr *= 0.1
        print('learning rate downgraded to {} at epoch {}'.format(lr, epoch))
        print('loading state_dict from Epoch {} (acc = {})'.format(checkpoint['epoch'], checkpoint['acc']))
        net.load_state_dict(checkpoint['net'])
        checkpoint['epoch'] = epoch
        history.append({'epoch': checkpoint['epoch'], 'acc': checkpoint['acc']})
        for param_group in opt.param_groups:
            param_group['lr'] = lr
    train(loss_func, opt)
    test(loss_func)
    epoch += 1
print('finish at lr = {}, acc = {}'.format(lr, checkpoint['acc']))

Epoch 0 (lr = 0.006667)
train loss: 1.740 | acc: 34.704 (17352/50000)
test  loss: 1.448 | acc: 46.79  ( 4679/10000) (up by 46.79)
Epoch 1 (lr = 0.013333)
train loss: 1.422 | acc: 47.890 (23945/50000)
test  loss: 1.321 | acc: 53.34  ( 5334/10000) (up by 6.55)
Epoch 2 (lr = 0.020000)
train loss: 1.235 | acc: 55.408 (27704/50000)
test  loss: 1.116 | acc: 60.05  ( 6005/10000) (up by 6.71)
Epoch 3 (lr = 0.026667)
train loss: 1.100 | acc: 60.772 (30386/50000)
test  loss: 0.925 | acc: 67.35  ( 6735/10000) (up by 7.30)
Epoch 4 (lr = 0.033333)
train loss: 1.004 | acc: 64.468 (32234/50000)
test  loss: 0.955 | acc: 67.07  ( 6707/10000)
Epoch 5 (lr = 0.040000)


### Results: 

Baseline (classic ResNet)

* ResNet50: 94.29 Bottleneck 32, [3,4,6,3], twist=[F,F,F,F]

With "twist":
* 94.15 Bottleneck 32, [3,4,6,3], twist=[T,T,T,T]
* 94.59 Bottleneck 32, [3,4,6,3], twist=[T,T,T,F]
* 94.84 Bottleneck 32, [8,8,8,3], twist=[T,T,T,F]
* 94.22 Bottleneck 32, [8,8,16,3], twist=[T,T,T,F]

Everything is run with lr=[0.1, 0.01, 0.001], downgrading if plateaued for 20 epochs.

