<a href="https://colab.research.google.com/github/mamagoudou/QNN-with-dithering/blob/main/VGG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model definition

In [1]:
import torch
import torch.nn as nn

import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

PATH_Models = '/content/drive/MyDrive/Memory/Models/VGG/'
PATH_Measures = '/content/drive/MyDrive/Memory/Measures/VGG/'

In [2]:
# adapted from 
# https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
# paper: https://arxiv.org/pdf/1409.1556.pdf

class VGG(nn.Module):

  def __init__(self, features, classes = 10):

    super(VGG, self).__init__()
    self.features = features
    """ #GENUINE CLASSIFIER
    self.classifier = nn.Sequential(
      nn.Dropout(),
      nn.Linear(512, 512),
      nn.ReLU(True),
      nn.Dropout(),
      nn.Linear(512, 512),
      nn.ReLU(True),
      nn.Linear(512, classes),
    )
    """
    # SIMPLIFIED CLASSIFIER
    self.classifier = nn.Linear(512, 10) 

    # Initialize weights
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        m.bias.data.zero_()

  def forward(self, x):

    x = self.features(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x

In [3]:
def make_layers(cfg, batch_norm=False):

  layers = []
  in_channels = 3
  for v in cfg:
    if v == 'M': # pooling
      layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

    else: # convolution
      conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)

      if batch_norm:
          layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
      else:
          layers += [conv2d, nn.ReLU(inplace=True)]

      in_channels = v
  return nn.Sequential(*layers)

In [4]:
cfg = {
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 
           512, 'M'],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 
           512, 512, 512, 'M'],
    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 
           512, 'M', 512, 512, 512, 512, 'M'],
}

In [5]:
def VGG11():
  return VGG(make_layers(cfg['11']))

def VGG11_bn():
  return VGG(make_layers(cfg['11'], batch_norm=True))

def VGG13():
  return VGG(make_layers(cfg['13']))

def VGG13_bn():
  return VGG(make_layers(cfg['13'], batch_norm=True))

def VGG16():
  return VGG(make_layers(cfg['16']))

def VGG16_bn():
  return VGG(make_layers(cfg['16'], batch_norm=True))

def VGG19():
  return VGG(make_layers(cfg['19']))

def VGG19_bn():
  return VGG(make_layers(cfg['19'], batch_norm=True))

In [6]:
# NAME_DD_MM_TEST
PATH_Name = 'VGG16bn_15_02_TEST'

network = VGG16_bn()

epoch = 0
network.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

# Dataset download and processing

In [7]:
import torchvision
import torchvision.transforms as transforms

In [8]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), 
                                                     (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# Training & validation

In [9]:
BATCH_SIZE = 128
MAX_EPOCHS = 50
NUM_WORKERS = 64

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_WORKERS)

In [18]:
LEARNING_RATE = 0.005
OPTIMIZER = "SGD"
CRITERION = "CrossEntropyLoss"

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=LEARNING_RATE)

In [11]:
import time
import csv
from tqdm.notebook import tqdm


TrainLoss = []
TrainAcc = []
Traintime = []
TestLoss = []
TestAcc = []

