In [13]:
import torch
import torchsummary
import torchvision as tv

In [14]:
set0 = tv.datasets.CIFAR10("../../CIFAR10", train = True, download = True,
                           transform = tv.transforms.ToTensor())
set1 = tv.datasets.CIFAR10("../../CIFAR10", train = False, download = True,
                           transform = tv.transforms.ToTensor())
loader0 = torch.utils.data.DataLoader(set0, batch_size = 100, shuffle = True)
loader1 = torch.utils.data.DataLoader(set1, batch_size = 100)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
model = torch.nn.Sequential(
    torch.nn.BatchNorm2d(3),
    torch.nn.Conv2d(3, 32, 3, padding=1), #32
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 32, 3, padding=1), #32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #16
    
    
    torch.nn.BatchNorm2d(32),
    torch.nn.Dropout(0.2),
    torch.nn.Conv2d(32, 64, 3, padding=1), #32
    torch.nn.ReLU(),
    torch.nn.Conv2d(64, 64, 3, padding=1), #32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #8
    
    
    torch.nn.BatchNorm2d(64),
    torch.nn.Dropout(0.3),
    torch.nn.Conv2d(64, 128, 3, padding=1), #32
    torch.nn.ReLU(),
    torch.nn.Conv2d(128, 128, 3, padding=1), #32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #4
    
    
    torch.nn.BatchNorm2d(128),
    torch.nn.Dropout(0.4),
    torch.nn.Flatten(),
    torch.nn.Linear(128 * 4 * 4, 128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 10)).cuda()
#model.load_state_dict(torch.load("dictionary.pt"))

torchsummary.summary(model, input_size = (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1            [-1, 3, 32, 32]               6
            Conv2d-2           [-1, 32, 32, 32]             896
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
              ReLU-5           [-1, 32, 32, 32]               0
         MaxPool2d-6           [-1, 32, 16, 16]               0
       BatchNorm2d-7           [-1, 32, 16, 16]              64
           Dropout-8           [-1, 32, 16, 16]               0
            Conv2d-9           [-1, 64, 16, 16]          18,496
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
             ReLU-12           [-1, 64, 16, 16]               0
        MaxPool2d-13             [-1, 64, 8, 8]               0
      BatchNorm2d-14             [-1, 6

In [None]:
accuracy0, accuracy1 = 0., 0.
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(1000):
    model.train()
    LOSS0 = torch.zeros((), device = "cuda")
    ACCURACY0 = torch.zeros((), device = "cuda")
    count0 = 0
    for DATA, TARGET in loader0:
        optimizer.zero_grad()
        DATA = DATA.cuda()
        TARGET = TARGET.cuda()
        count = TARGET.size(0)
        ACTIVATION = model(DATA)
        LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET)
        LOSS0 += LOSS * count
        VALUE = ACTIVATION.argmax(1)
        ACCURACY0 += torch.eq(VALUE, TARGET).sum()
        count0 += count
        LOSS.backward()
        optimizer.step()
    LOSS0 /= count0
    ACCURACY0 /= count0
    with torch.no_grad():
        model.eval()
        LOSS1 = torch.zeros((), device = "cuda")
        ACCURACY1 = torch.zeros((), device = "cuda")
        count1 = 0
        for DATA, TARGET in loader1:
            DATA = DATA.cuda()
            TARGET = TARGET.cuda()
            ACTIVATION = model(DATA)
            LOSS1 += torch.nn.functional.cross_entropy(ACTIVATION, TARGET,
                                                       reduction = "sum")
            VALUE = ACTIVATION.argmax(1)
            ACCURACY1 += torch.eq(VALUE, TARGET).sum()
            count1 += TARGET.size(0)
        LOSS1 /= count1
        ACCURACY1 /= count1
    if accuracy1 < ACCURACY1.item():
        accuracy0, accuracy1 = ACCURACY0.item(), ACCURACY1.item()
        torch.save(model.state_dict(), "dictionary.pt")
    print("%4d %12.3f %4.3f %4.3f %12.3f %4.3f %4.3f" % \
          (epoch, LOSS0, ACCURACY0, accuracy0, LOSS1, ACCURACY1, accuracy1), flush = True)