In [1]:
# %% md
import torch
from utils import accuracy
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.retiarii.oneshot.pytorch import DartsTrainer, EnasTrainer
import itertools
import torch.nn.functional as F
import nni.retiarii.nn.pytorch as nn
from collections import OrderedDict

In [2]:
def reward_accuracy(output, target, topk=(1,)):
    batch_size = target.size(0)
    _, predicted = torch.max(output.data, 1)
    return (predicted == target).sum().item() / batch_size

In [1]:
# Retiarii Example - One-shot NAS

## Step 1: Express the Model Space

### Step 1.1: Define the Base Model

class CIFAR_17(nn.Module):
    '''
    BaseModel which has 3 CNN layers and 2 FC layers
    '''

    def __init__(self, head_size=10):
        super(CIFAR_17, self).__init__()

        self.body = nn.Sequential(OrderedDict([
            ('cnn1', nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(3, 8, 3, 1, 1)),
                ('batchnorm', nn.BatchNorm2d(8)),
                ('relu', nn.ReLU(inplace=True)),
                ('pool', nn.MaxPool2d(2))
            ]))),
            ('cnn2', nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(8, 8, 3, 1, 1)),
                ('batchnorm', nn.BatchNorm2d(8)),
                ('relu', nn.ReLU(inplace=True)),
                ('pool', nn.MaxPool2d(2))
            ]))),
            ('cnn3', nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(8, 8, 3, 1, 1)),
                ('batchnorm', nn.BatchNorm2d(8)),
                ('relu', nn.ReLU(inplace=True)),
                ('pool', nn.MaxPool2d(2)),
            ])))
        ]))

        self.head = nn.Sequential(OrderedDict([
            ('dense', nn.Sequential(OrderedDict([
                ('fc1', nn.Conv2d(8 * 4 * 4, 32, kernel_size=1, bias=True)),  # implement dense layer in CNN way
                ('relu', nn.ReLU(inplace=True)),
                ('fc2', nn.Conv2d(32, head_size, kernel_size=1, bias=True)),
            ])))
        ]))

    def features(self, x):
        feat = self.body(x)
        feat = x.view(x.shape[0], -1)
        return feat

    def forward(self, x):
        x = self.body(x)
        x = x.view(x.shape[0], -1, 1, 1)  # flatten
        x = self.head(x)
        x = x.view(x.shape[0], -1)
        return x

model = CIFAR_17()

In [3]:
### Step 1.2: Define the Model Mutations


import torch.nn.functional as F
import nni.retiarii.nn.pytorch as nn


class Net(nn.Module):
    def __init__(self, head_size=10, lower_range=8, upper_range=16):
        super(Net, self).__init__()
        self.head_size = head_size
        self.lower_range = lower_range
        self.upper_range = upper_range
        choice_dict = self._get_mutator()
        self.net = nn.LayerChoice(choice_dict)

    def _get_mutator(self):
        ## this is supposed to be slooow 
        layer_choices = []
        a = [range(self.lower_range, self.upper_range+1),
             range(self.lower_range, self.upper_range+1),
             range(self.lower_range, self.upper_range+1)]

        for comb in list(itertools.product(*a)):
            i, j, k = comb
            layer_choices.append(
                nn.Sequential(
                        nn.Conv2d(3, i, 3, 1, 1),
#                         nn.BatchNorm2d(i),
                        nn.ReLU(inplace=True),
                        nn.MaxPool2d(2),

                        nn.Conv2d(i, j, 3, 1, 1),
#                         nn.BatchNorm2d(j),
                        nn.ReLU(inplace=True),
                        nn.MaxPool2d(2),

                        nn.Conv2d(j, k, 3, 1, 1),
#                         nn.BatchNorm2d(k),
                        nn.ReLU(inplace=True),
                        nn.MaxPool2d(2),

                        nn.Flatten(),
                        nn.Linear(k * 4 * 4, 32),
                        nn.ReLU(inplace=True),
                        nn.Linear(32, self.head_size),
                    )
            )
        return layer_choices

    def forward(self, x):
        out = self.net(x)
        return out


model = Net()
model.forward(torch.rand(1, 3, 32, 32))