for epoch in tqdm(range(epoch, MAX_EPOCHS), position=0, desc="Epoch"):

  print("Epoch: %d" %(epoch))
  # TRAINING
  network.train()
  start_time = time.time()
  train_loss = 0
  correct = 0
  total = 0
  for i, data in tqdm(enumerate(trainloader, 0), position=1, desc="Training", 
                      total=len(trainloader.dataset)/BATCH_SIZE, leave=False):
    
    inputs, labels = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    outputs = network(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += predicted.eq(labels).sum().item()
    end_time = time.time()

  TrainLoss.append(train_loss/(i+1))
  TrainAcc.append(100.*correct/total)
  Traintime.append(end_time-start_time)
  print('TrainLoss: %.3f | TrainAcc: %.3f%% (%d/%d) | Time Elapsed %.3f sec' 
        % (TrainLoss[-1], TrainAcc[-1], correct, total, Traintime[-1]))
  
  # TESTING
  network.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for i, data in tqdm(enumerate(testloader, 0), position=2, desc="Testing", 
                        total=len(testloader.dataset)/BATCH_SIZE, leave=False):
      inputs, labels = data[0].to(device), data[1].to(device)
      outputs = network(inputs)
      loss = criterion(outputs, labels)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

    TestLoss.append(test_loss/(i+1))
    TestAcc.append(100.*correct/total)
    print('TestLoss: %.3f | TestAcc: %.3f%% (%d/%d)' 
          % (TestLoss[-1], TestAcc[-1], correct, total))
    print('-' * 75)
  # SAVE MODEL IF BEST
  if TestAcc[-1] == max(TestAcc):
    torch.save({
      'optimizer': optimizer.state_dict(),
      'network': network.state_dict(),
      'epoch': epoch
    }, PATH_Models + PATH_Name + '.pth')

  if TrainAcc[-1] >= 99.9:
    break


# WRITE INFOS & STATS IN CSV
infos = {"Type": OPTIMIZER, "Optimizer": optimizer.state_dict(),
         "BatchSize": BATCH_SIZE, "Criterion": CRITERION}

stats = {"TrainLoss": TrainLoss, "TrainAcc": TrainAcc, "Traintime": Traintime,
         "TestLoss": TestLoss, "TestAcc": TestAcc}

with open(PATH_Measures + PATH_Name + ".csv", "w") as f:
  writer = csv.writer(f)
  writer.writerow(stats.keys())
  writer.writerows(zip(*stats.values()))

with open(PATH_Measures + PATH_Name + "_infos.csv", "w") as f:
  writer = csv.DictWriter(f, fieldnames=infos.keys())
  writer.writeheader()
  writer.writerow(infos)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=50.0, style=ProgressStyle(description_width='…

Epoch: 0


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.581 | TrainAcc: 42.124% (21062/50000) | Time Elapsed 199.823 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.337 | TestAcc: 50.510% (5051/10000)
---------------------------------------------------------------------------
Epoch: 1


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.186 | TrainAcc: 56.972% (28486/50000) | Time Elapsed 198.640 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.286 | TestAcc: 53.940% (5394/10000)
---------------------------------------------------------------------------
Epoch: 2


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.990 | TrainAcc: 64.796% (32398/50000) | Time Elapsed 197.860 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.056 | TestAcc: 62.580% (6258/10000)
---------------------------------------------------------------------------
Epoch: 3


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.822 | TrainAcc: 71.160% (35580/50000) | Time Elapsed 198.481 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.032 | TestAcc: 63.970% (6397/10000)
---------------------------------------------------------------------------
Epoch: 4


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.672 | TrainAcc: 76.588% (38294/50000) | Time Elapsed 195.350 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.010 | TestAcc: 64.980% (6498/10000)
---------------------------------------------------------------------------
Epoch: 5


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.526 | TrainAcc: 82.372% (41186/50000) | Time Elapsed 198.589 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.978 | TestAcc: 66.940% (6694/10000)
---------------------------------------------------------------------------
Epoch: 6


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.382 | TrainAcc: 87.828% (43914/50000) | Time Elapsed 198.624 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.100 | TestAcc: 64.880% (6488/10000)
---------------------------------------------------------------------------
Epoch: 7


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.246 | TrainAcc: 92.886% (46443/50000) | Time Elapsed 198.661 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.050 | TestAcc: 67.540% (6754/10000)
---------------------------------------------------------------------------
Epoch: 8


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.136 | TrainAcc: 97.012% (48506/50000) | Time Elapsed 199.416 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.148 | TestAcc: 67.390% (6739/10000)
---------------------------------------------------------------------------
Epoch: 9


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.067 | TrainAcc: 99.062% (49531/50000) | Time Elapsed 199.683 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.205 | TestAcc: 67.690% (6769/10000)
---------------------------------------------------------------------------
Epoch: 10


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.033 | TrainAcc: 99.758% (49879/50000) | Time Elapsed 198.534 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.164 | TestAcc: 69.480% (6948/10000)
---------------------------------------------------------------------------
Epoch: 11


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.019 | TrainAcc: 99.936% (49968/50000) | Time Elapsed 199.440 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.201 | TestAcc: 69.230% (6923/10000)
---------------------------------------------------------------------------
Epoch: 12


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.013 | TrainAcc: 99.952% (49976/50000) | Time Elapsed 199.286 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.233 | TestAcc: 69.360% (6936/10000)
---------------------------------------------------------------------------
Epoch: 13


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.010 | TrainAcc: 99.978% (49989/50000) | Time Elapsed 198.892 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.241 | TestAcc: 69.400% (6940/10000)
---------------------------------------------------------------------------
Epoch: 14


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.007 | TrainAcc: 99.994% (49997/50000) | Time Elapsed 199.665 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.271 | TestAcc: 69.450% (6945/10000)
---------------------------------------------------------------------------
Epoch: 15


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.006 | TrainAcc: 99.990% (49995/50000) | Time Elapsed 198.709 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.285 | TestAcc: 69.850% (6985/10000)
---------------------------------------------------------------------------
Epoch: 16


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.006 | TrainAcc: 99.988% (49994/50000) | Time Elapsed 197.845 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.288 | TestAcc: 69.480% (6948/10000)
---------------------------------------------------------------------------
Epoch: 17


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.005 | TrainAcc: 99.998% (49999/50000) | Time Elapsed 200.342 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.309 | TestAcc: 69.470% (6947/10000)
---------------------------------------------------------------------------
Epoch: 18


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.004 | TrainAcc: 99.994% (49997/50000) | Time Elapsed 200.652 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.321 | TestAcc: 69.570% (6957/10000)
---------------------------------------------------------------------------
Epoch: 19


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.004 | TrainAcc: 99.994% (49997/50000) | Time Elapsed 199.932 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.385 | TestAcc: 68.840% (6884/10000)
---------------------------------------------------------------------------
Epoch: 20


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.003 | TrainAcc: 99.994% (49997/50000) | Time Elapsed 199.996 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.345 | TestAcc: 69.360% (6936/10000)
---------------------------------------------------------------------------
Epoch: 21


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.003 | TrainAcc: 100.000% (50000/50000) | Time Elapsed 200.071 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.348 | TestAcc: 69.420% (6942/10000)
---------------------------------------------------------------------------
Epoch: 22


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.003 | TrainAcc: 99.998% (49999/50000) | Time Elapsed 198.554 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.354 | TestAcc: 69.700% (6970/10000)
---------------------------------------------------------------------------
Epoch: 23


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.002 | TrainAcc: 100.000% (50000/50000) | Time Elapsed 200.112 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.370 | TestAcc: 69.700% (6970/10000)
---------------------------------------------------------------------------
Epoch: 24


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.003 | TrainAcc: 99.990% (49995/50000) | Time Elapsed 199.863 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.381 | TestAcc: 69.570% (6957/10000)
---------------------------------------------------------------------------
Epoch: 25


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.002 | TrainAcc: 99.996% (49998/50000) | Time Elapsed 199.601 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.384 | TestAcc: 69.600% (6960/10000)
---------------------------------------------------------------------------
Epoch: 26


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

KeyboardInterrupt: ignored

In [26]:
# WRITE INFOS & STATS IN CSV
infos = {"Type": OPTIMIZER, "Optimizer": optimizer.state_dict(),
         "BatchSize": BATCH_SIZE, "Criterion": CRITERION}

stats = {"TrainLoss": TrainLoss, "TrainAcc": TrainAcc, "Traintime": Traintime,
         "TestLoss": TestLoss, "TestAcc": TestAcc}

with open(PATH_Measures + PATH_Name + ".csv", "w") as f:
  writer = csv.writer(f)
  writer.writerow(stats.keys())
  writer.writerows(zip(*stats.values()))

with open(PATH_Measures + PATH_Name + "_infos.csv", "w") as f:
  writer = csv.DictWriter(f, fieldnames=infos.keys())
  writer.writeheader()
  writer.writerow(infos)

# Drafts

In [None]:
from tqdm.notebook import tqdm


network.train()
for epoch in tqdm(range(epoch, MAX_EPOCHS), position=0, desc="Epoch"):
  for i, data in tqdm(enumerate(trainloader, 0), position=1, desc="Batch", 
                      total=len(trainloader.dataset)/BATCH_SIZE, leave=True):
  
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data[0].to(device), data[1].to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = network(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

epoch += 1


PATH = '/content/drive/MyDrive/Colab Notebooks/Memory/VGG11_bn.pth'
# Save the state of the training
torch.save({
    'optimizer': optimizer.state_dict(),
    'network': network.state_dict(),
    'epoch': epoch,
}, PATH)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))




In [None]:
network.eval()

correct = 0
total = 0
with torch.no_grad():
  for data in testloader:
    inputs, labels = data[0].to(device), data[1].to(device)
    outputs = network(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
      100 * correct / total))

Accuracy of the network on the 10000 test images: 81 %


In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
  for data in testloader:
    inputs, labels = data[0].to(device), data[1].to(device)
    outputs = network(inputs)
    _, predicted = torch.max(outputs, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):
      label = labels[i]
      class_correct[label] += c[i].item()
      class_total[label] += 1


for i in range(10):
  print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane : 82 %
Accuracy of   car : 92 %
Accuracy of  bird : 75 %
Accuracy of   cat : 69 %
Accuracy of  deer : 85 %
Accuracy of   dog : 61 %
Accuracy of  frog : 89 %
Accuracy of horse : 78 %
Accuracy of  ship : 87 %
Accuracy of truck : 89 %


In [13]:
print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.005, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53]}]}
