<a href="https://colab.research.google.com/github/liuyao12/pytorch-cifar/blob/master/cifar10_with_PDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ResNet with a "twist"

* As far as I'm aware, a simple and novel architecture of ConvNets (Convolutional Neural Networks) that is readily applicable to any existing ResNet backbone.

* The key idea would be hard to come by or justify without viewing ResNet as a partial differential equation (like the heat equation). Traditionally, the standard toolkit for machine learning typically includes basics of multi-variable calculus, linear algebra, and statistics, and not so much PDE. This partly explains why ResNet comes on the scene relatively late (2015), and why this enhanced version of ResNet has not been "reinvented" by the DL community.

* Code based off of https://github.com/kuangliu/pytorch-cifar

* Questions and comments shall be greatly appreciated [@liuyao12](https://twitter.com/liuyao12) or liuyao@gmail.com

A quick summary of ConvNets from a Partial Differential Equations (PDE) point of view. For details, see my [blog post on Observable](https://observablehq.com/@liuyao12/neural-networks-and-partial-differential-equations).

neural network | heat equation
:----:|:-------:
input layer | initial condition
feed forward | solving the equation
hidden layers | solution at intermediate times
output layer | solution at final time
convolution with 3×3 kernel | differential operator of order ≤ 2
weights | coefficients
boundary handling (padding) | boundary condition
multiple channels/filters/feature maps | system of (coupled) PDEs
e.g. 16×16×3×3 kernel | 16×16 matrix of differential operators
16×16×1×1 kernel | 16×16 matrix of constants
groups=2 (in Conv2d) | matrix is block diagonal (direct sum of 2 blocks)


Basically, classical ConvNets (ResNets) are **linear PDEs with constant coefficients**, and here I'm simply trying to make it **variable coefficients**, with the variables being polynomials of degree ≤ 1, which should (in theory) enable the neural net to learn more ways to deform the input than diffusion and translation (e.g., rotation and scaling).

In [1]:
''' 
ResNet in PyTorch, forked from https://github.com/kuangliu/pytorch-cifar
Reference:
    Kaiming He 何恺明, Xiangyu Zhang 张祥雨, Shaoqing Ren 任少卿, Jian Sun 孙剑 (Microsoft Research Asia)
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1):
        super(BasicBlock, self).__init__()
        self.match = stride == 1 and in_channels == self.expansion * channels
        self.twist = False
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.conv1x = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1y = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2x = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2y = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if not self.match:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x = F.relu(self.bn1(x))
        x1 = self.conv1(x)
        if self.twist:
            _, c, h, w = tuple(x1.shape)
            # symmetrize the x-kernel (forcing it to be a 1st-order differential operator, aka a vector field)
            self.conv1x.weight.data = (self.conv1x.weight - self.conv1x.weight.flip(2).flip(3)) / 2
            # copy the x-kernel to be the y-kernel
            # self.conv1y.weight.data = (self.conv1y.weight - self.conv1y.weight.flip(2).flip(3)) / 2
            self.conv1y.weight.data = self.conv1x.weight.transpose(2,3).flip(2)
            if self.XX is None:
                self.XX = torch.from_numpy(np.indices((h,w), dtype='float32')[1] / w - 0.5).to(x.device)
                self.YY = torch.from_numpy(np.indices((h,w), dtype='float32')[0] / h - 0.5).to(x.device)
                # print("twist initialized, self.XX", self.XX.shape, self.XX.mean().item())
            x1 = self.conv1(x) + self.XX * self.conv1x(x) + self.YY * self.conv1y(x)
            # print("twist initialized, outside self.XX", self.XX.shape, self.XX.mean().item())
        
        x2 = F.relu(self.bn2(x1))
        if self.twist:
            # symmetrize the x-kernel (forcing it to be a 1st-order differential operator, aka a vector field)
            self.conv2x.weight.data = (self.conv2x.weight - self.conv2x.weight.flip(2).flip(3)) / 2
            # copy the x-kernel to be the y-kernel
            # self.conv2y.weight.data = (self.conv2y.weight - self.conv2y.weight.flip(2).flip(3)) / 2
            self.conv2y.weight.data = self.conv2x.weight.transpose(2,3).flip(2)
            x3 = self.conv2(x2) + self.XX * self.conv2x(x2) + self.YY * self.conv2y(x2)
        else:
            x3 = self.conv2(x2)
        x3 += self.shortcut(x)
        return x3

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1):
        super(Bottleneck, self).__init__()
        self.twist = False
        self.channels = channels
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2x = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2y = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv3 = nn.Conv2d(channels, self.expansion * channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.relu(self.bn1(x1))
        if self.twist: 
            # symmetrize the kernels (force it to be a 1st-order diff op, i.e. a vector field)
            tmp = (self.conv2x.weight - self.conv2x.weight.flip(2)) / 2
            self.conv2x.weight.data = (tmp - tmp.flip(3)) / 2
            # self.conv2y.weight.data = (self.conv2y.weight - self.conv2y.weight.flip(2).flip(3)) / 2
            # make y-vector perpendicular to x-vector
            self.conv2y.weight.data = self.conv2x.weight.transpose(2,3).flip(3)
        x2 = self.conv2(x1)
        if self.twist:
            if self.XX is None: # initialize self.XY
                _, c, h, w = tuple(x2.shape)
                self.XX = torch.from_numpy(np.indices((h,w), dtype='float32')[1] / w - 0.5).to(x.device)
                self.YY = torch.from_numpy(np.indices((h,w), dtype='float32')[0] / h - 0.5).to(x.device)
            x2 += self.XX * self.conv2x(x1) + self.YY * self.conv2y(x1)
        x3 = F.relu(self.bn2(x2))
        x4 = self.conv3(x3)
        x4 = self.bn3(x4)
        x4 += self.shortcut(x)
        x4 = F.relu(x4)
        return x4


class PDEBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1):
        super(PDEBlock, self).__init__()
        self.twist = False
        self.iterations = 2
        self.expand = in_channels != channels or stride != 1
        if self.expand:
            self.conv0 = nn.Conv2d(in_channels, channels, kernel_size=1, stride=stride, padding=1, bias=False)
            self.bn0 = nn.BatchNorm2d(channels)
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, groups=4, padding=1, bias=False)
        self.XX, self.YY = None, None
        self.convx = nn.Conv2d(channels, channels, kernel_size=3, stride=1, groups=4, padding=1, bias=False)
        self.convy = nn.Conv2d(channels, channels, kernel_size=3, stride=1, groups=4, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(channels)

    def forward(self, x):
        if self.expand:
            x = self.bn0(self.conv0(x))
        if self.XX is None:
            _, _, h, w = tuple(x.shape)
            self.XX = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[3]/w-0.5).to(x.device)
            self.YY = torch.from_numpy(np.indices((1,1,h,w), dtype='float32')[2]/h-0.5).to(x.device)
        
        if self.twist:
            # symmetrize kernels
            self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3)) / 2
            # self.convy.weight.data = (self.convy.weight - self.convy.weight.flip(2).flip(3)) / 2
            self.convy.weight.data = self.convx.weight.transpose(2,3).flip(2)
        for i in range(self.iterations):
            x1 = self.conv(x)
            if self.twist:
                x1 += self.XX * self.convx(x) + self.YY * self.convy(x)
            x = x + x1 / self.iterations
        x = self.bn(x)
        x = F.relu(x)
        return x

    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 32
        channels = [self.in_channels * i for i in [1, 2, 4, 4]]
        self.num_blocks = num_blocks
        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.layer1 = self._make_layer(block, channels[0], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, channels[1], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, channels[2], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, channels[3], num_blocks[3], stride=2)
        self.linear = nn.Linear(channels[2] * block.expansion, num_classes)

    def _make_layer(self, block, channels, num_blocks, stride):
        if num_blocks == 0:
            layers = [nn.Conv2d(self.in_channels, channels, kernel_size=1, stride=stride, padding=1, bias=False),
                      nn.BatchNorm2d(channels)]
            self.in_channels = channels * block.expansion
        else:
            strides = [stride] + [1] * (num_blocks - 1)
            layers = []
            for idx, stride in enumerate(strides):
                # twist = twist and idx < 3
                layers.append(block(self.in_channels, channels, stride))
                self.in_channels = channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.adaptive_avg_pool2d(x,1)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
    
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

net = ResNet34()
# net = ResNet(PDEBlock, [3,3,3])
epoch = 0 
lr = 0.1
checkpoint = {'acc': 0, 'epoch': 0}
history = [{'acc': 0, 'epoch': 0}]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device =', device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
net.to(device)
print('Testing on a random input:')
test = torch.randn(1,3,224,224).to(device)
print('INPUT ', test.shape)
print('OUTPUT', net(test).shape)

device = cuda
Testing on a random input:
INPUT  torch.Size([1, 3, 224, 224])
OUTPUT torch.Size([1, 10])


In [2]:
import torchvision
import torchvision.transforms as transforms

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    # transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=10, translate=(0.3,0.3), scale=(0.8,1.2)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainset = torchvision.datasets.ImageFolder(root='./data/imagenette2/train', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

# testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testset = torchvision.datasets.ImageFolder(root='./data/imagenette2/val', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..


In [3]:
# Training
def train(loss_func, opt):
    global history
    print('Epoch {} (lr={:.4f})'.format(epoch, lr))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        pred = net(x)
        loss = loss_func(pred, y)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        _, predicted = pred.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    print('train loss: {:.3f} | acc: {:.3f} ({}/{})'.format(
        train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': 100. * correct / total})
    
def test(loss_func):
    global checkpoint, history
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            pred = net(x)
            loss = loss_func(pred, y)
            test_loss += loss.item()
            _, predicted = pred.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    acc = 100. * correct / total
    history[-1]['loss'] = test_loss
    history[-1]['acc'] = acc
    if acc > checkpoint['acc']:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{}) (up by {:.2f})'.format(
               test_loss / (batch_idx + 1), 100. * correct / total, correct, total,
               acc - checkpoint['acc']))
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'lr': lr,
            'acc': acc,
            'epoch': epoch
        }
        checkpoint = state
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')
        # torch.save(state, './checkpoint/ckpt.pth')
    else:
        print('test  loss: {:.3f} | acc: {:.2f}  ( {}/{})'.format(
            test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

loss_func = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)  # 5e-4


from math import exp, e

def lr_schedule(x, lr):
    x0 = 10
    y0 = 0.1
    return 0.001 + exp(- x / x0) * x * e * y0 / x0
#     if x < x0:
#         return 0.001 if x == 0 else lr + y0 / x0
#     elif x < 100:
#         return 0.001 + (lr - 0.001) * 0.95

for _ in range(10):
    global epoch, checkpoint, history
    if history[-1].get('train_acc', 0) > 99.99:
        break
    if epoch == 0:
        m = net.module
        for layer in [m.layer1]: #, m.layer2, m.layer3]: #, m.layer4]:
            for i in range(len(layer)):
                layer[i].twist = True
        print("twist on")
        print('testing on random initial weights:')
        test(loss_func)
    if epoch < 0:
        lr = lr_schedule(epoch, lr)
        for param_group in opt.param_groups:
            param_group['lr'] = lr
    elif epoch - checkpoint['epoch'] >= 30:
        print('\nloading state_dict from Epoch {} (acc = {})'.format(checkpoint['epoch'], checkpoint['acc']))
        net.load_state_dict(checkpoint['net'])
        m = net.module
        for layer in [m.layer1, m.layer2]: #, m.layer3]: #, m.layer4]:
            for i in range(len(layer)):
                layer[i].twist = True
        test(loss_func)
        lr = checkpoint['lr'] * 0.1
        print('\nlearning rate downgraded to {} at epoch {}'.format(lr, epoch))
        checkpoint['epoch'] = epoch
        history.append({'epoch': checkpoint['epoch'], 'acc': checkpoint['acc']})
        opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        
    train(loss_func, opt)
    test(loss_func)
    epoch += 1
print('finish at lr = {}, acc = {}'.format(lr, checkpoint['acc']))

twist on
testing on random initial weights:
test  loss: 2.305 | acc: 10.80  ( 424/3925) (up by 10.80)
Epoch 0 (lr=0.1000)
train loss: 2.294 | acc: 14.162 (1341/9469)
test  loss: 2.292 | acc: 16.38  ( 643/3925) (up by 5.58)
Epoch 1 (lr=0.1000)
train loss: 2.125 | acc: 21.291 (2016/9469)
test  loss: 2.119 | acc: 22.01  ( 864/3925) (up by 5.63)
Epoch 2 (lr=0.1000)
train loss: 2.050 | acc: 24.543 (2324/9469)
test  loss: 2.044 | acc: 24.92  ( 978/3925) (up by 2.90)
Epoch 3 (lr=0.1000)
train loss: 1.984 | acc: 27.585 (2612/9469)
test  loss: 3.112 | acc: 17.38  ( 682/3925)
Epoch 4 (lr=0.1000)
train loss: 1.974 | acc: 29.443 (2788/9469)
test  loss: 2.181 | acc: 18.88  ( 741/3925)
Epoch 5 (lr=0.1000)
train loss: 1.816 | acc: 36.392 (3446/9469)
test  loss: 1.833 | acc: 36.25  ( 1423/3925) (up by 11.34)
Epoch 6 (lr=0.1000)
train loss: 1.728 | acc: 39.159 (3708/9469)
test  loss: 4.603 | acc: 20.15  ( 791/3925)
Epoch 7 (lr=0.1000)


Traceback (most recent call last):
  File "/home/www/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/www/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/www/anaconda3/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/www/anaconda3/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [18]:
m = net.module
print(m.conv1.weight.mean().item(), m.conv1.weight.std().item(), m.conv1.weight.shape)
for layer in [m.layer1, m.layer2, m.layer3]: #, m.layer4]:
    print()
    for i in range(len(layer)):
        print(layer[i].conv.weight.mean().item(), layer[i].conv.weight.std().item(), layer[i].convx.weight.std().item())

-0.0017278349259868264 0.24676665663719177 torch.Size([8, 3, 3, 3])

-0.003532852977514267 0.1745581179857254 0.034341562539339066
-0.006805328652262688 0.14749710261821747 0.03434890881180763
0.007838931865990162 0.15485425293445587 0.022962460294365883

-0.006107522640377283 0.10852225124835968 0.018574967980384827
-0.004374963231384754 0.12109889090061188 0.02475770376622677
0.008549327962100506 0.11630651354789734 0.02603377401828766

-0.00048096803948283195 0.08133669197559357 0.015090527012944221
-0.0058867610059678555 0.09106168895959854 0.016228066757321358
0.0011466664727777243 0.05562036111950874 0.00782526470720768


### Results: 

Training with lr=[0.1, 0.01, 0.001], downgrading if plateaued for 20 epochs.

Baseline (classic ResNet)

* ResNet50: 94.29 Bottleneck 32, [3,4,6,3], twist=[F,F,F,F]

With "twist":
* 94.52 BasicBlock 64, [2,2,2,2], twist=[T,T,T,T], rotation_aug=10

* 94.15 Bottleneck 32, [3,4,6,3], twist=[T,T,T,T]
* 94.59 Bottleneck 32, [3,4,6,3], twist=[T,T,T,F]
* 94.28 Bottleneck 32, [3,4,6,3], twist=[T,F,F,F]
* 94.84 Bottleneck 32, [8,8,8,3], twist=[T,T,T,F]
* 94.22 Bottleneck 32, [8,8,16,3], twist=[T,T,T,F]

When training with lr= x e^(-x) -- rising linearly, then falling off exponentially -- it converges faster, in about 100 epochs

* 93.98 ResNet50 Bottleneck 16
* 94.53 (train_acc=96.13) , BasicBlock 64, [3,4,6,3], twist=[T,T,T,F]
* 93.19 Bottleneck 16, [3,4,6,3], twist=[T,T,T,F]
* 94.06 Bottleneck 32, [3,4,6,3], twist=[T,T,T,F]
* 94.72 Bottleneck 64, [3,4,6,3], twist=[T,T,T,F]


In [19]:
torch.tensor([
 [2.8104, 2.6607, 2.0765, 1.7392, 1.5272, 1.3853, 1.1833, 1.1261, 0.8286,
        0.7381, 0.6732, 0.5530, 0.4448, 0.3987, 0.3280, 0.2829, 0.2381, 0.2129,
        0.1887, 0.1716, 0.1538, 0.1227, 0.1100, 0.0991, 0.0793, 0.0575, 0.0470],
 [2.5849, 2.4312, 2.0399, 1.4542, 1.3491, 1.2283, 1.0040, 0.7782, 0.7403,
        0.6404, 0.4244, 0.4096, 0.3201, 0.2533, 0.2170, 0.1899, 0.1393, 0.1194,
        0.1111, 0.0810, 0.0702, 0.0599, 0.0491, 0.0427, 0.0316, 0.0281, 0.0211],
 [2.3489, 2.1176, 1.8689, 1.3103, 1.1440, 1.1354, 0.8471, 0.7960, 0.6468,
        0.4969, 0.4144, 0.3684, 0.2797, 0.2498, 0.1899, 0.1561, 0.1421, 0.1118,
        0.0815, 0.0706, 0.0602, 0.0536, 0.0389, 0.0341, 0.0209, 0.0138, 0.0093],
 [2.3597, 2.0739, 1.8342, 1.2758, 1.1094, 1.0668, 0.9078, 0.8299, 0.6184,
        0.5731, 0.4374, 0.4200, 0.2464, 0.2361, 0.2333, 0.1684, 0.1497, 0.1077,
        0.0909, 0.0606, 0.0542, 0.0475, 0.0420, 0.0290, 0.0157, 0.0121, 0.0053],
 [2.4256, 2.0896, 1.8365, 1.3078, 1.1393, 1.1217, 1.0316, 0.8114, 0.6341,
        0.5491, 0.4557, 0.3869, 0.2779, 0.2309, 0.1898, 0.1831, 0.1766, 0.1271,
        0.0781, 0.0603, 0.0509, 0.0406, 0.0363, 0.0292, 0.0185, 0.0120, 0.0060],
 [2.4849, 2.0532, 1.8571, 1.3805, 1.1323, 1.1066, 1.0433, 0.8157, 0.6629,
        0.5746, 0.4429, 0.3829, 0.2724, 0.2594, 0.1690, 0.1577, 0.1473, 0.1279,
        0.0809, 0.0602, 0.0458, 0.0376, 0.0331, 0.0263, 0.0196, 0.0110, 0.0061],
 [2.4709, 2.0658, 1.8636, 1.4249, 1.1889, 1.1478, 0.9837, 0.7942, 0.6960,
        0.5268, 0.4503, 0.4186, 0.2792, 0.2359, 0.1945, 0.1636, 0.1547, 0.1163,
        0.0766, 0.0685, 0.0539, 0.0411, 0.0342, 0.0238, 0.0154, 0.0120, 0.0039],
 [2.5026, 2.0597, 1.8135, 1.4707, 1.2017, 1.0554, 0.9883, 0.8208, 0.7253,
        0.5223, 0.4767, 0.4090, 0.3169, 0.2375, 0.1904, 0.1654, 0.1372, 0.1143,
        0.0842, 0.0540, 0.0447, 0.0322, 0.0279, 0.0192, 0.0152, 0.0099, 0.0069],
 [2.5069, 2.0750, 1.7927, 1.5006, 1.2215, 1.1006, 0.9533, 0.8030, 0.7206,
        0.5561, 0.4817, 0.4372, 0.3454, 0.2270, 0.1583, 0.1296, 0.1175, 0.0945,
        0.0714, 0.0528, 0.0451, 0.0355, 0.0255, 0.0217, 0.0172, 0.0125, 0.0101],
 [2.5199, 2.0277, 1.7792, 1.4563, 1.2592, 1.0659, 0.9456, 0.8021, 0.7113,
        0.5484, 0.4876, 0.4590, 0.3279, 0.1928, 0.1862, 0.1847, 0.1571, 0.1049,
        0.0790, 0.0580, 0.0396, 0.0352, 0.0295, 0.0217, 0.0183, 0.0103, 0.0057],
 [2.3893, 1.9121, 1.6831, 1.3817, 1.1941, 1.0182, 0.9111, 0.7532, 0.6538,
        0.5459, 0.4672, 0.4358, 0.3111, 0.1830, 0.1812, 0.1646, 0.1533, 0.0977,
        0.0770, 0.0575, 0.0403, 0.0324, 0.0275, 0.0221, 0.0167, 0.0096, 0.0064],
 [2.2690, 1.8165, 1.5975, 1.3160, 1.1309, 0.9703, 0.8707, 0.7099, 0.6191,
        0.5291, 0.4453, 0.4059, 0.2917, 0.1800, 0.1698, 0.1525, 0.1403, 0.0821,
        0.0720, 0.0552, 0.0420, 0.0310, 0.0275, 0.0211, 0.0167, 0.0098, 0.0062],
 [2.1601, 1.7282, 1.5242, 1.2564, 1.0730, 0.9203, 0.8337, 0.6863, 0.5854,
        0.5140, 0.4252, 0.3822, 0.2758, 0.1659, 0.1575, 0.1521, 0.1334, 0.0783,
        0.0741, 0.0565, 0.0421, 0.0305, 0.0250, 0.0212, 0.0163, 0.0090, 0.0055],
 [2.0608, 1.6436, 1.4555, 1.2012, 1.0342, 0.8702, 0.7983, 0.6782, 0.5550,
        0.4913, 0.4171, 0.3638, 0.2737, 0.1557, 0.1502, 0.1329, 0.1208, 0.0756,
        0.0704, 0.0500, 0.0414, 0.0295, 0.0238, 0.0203, 0.0151, 0.0089, 0.0044],
 [2.0513, 1.6338, 1.4452, 1.1936, 1.0283, 0.8667, 0.7949, 0.6695, 0.5496,
        0.4944, 0.4150, 0.3623, 0.2676, 0.1548, 0.1484, 0.1310, 0.1200, 0.0749,
        0.0701, 0.0505, 0.0418, 0.0292, 0.0238, 0.0202, 0.0152, 0.0090, 0.0041]
]).shape

torch.Size([15, 27])

In [20]:
history

[{'acc': 7.94, 'epoch': 0, 'loss': 230.29033517837524},
 {'epoch': 0,
  'train_loss': 738.2270909547806,
  'train_acc': 28.212,
  'loss': 178.15970933437347,
  'acc': 35.03},
 {'epoch': 1,
  'train_loss': 636.0984075069427,
  'train_acc': 39.642,
  'loss': 171.37144315242767,
  'acc': 39.55},
 {'epoch': 2,
  'train_loss': 579.8472344875336,
  'train_acc': 45.41,
  'loss': 189.14156591892242,
  'acc': 44.0},
 {'epoch': 3,
  'train_loss': 551.47753739357,
  'train_acc': 48.558,
  'loss': 140.14471018314362,
  'acc': 52.36},
 {'epoch': 4,
  'train_loss': 531.3351120948792,
  'train_acc': 50.55,
  'loss': 194.9198615550995,
  'acc': 41.94},
 {'epoch': 5,
  'train_loss': 514.1015273332596,
  'train_acc': 52.394,
  'loss': 174.3349907398224,
  'acc': 45.35},
 {'epoch': 6,
  'train_loss': 500.8305673599243,
  'train_acc': 53.652,
  'loss': 155.5255310535431,
  'acc': 50.97},
 {'epoch': 7,
  'train_loss': 490.45059484243393,
  'train_acc': 54.622,
  'loss': 181.0910484790802,
  'acc': 47.19},


In [21]:
print(lr)

0.0010000000000000002