tensor([[[[0.8355, 0.0372, 0.4820,  ..., 0.6036, 0.5490, 0.5939],
          [0.9390, 0.1123, 0.1321,  ..., 0.3496, 0.5150, 0.4703],
          [0.1982, 0.4452, 0.6646,  ..., 0.8955, 0.7899, 0.4220],
          ...,
          [0.8055, 0.0359, 0.5717,  ..., 0.5338, 0.2097, 0.1777],
          [0.6984, 0.3758, 0.0012,  ..., 0.4088, 0.4660, 0.1637],
          [0.5624, 0.0933, 0.2508,  ..., 0.6675, 0.1395, 0.4679]],

         [[0.9481, 0.0475, 0.4693,  ..., 0.5553, 0.7257, 0.2556],
          [0.7076, 0.7270, 0.7303,  ..., 0.6901, 0.2042, 0.8692],
          [0.7110, 0.2516, 0.6331,  ..., 0.5559, 0.3112, 0.6330],
          ...,
          [0.3921, 0.6627, 0.7397,  ..., 0.8701, 0.2475, 0.5895],
          [0.6491, 0.2870, 0.8292,  ..., 0.6190, 0.8112, 0.7844],
          [0.4368, 0.7931, 0.0874,  ..., 0.3731, 0.0355, 0.2749]],

         [[0.8739, 0.0190, 0.7372,  ..., 0.1666, 0.1791, 0.5846],
          [0.3235, 0.0745, 0.1028,  ..., 0.6966, 0.7469, 0.6754],
          [0.7236, 0.6912, 0.9650,  ..., 0

In [4]:
## Step 2: Explore the Model Space

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 
                            lr=0.001, 
                            momentum=0.9)

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                     (0.2023, 0.1994, 0.2010)),
                               ])

train_dataset = CIFAR10(root="./data",
                        train=True,
                        download=True,
                        transform=transform)


Files already downloaded and verified


## DARTS or ENAS

In [5]:
trainer = DartsTrainer(
        model=model, 
        loss=criterion,
        metrics=lambda output, target: accuracy(output, target), 
        optimizer=optimizer,
        num_epochs=2, 
        dataset=train_dataset,
        batch_size=8, 
        log_frequency=10,
#         reward_function=reward_accuracy,
#         device=torch.device("cpu")
    
)
        

trainer.fit()

[2021-06-13 22:13:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1/3125]  acc1 0.125000 (0.125000)  loss 2.302913 (2.302913)
[2021-06-13 22:14:01] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [11/3125]  acc1 0.125000 (0.113636)  loss 2.300397 (2.303009)
[2021-06-13 22:14:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [21/3125]  acc1 0.000000 (0.101190)  loss 2.303456 (2.302628)
[2021-06-13 22:14:58] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [31/3125]  acc1 0.375000 (0.112903)  loss 2.302419 (2.302710)
[2021-06-13 22:15:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [41/3125]  acc1 0.000000 (0.097561)  loss 2.302354 (2.302768)
[2021-06-13 22:15:55] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [51/3125]  acc1 0.000000 (0.098039)  loss 2.304265 (2.302703)
[2021-06-13 22:16:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Ep

[2021-06-13 22:37:09] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [531/3125]  acc1 0.000000 (0.080508)  loss 2.302478 (2.302670)
[2021-06-13 22:37:39] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [541/3125]  acc1 0.000000 (0.079945)  loss 2.303230 (2.302673)
[2021-06-13 22:38:08] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [551/3125]  acc1 0.125000 (0.079855)  loss 2.300917 (2.302672)
[2021-06-13 22:38:38] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [561/3125]  acc1 0.250000 (0.079768)  loss 2.303063 (2.302678)
[2021-06-13 22:39:08] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [571/3125]  acc1 0.000000 (0.079466)  loss 2.302926 (2.302677)
[2021-06-13 22:39:38] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [581/3125]  acc1 0.250000 (0.079174)  loss 2.301330 (2.302686)
[2021-06-13 22:40:08] INFO (nni.retiarii.oneshot.pytorch.darts/MainThr

[2021-06-13 23:00:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1061/3125]  acc1 0.000000 (0.079053)  loss 2.304769 (2.302659)
[2021-06-13 23:00:54] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1071/3125]  acc1 0.125000 (0.079015)  loss 2.299605 (2.302650)
[2021-06-13 23:01:19] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1081/3125]  acc1 0.125000 (0.078747)  loss 2.301654 (2.302655)
[2021-06-13 23:01:44] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1091/3125]  acc1 0.125000 (0.078598)  loss 2.302218 (2.302653)
[2021-06-13 23:02:09] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1101/3125]  acc1 0.000000 (0.078565)  loss 2.303739 (2.302651)
[2021-06-13 23:02:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1111/3125]  acc1 0.250000 (0.078870)  loss 2.302104 (2.302651)
[2021-06-13 23:02:59] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-13 23:22:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1591/3125]  acc1 0.000000 (0.079510)  loss 2.304319 (2.302647)
[2021-06-13 23:22:56] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1601/3125]  acc1 0.125000 (0.079638)  loss 2.303254 (2.302646)
[2021-06-13 23:23:21] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1611/3125]  acc1 0.250000 (0.079609)  loss 2.303160 (2.302647)
[2021-06-13 23:23:46] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1621/3125]  acc1 0.125000 (0.079503)  loss 2.303931 (2.302650)
[2021-06-13 23:24:10] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1631/3125]  acc1 0.375000 (0.079706)  loss 2.300795 (2.302655)
[2021-06-13 23:24:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1641/3125]  acc1 0.125000 (0.079525)  loss 2.302819 (2.302657)
[2021-06-13 23:25:00] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-13 23:44:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2121/3125]  acc1 0.125000 (0.081153)  loss 2.302503 (2.302663)
[2021-06-13 23:44:58] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2131/3125]  acc1 0.000000 (0.081007)  loss 2.303407 (2.302664)
[2021-06-13 23:45:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2141/3125]  acc1 0.000000 (0.081095)  loss 2.303429 (2.302663)
[2021-06-13 23:45:48] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2151/3125]  acc1 0.125000 (0.080951)  loss 2.302589 (2.302662)
[2021-06-13 23:46:13] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2161/3125]  acc1 0.000000 (0.080692)  loss 2.303370 (2.302664)
[2021-06-13 23:46:38] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2171/3125]  acc1 0.125000 (0.080493)  loss 2.299268 (2.302665)
[2021-06-13 23:47:03] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-14 00:06:38] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2651/3125]  acc1 0.125000 (0.080488)  loss 2.301579 (2.302646)
[2021-06-14 00:07:03] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2661/3125]  acc1 0.000000 (0.080468)  loss 2.304271 (2.302646)
[2021-06-14 00:07:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2671/3125]  acc1 0.125000 (0.080541)  loss 2.301756 (2.302644)
[2021-06-14 00:07:53] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2681/3125]  acc1 0.000000 (0.080520)  loss 2.303443 (2.302644)
[2021-06-14 00:08:17] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2691/3125]  acc1 0.125000 (0.080686)  loss 2.303517 (2.302642)
[2021-06-14 00:08:42] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [2701/3125]  acc1 0.000000 (0.080572)  loss 2.302420 (2.302643)
[2021-06-14 00:09:07] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-14 00:28:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [51/3125]  acc1 0.125000 (0.075980)  loss 2.300497 (2.302438)
[2021-06-14 00:28:50] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [61/3125]  acc1 0.125000 (0.071721)  loss 2.303217 (2.302543)
[2021-06-14 00:29:15] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [71/3125]  acc1 0.000000 (0.070423)  loss 2.303904 (2.302561)
[2021-06-14 00:29:40] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [81/3125]  acc1 0.125000 (0.075617)  loss 2.301561 (2.302526)
[2021-06-14 00:30:05] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [91/3125]  acc1 0.250000 (0.076923)  loss 2.304818 (2.302595)
[2021-06-14 00:30:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [101/3125]  acc1 0.000000 (0.077970)  loss 2.301888 (2.302577)
[2021-06-14 00:30:56] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) 

