<a href="https://colab.research.google.com/github/liuyao12/pytorch-cifar/blob/master/cifar10_with_PDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ResNet with a "twist"

* As far as I'm aware, a simple and novel architecture of ConvNets (Convolutional Neural Networks) that is readily applicable to any existing ResNet backbone.

* The key idea would be hard to come by or justify without viewing ResNet as a partial differential equation (like the heat equation). Traditionally, the standard toolkit for machine learning typically includes basics of multi-variable calculus, linear algebra, and statistics, and not so much PDE. This partly explains why ResNet comes on the scene relatively late (2015), and why this enhanced version of ResNet has not been "reinvented" by the DL community.

* Code based off of https://github.com/kuangliu/pytorch-cifar

* Questions and comments shall be greatly appreciated [@liuyao12](https://twitter.com/liuyao12) or liuyao@gmail.com

A quick summary of ConvNets from a Partial Differential Equations (PDE) point of view. For details, see my [blog post on Observable](https://observablehq.com/@liuyao12/neural-networks-and-partial-differential-equations).

neural network | heat equation
:----:|:-------:
input layer | initial condition
feed forward | solving the equation
hidden layers | solution at intermediate times
output layer | solution at final time
convolution with 3×3 kernel | differential operator of order ≤ 2
weights | coefficients
boundary handling (padding) | boundary condition
multiple channels | system of (coupled) PDEs
e.g. 16×16×3×3 kernel | 16×16 matrix of differential operators
16×16×1×1 kernel | 16×16 matrix of constants


Basically, classical ConvNets (ResNets) are **linear PDEs with constant coefficients**, and here I'm simply trying to make it **variable coefficients**, with the variables being polynomials of degree ≤ 1, which should (in theory) enable the neural net to learn more ways to deform the input than diffusion and translation (e.g., rotation and scaling).

In [306]:
''' 
ResNet in PyTorch, forked from https://github.com/kuangliu/pytorch-cifar
Reference:
    Kaiming He 何恺明, Xiangyu Zhang 张祥雨, Shaoqing Ren 任少卿, Jian Sun 孙剑 (Microsoft Research Asia)
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, twist=False):
        super(BasicBlock, self).__init__()
        self.twist = False
        self.XX, self.YY = None, None
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1x = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1y = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2x = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2y = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x = F.relu(self.bn1(x))
        x1 = self.conv1(x)
        if self.twist:
            _, c, h, w = list(x1.shape)
            # symmetrize the x-kernel (forcing it to be a 1st-order differential operator, aka a vector field)
            self.conv1x.weight.data = (self.conv1x.weight - self.conv1x.weight.flip(2).flip(3)) / 2
            # copy the x-kernel to be the y-kernel
            self.conv1y.weight.data = self.conv1x.weight.transpose(2,3).flip(3)
            if self.XX is None:
                self.XX = torch.from_numpy(np.indices((h,w), dtype='float32')[1] / w - 0.5).to(x.device)
                self.YY = torch.from_numpy(np.indices((h,w), dtype='float32')[0] / h - 0.5).to(x.device)
#                 self.conv1x.weight.data = self.conv1x.weight / 100
#                 self.conv1y.weight.data = self.conv1y.weight / 100
#                 self.conv2x.weight.data = self.conv2x.weight / 100
#                 self.conv2y.weight.data = self.conv2y.weight / 100
                # print("twist initialized, self.XX", self.XX.shape, self.XX.mean().item())
            x1 = self.conv1(x) + self.XX * self.conv1x(x) + self.YY * self.conv1y(x)
            # print("twist initialized, outside self.XX", self.XX.shape, self.XX.mean().item())
        
        x2 = F.relu(self.bn2(x1))
        if self.twist:
            # symmetrize the x-kernel (forcing it to be a 1st-order differential operator, aka a vector field)
            self.conv2x.weight.data = (self.conv2x.weight - self.conv2x.weight.flip(2).flip(3)) / 2
            # copy the x-kernel to be the y-kernel
            self.conv2y.weight.data = self.conv2x.weight.transpose(2,3).flip(3)
            x2 = self.conv2(x1) + self.XX * self.conv2x(x1) + self.YY * self.conv2y(x1)
        else:
            x2 = self.conv2(x1)
        x2 += self.shortcut(x)
        return x2

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1, twist=True):
        super(Bottleneck, self).__init__()
        self.twist = twist
        self.channels = channels
        self.XY = None
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv3 = nn.Conv2d(channels, self.expansion * channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.relu(self.bn1(x1))
        k = self.channels // 3
        if self.twist: 
            # symmetrize the kernels (force it to be a 1st-order diff op, i.e. a vector field)
            self.conv2.weight.data[:,:k] = (self.conv2.weight[:,:k] - self.conv2.weight[:,:k].flip(2).flip(3)) / 2
            # make y-vector perpendicular to x-vector
            self.conv2.weight.data[:,k:2*k] = self.conv2.weight[:,:k].transpose(2,3).flip(3)
        x2 = self.conv2(x1)
        if self.twist:
            if self.XY is None: # initialize self.XY
                _, c, h, w = tuple(x2.shape)
                ones = np.ones((h,w), dtype='float32')
                XX = np.indices((h,w), dtype='float32')[1] * 0.1 / w - 0.05
                YY = np.indices((h,w), dtype='float32')[0] * 0.1 / h - 0.05
                XY = [XX] * k + [YY] * k + [ones] * (c - 2 * k)
                self.XY = torch.from_numpy(np.stack(XY, axis=0)).to(x.device)
            x2 = self.XY * x2
            x2[:,2*k:3*k] += x2[:,:k] + x2[:,k:2*k]
        x3 = F.relu(self.bn2(x2))
        x4 = self.conv3(x3)
        x4 = self.bn3(x4)
        x4 += self.shortcut(x)
        x4 = F.relu(x4)
        return x4


class PDEBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, twist=True):
        super(PDEBlock, self).__init__()
        self.twist = False
        self.freeze = True
        self.match = in_channels == channels and stride == 1
        self.conv = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.convx = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.convy = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2x = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2y = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        if self.XX is None:
            _, _, h, w = tuple(x.shape)
            self.XX = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.YY = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
            # self.conv.weight.data = self.orthogonal(self.conv.weight)
            # self.convx.weight.data = self.orthogonal(self.convx.weight)
        
        if self.twist and self.match:
            # symmetrize kernels
            self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3)) / 2
            # self.convy.weight.data = (self.convy.weight - self.convy.weight.flip(2).flip(3)) / 2
            self.convy.weight.data = self.convx.weight.transpose(2,3).flip(2)
            if self.freeze:
                self.conv2.weight.data = self.conv.weight
                self.conv2x.weight.data = self.conv2x.weight
                self.conv2y.weight.data = self.cvon2y.weight
            for i in range(2):
                x = self.Euler_step(x, i)
        else:
            x = self.Euler_step(x)
        x = self.bn(x)
        x = F.relu(x)
        return x
    
    def orthogonal(self, x):
        shape = x.shape
        q, r = torch.qr(self.convx.weight.view(list(shape)[0], -1))
        return r.view(shape).to(x.device)
    
    def Euler_step(self, x, i=0):
        x1 = self.conv2(x)
        if self.twist and self.match:
            if i < 1:
                x1 += self.XX * self.convx(x) + self.YY * self.convy(x)
            else:
                x1 += self.XX * self.conv2x(x) + self.YY * self.conv2y(x)
        x1 += self.shortcut(x)
        return x1

    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        channels = [self.in_channels * i for i in [1, 2, 4, 8]]
        self.num_blocks = num_blocks
        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.layer1 = self._make_layer(block, channels[0], num_blocks[0], stride=1, twist=False)
        self.layer2 = self._make_layer(block, channels[1], num_blocks[1], stride=2, twist=False)
        self.layer3 = self._make_layer(block, channels[2], num_blocks[2], stride=2, twist=False)
        self.layer4 = self._make_layer(block, channels[3], num_blocks[3], stride=2, twist=False)
        self.linear = nn.Linear(channels[3] * block.expansion, num_classes)

    def _make_layer(self, block, channels, num_blocks, stride, twist=False):
        strides = [stride] + [1] * num_blocks
        layers = []
        for idx, stride in enumerate(strides):
            # twist = twist and idx < 3
            layers.append(block(self.in_channels, channels, stride, twist))
            self.in_channels = channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        # x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
    
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

# net = ResNet34()
net = ResNet(PDEBlock, [1,3,3,1])
epoch = 0 
lr = 0.1
checkpoint = {'acc': 0, 'epoch': 0}
history = [{'acc': 0, 'epoch': 0}]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device =', device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
net.to(device)
print('Testing on a random input:')
test = torch.randn(1,3,32,32).to(device)
print('INPUT ', test.shape)
print('OUTPUT', net(test).shape)

device = cuda
Testing on a random input:
INPUT  torch.Size([1, 3, 32, 32])
OUTPUT torch.Size([1, 10])


In [211]:
import torchvision
import torchvision.transforms as transforms

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(1,1.2)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [307]:
# Training
def train(loss_func, opt):
    global history
    print('Epoch {} (lr={:.4f})'.format(epoch, lr))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        pred = net(x)
        loss = loss_func(pred, y)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        _, predicted = pred.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    print('train loss: {:.3f} | acc: {:.3f} ({}/{})'.format(
        train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': 100. * correct / total})
    
def test(loss_func):
    global checkpoint, history
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            pred = net(x)
            loss = loss_func(pred, y)
            test_loss += loss.item()
            _, predicted = pred.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    acc = 100. * correct / total
    history[-1]['loss'] = test_loss
    history[-1]['acc'] = acc
    if acc > checkpoint['acc']:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{}) (up by {:.2f})'.format(
               test_loss / (batch_idx + 1), 100. * correct / total, correct, total,
               acc - checkpoint['acc']))
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        checkpoint = state
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')
        # torch.save(state, './checkpoint/ckpt.pth')
    else:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{})'.format(
            test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

loss_func = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)  # 5e-4

for _ in range(300):
    global epoch, checkpoint, history
    if epoch == 0:
        m = net.module
        for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:
            for i in range(len(layer)):
                layer[i].twist = True
        print("twist on")
#     if epoch < 10:
#         lr = (epoch + 1) / 100
#         for param_group in opt.param_groups:
#             param_group['lr'] = lr
#     elif epoch >= 10 and epoch < 25:
#         lr = 0.1 - (epoch - 10) / 200
#         for param_group in opt.param_groups:
#             param_group['lr'] = lr
    elif epoch - checkpoint['epoch'] >= 20:
        if lr < 0.00011 or history[-1].get('train_acc', 0) > 99.9:
            break
        lr *= 0.1
        print('\nlearning rate downgraded to {} at epoch {}'.format(lr, epoch))
        print('loading state_dict from Epoch {} (acc = {})'.format(checkpoint['epoch'], checkpoint['acc']))
        net.load_state_dict(checkpoint['net'])
        checkpoint['epoch'] = epoch
        history.append({'epoch': checkpoint['epoch'], 'acc': checkpoint['acc']})
        opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        
        
    train(loss_func, opt)
    test(loss_func)
    epoch += 1
print('finish at lr = {}, acc = {}'.format(lr, checkpoint['acc']))

twist on
Epoch 0 (lr=0.1000)
train loss: 1.804 | acc: 31.822 (15911/50000)
test  loss: 1.514 | acc: 45.05  ( 4505/10000) (up by 45.05)
Epoch 1 (lr=0.1000)
train loss: 1.454 | acc: 46.852 (23426/50000)
test  loss: 1.853 | acc: 37.93  ( 3793/10000)
Epoch 2 (lr=0.1000)
train loss: 1.279 | acc: 53.648 (26824/50000)
test  loss: 1.336 | acc: 51.81  ( 5181/10000) (up by 6.76)
Epoch 3 (lr=0.1000)
train loss: 1.159 | acc: 58.392 (29196/50000)
test  loss: 1.151 | acc: 59.67  ( 5967/10000) (up by 7.86)
Epoch 4 (lr=0.1000)
train loss: 1.069 | acc: 61.946 (30973/50000)
test  loss: 0.981 | acc: 65.49  ( 6549/10000) (up by 5.82)
Epoch 5 (lr=0.1000)
train loss: 1.000 | acc: 64.840 (32420/50000)
test  loss: 0.976 | acc: 65.73  ( 6573/10000) (up by 0.24)
Epoch 6 (lr=0.1000)
train loss: 0.954 | acc: 66.584 (33292/50000)
test  loss: 0.976 | acc: 67.14  ( 6714/10000) (up by 1.41)
Epoch 7 (lr=0.1000)
train loss: 0.916 | acc: 68.028 (34014/50000)
test  loss: 0.980 | acc: 67.48  ( 6748/10000) (up by 0.34)
Epo

train loss: 0.381 | acc: 87.292 (43646/50000)
test  loss: 0.302 | acc: 89.75  ( 8975/10000) (up by 7.82)
Epoch 72 (lr=0.0100)
train loss: 0.326 | acc: 88.884 (44442/50000)
test  loss: 0.279 | acc: 90.42  ( 9042/10000) (up by 0.67)
Epoch 73 (lr=0.0100)
train loss: 0.302 | acc: 89.826 (44913/50000)
test  loss: 0.271 | acc: 90.78  ( 9078/10000) (up by 0.36)
Epoch 74 (lr=0.0100)
train loss: 0.283 | acc: 90.348 (45174/50000)
test  loss: 0.272 | acc: 90.66  ( 9066/10000)
Epoch 75 (lr=0.0100)
train loss: 0.274 | acc: 90.688 (45344/50000)
test  loss: 0.255 | acc: 91.28  ( 9128/10000) (up by 0.50)
Epoch 76 (lr=0.0100)
train loss: 0.268 | acc: 90.908 (45454/50000)
test  loss: 0.262 | acc: 91.17  ( 9117/10000)
Epoch 77 (lr=0.0100)
train loss: 0.258 | acc: 91.150 (45575/50000)
test  loss: 0.265 | acc: 91.19  ( 9119/10000)
Epoch 78 (lr=0.0100)
train loss: 0.250 | acc: 91.580 (45790/50000)
test  loss: 0.258 | acc: 91.40  ( 9140/10000) (up by 0.12)
Epoch 79 (lr=0.0100)
train loss: 0.245 | acc: 91.736

train loss: 0.060 | acc: 98.022 (49011/50000)
test  loss: 0.221 | acc: 93.99  ( 9399/10000)
Epoch 141 (lr=0.0010)
train loss: 0.058 | acc: 98.050 (49025/50000)
test  loss: 0.219 | acc: 93.79  ( 9379/10000)
Epoch 142 (lr=0.0010)
train loss: 0.061 | acc: 97.966 (48983/50000)
test  loss: 0.234 | acc: 93.49  ( 9349/10000)
Epoch 143 (lr=0.0010)
train loss: 0.056 | acc: 98.214 (49107/50000)
test  loss: 0.221 | acc: 93.76  ( 9376/10000)
Epoch 144 (lr=0.0010)
train loss: 0.056 | acc: 98.138 (49069/50000)
test  loss: 0.226 | acc: 93.68  ( 9368/10000)
Epoch 145 (lr=0.0010)
train loss: 0.055 | acc: 98.242 (49121/50000)
test  loss: 0.223 | acc: 93.81  ( 9381/10000)
Epoch 146 (lr=0.0010)
train loss: 0.056 | acc: 98.158 (49079/50000)
test  loss: 0.227 | acc: 93.75  ( 9375/10000)
Epoch 147 (lr=0.0010)
train loss: 0.056 | acc: 98.158 (49079/50000)
test  loss: 0.221 | acc: 93.84  ( 9384/10000)
Epoch 148 (lr=0.0010)
train loss: 0.053 | acc: 98.266 (49133/50000)
test  loss: 0.220 | acc: 93.73  ( 9373/100

In [328]:
m = net.module
print(m.conv1.weight.mean().item(), m.conv1.weight.std().item(), m.conv1.weight.shape)
for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:
    for i in range(len(layer)):
        print(layer[i].conv.weight.mean().item(), layer[i].conv.weight.std().item(), layer[i].convx.weight.std().item())

6.521296745631844e-05 0.08932402729988098 torch.Size([64, 3, 3, 3])
-0.0007830829126760364 0.02403755486011505 0.005354593507945538
-0.001313234562985599 0.026450015604496002 0.005461629945784807
-0.0019167831633239985 0.020962413400411606 9.243061427355315e-09
-0.0008073995122686028 0.020578620955348015 0.004273314960300922
-0.000625637243501842 0.019594375044107437 0.004158452618867159
-0.0007864369545131922 0.017948370426893234 0.003565750550478697
-0.002031880198046565 0.013541772030293941 6.5324297082725025e-09
-0.00046382119762711227 0.013607258908450603 0.0024890019558370113
-0.00040515229920856655 0.010474667884409428 0.0015528416261076927
-0.000259602238656953 0.00654502771794796 0.0007960244547575712
-0.0007165137794800103 0.0032781183253973722 4.625624772103265e-09
2.8254662538529374e-05 0.0023896435741335154 0.00046875549014657736


### Results: 

Baseline (classic ResNet)

* ResNet50: 94.29 Bottleneck 32, [3,4,6,3], twist=[F,F,F,F]

With "twist":
* 94.15 Bottleneck 32, [3,4,6,3], twist=[T,T,T,T]
* 94.59 Bottleneck 32, [3,4,6,3], twist=[T,T,T,F]
* 94.84 Bottleneck 32, [8,8,8,3], twist=[T,T,T,F]
* 94.22 Bottleneck 32, [8,8,16,3], twist=[T,T,T,F]

Everything is run with lr=[0.1, 0.01, 0.001], downgrading if plateaued for 20 epochs.