[2021-06-14 00:50:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [581/3125]  acc1 0.000000 (0.090577)  loss 2.304876 (2.302495)
[2021-06-14 00:50:55] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [591/3125]  acc1 0.125000 (0.089679)  loss 2.302621 (2.302503)
[2021-06-14 00:51:20] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [601/3125]  acc1 0.125000 (0.090474)  loss 2.303468 (2.302501)
[2021-06-14 00:51:45] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [611/3125]  acc1 0.000000 (0.090221)  loss 2.302054 (2.302495)
[2021-06-14 00:52:10] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [621/3125]  acc1 0.000000 (0.090378)  loss 2.303878 (2.302496)
[2021-06-14 00:52:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [631/3125]  acc1 0.125000 (0.090927)  loss 2.301956 (2.302488)
[2021-06-14 00:53:00] INFO (nni.retiarii.oneshot.pytorch.darts/MainThr

[2021-06-14 01:12:39] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1111/3125]  acc1 0.125000 (0.094397)  loss 2.301702 (2.302464)
[2021-06-14 01:13:04] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1121/3125]  acc1 0.125000 (0.094558)  loss 2.300599 (2.302467)
[2021-06-14 01:13:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1131/3125]  acc1 0.250000 (0.094717)  loss 2.300587 (2.302469)
[2021-06-14 01:13:54] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1141/3125]  acc1 0.125000 (0.094763)  loss 2.302217 (2.302468)
[2021-06-14 01:14:19] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1151/3125]  acc1 0.125000 (0.095026)  loss 2.302615 (2.302469)
[2021-06-14 01:14:44] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1161/3125]  acc1 0.000000 (0.094746)  loss 2.304230 (2.302471)
[2021-06-14 01:15:09] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-14 01:34:45] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1641/3125]  acc1 0.125000 (0.096511)  loss 2.302208 (2.302465)
[2021-06-14 01:35:10] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1651/3125]  acc1 0.125000 (0.096835)  loss 2.303352 (2.302466)
[2021-06-14 01:35:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1661/3125]  acc1 0.000000 (0.096553)  loss 2.306169 (2.302471)
[2021-06-14 01:36:00] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1671/3125]  acc1 0.000000 (0.096275)  loss 2.302676 (2.302471)
[2021-06-14 01:36:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1681/3125]  acc1 0.000000 (0.095925)  loss 2.305575 (2.302474)
[2021-06-14 01:36:50] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1691/3125]  acc1 0.250000 (0.096245)  loss 2.302801 (2.302473)
[2021-06-14 01:37:15] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-14 01:57:56] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2171/3125]  acc1 0.000000 (0.098457)  loss 2.304977 (2.302463)
[2021-06-14 01:58:21] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2181/3125]  acc1 0.000000 (0.098521)  loss 2.303649 (2.302465)
[2021-06-14 01:58:46] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2191/3125]  acc1 0.000000 (0.098699)  loss 2.302085 (2.302463)
[2021-06-14 01:59:11] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2201/3125]  acc1 0.125000 (0.098819)  loss 2.302430 (2.302463)
[2021-06-14 01:59:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2211/3125]  acc1 0.000000 (0.098654)  loss 2.301126 (2.302462)
[2021-06-14 02:00:01] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2221/3125]  acc1 0.000000 (0.098379)  loss 2.301821 (2.302465)
[2021-06-14 02:00:26] INFO (nni.retiarii.oneshot.pytorch.darts/M

[2021-06-14 02:19:47] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2701/3125]  acc1 0.000000 (0.101120)  loss 2.304972 (2.302443)
[2021-06-14 02:20:12] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2711/3125]  acc1 0.375000 (0.101116)  loss 2.300725 (2.302443)
[2021-06-14 02:20:37] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2721/3125]  acc1 0.000000 (0.101112)  loss 2.302159 (2.302442)
[2021-06-14 02:21:02] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2731/3125]  acc1 0.000000 (0.101108)  loss 2.302120 (2.302441)
[2021-06-14 02:21:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2741/3125]  acc1 0.125000 (0.101104)  loss 2.303539 (2.302441)
[2021-06-14 02:21:52] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [2751/3125]  acc1 0.125000 (0.101054)  loss 2.301830 (2.302441)
[2021-06-14 02:22:17] INFO (nni.retiarii.oneshot.pytorch.darts/M

In [26]:
# Similarly, the optimal structure found can be exported.
print('Final architecture:', trainer.export())

Final architecture: {}


In [17]:
trainer.nas_modules[0][1].op_choices['278']

Sequential(
  (0): Conv2d(3, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(11, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU(inplace=True)
  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (8): Conv2d(11, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ReLU(inplace=True)
  (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (12): Flatten(start_dim=1, end_dim=-1)
  (13): Linear(in_features=256, out_features=32, bias=True)
  (14): ReLU(inplace=True)
  (15): Linear(in_features=32, out_features=10, bias=Tr