<a href="https://colab.research.google.com/github/mamagoudou/QNN-with-dithering/blob/main/GoogLeNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model definition

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

PATH_Models = '/content/drive/MyDrive/Memory/Models/GoogLeNet/'
PATH_Measures = '/content/drive/MyDrive/Memory/Measures/GoogLeNet/'

## Inception module

In [2]:
# adapted from 
# https://github.com/Ksuryateja/pytorch-cifar10/blob/master/models/googlenet.py
# paper: https://arxiv.org/pdf/1409.4842.pdf

class Inception(nn.Module):

  def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = nn.Sequential(
      nn.Conv2d(in_planes, n1x1, kernel_size=1),
      nn.BatchNorm2d(n1x1),
      nn.ReLU(True),
    )

    # 1x1 conv -> 3x3 conv branch
    self.b2 = nn.Sequential(
      nn.Conv2d(in_planes, n3x3red, kernel_size=1),
      nn.BatchNorm2d(n3x3red),
      nn.ReLU(True),
      nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
      nn.BatchNorm2d(n3x3),
      nn.ReLU(True),
    )

    # 1x1 conv -> 5x5 conv branch
    self.b3 = nn.Sequential(
      nn.Conv2d(in_planes, n5x5red, kernel_size=1),
      nn.BatchNorm2d(n5x5red),
      nn.ReLU(True),
      nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
      nn.BatchNorm2d(n5x5),
      nn.ReLU(True),
      nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
      nn.BatchNorm2d(n5x5),
      nn.ReLU(True),
    )

    # 3x3 pool -> 1x1 conv branch
    self.b4 = nn.Sequential(
      nn.MaxPool2d(3, stride=1, padding=1),
      nn.Conv2d(in_planes, pool_planes, kernel_size=1),
      nn.BatchNorm2d(pool_planes),
      nn.ReLU(True),
    )

    # Initialize weights TOTEST
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        m.bias.data.zero_()

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2(x)
    y3 = self.b3(x)
    y4 = self.b4(x)
    return torch.cat([y1,y2,y3,y4], 1)

## GoogLeNet module

In [3]:
class GoogLeNet(nn.Module):

  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = nn.Sequential(
        nn.Conv2d(3, 192, kernel_size=3, padding=1),
        nn.BatchNorm2d(192),
        nn.ReLU(True),
    )

    self.a3 = Inception(192,  64,  96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192,  96, 208, 16,  48,  64)
    self.b4 = Inception(512, 160, 112, 224, 24,  64,  64)
    self.c4 = Inception(512, 128, 128, 256, 24,  64,  64)
    self.d4 = Inception(512, 112, 144, 288, 32,  64,  64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

    # Initialize weights TOTEST
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        print("test")
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        m.bias.data.zero_()

  def forward(self, x):
    x = self.pre_layers(x)
    x = self.a3(x)
    x = self.b3(x)
    x = self.maxpool(x)
    x = self.a4(x)
    x = self.b4(x)
    x = self.c4(x)
    x = self.d4(x)
    x = self.e4(x)
    x = self.maxpool(x)
    x = self.a5(x)
    x = self.b5(x)
    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.linear(x)

    return x

In [4]:
# NAME_DD_MM_TEST
PATH_Name = 'GoogLeNet_15_02_TEST'

network = GoogLeNet()

epoch = 0
network.to(device)

GoogLeNet(
  (pre_layers): Sequential(
    (0): Conv2d(3, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (b3): Sequential(
      (0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, m

# Dataset download and processing

In [5]:
import torchvision
import torchvision.transforms as transforms

In [6]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), 
                                                     (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# Training & validation

In [7]:
BATCH_SIZE = 64
NUM_WORKERS = 64
MAX_EPOCHS = 50

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_WORKERS)

In [8]:
OPTIMIZER = "SGD"
LEARNING_RATE = 0.003
CRITERION = "CrossEntropyLoss"

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=LEARNING_RATE)

In [None]:
import time
import csv
from tqdm.notebook import tqdm


TrainLoss = []
TrainAcc = []
Traintime = []
TestLoss = []
TestAcc = []

for epoch in tqdm(range(epoch, MAX_EPOCHS), position=0, desc="Epoch"):

  print("Epoch: %d" %(epoch))
  # TRAINING
  network.train()
  start_time = time.time()
  train_loss = 0
  correct = 0
  total = 0
  for i, data in tqdm(enumerate(trainloader, 0), position=1, desc="Training", 
                      total=len(trainloader.dataset)/BATCH_SIZE, leave=False):
    
    inputs, labels = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    outputs = network(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += predicted.eq(labels).sum().item()
    end_time = time.time()

  TrainLoss.append(train_loss/(i+1))
  TrainAcc.append(100.*correct/total)
  Traintime.append(end_time-start_time)
  print('TrainLoss: %.3f | TrainAcc: %.3f%% (%d/%d) | Time Elapsed %.3f sec' 
        % (TrainLoss[-1], TrainAcc[-1], correct, total, Traintime[-1]))
  
  # TESTING
  network.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for i, data in tqdm(enumerate(testloader, 0), position=2, desc="Testing", 
                        total=len(testloader.dataset)/BATCH_SIZE, leave=False):
      inputs, labels = data[0].to(device), data[1].to(device)
      outputs = network(inputs)
      loss = criterion(outputs, labels)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

    TestLoss.append(test_loss/(i+1))
    TestAcc.append(100.*correct/total)
    print('TestLoss: %.3f | TestAcc: %.3f%% (%d/%d)' 
          % (TestLoss[-1], TestAcc[-1], correct, total))
    print('-' * 75)
  # SAVE MODEL IF BEST
  if TestAcc[-1] == max(TestAcc):
    torch.save({
      'optimizer': optimizer.state_dict(),
      'network': network.state_dict(),
      'epoch': epoch
    }, PATH_Models + PATH_Name + '.pth')

  if TrainAcc[-1] >= 99.9:
    break


# WRITE INFOS & STATS IN CSV
stats = {"TrainLoss": TrainLoss, "TrainAcc": TrainAcc, "Traintime": Traintime,
         "TestLoss": TestLoss, "TestAcc": TestAcc}

with open(PATH_Measures + PATH_Name + ".csv", "w") as f:
  writer = csv.writer(f)
  writer.writerow(stats.keys())
  writer.writerows(zip(*stats.values()))

infos = {"PATH_Name":PATH_Name,"BATCH_SIZE":BATCH_SIZE,"MAX_EPOCHS":MAX_EPOCHS,
         "NUM_WORKERS":NUM_WORKERS,"OPTIMIZER":OPTIMIZER,
         "LEARNING_RATE":LEARNING_RATE,"CRITERION":CRITERION,"OptimizerState":
         optimizer.state_dict(),"epoch":epoch}

with open(PATH_Measures + PATH_Name + "_infos.csv", "w") as f:
  writer = csv.DictWriter(f, fieldnames=infos.keys())
  writer.writeheader()
  writer.writerow(infos)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=50.0, style=ProgressStyle(description_width='…

Epoch: 0












HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…






TrainLoss: 2.015 | TrainAcc: 27.802% (13901/50000) | Time Elapsed 192.200 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.779 | TestAcc: 36.130% (3613/10000)
---------------------------------------------------------------------------
Epoch: 1


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.669 | TrainAcc: 39.558% (19779/50000) | Time Elapsed 202.280 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.565 | TestAcc: 43.470% (4347/10000)
---------------------------------------------------------------------------
Epoch: 2


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.504 | TrainAcc: 45.638% (22819/50000) | Time Elapsed 204.633 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.433 | TestAcc: 48.170% (4817/10000)
---------------------------------------------------------------------------
Epoch: 3


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.377 | TrainAcc: 50.304% (25152/50000) | Time Elapsed 204.691 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.336 | TestAcc: 51.690% (5169/10000)
---------------------------------------------------------------------------
Epoch: 4


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.268 | TrainAcc: 54.826% (27413/50000) | Time Elapsed 204.570 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.241 | TestAcc: 55.440% (5544/10000)
---------------------------------------------------------------------------
Epoch: 5


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.171 | TrainAcc: 58.292% (29146/50000) | Time Elapsed 204.698 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.192 | TestAcc: 57.630% (5763/10000)
---------------------------------------------------------------------------
Epoch: 6


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.086 | TrainAcc: 61.486% (30743/50000) | Time Elapsed 204.842 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.119 | TestAcc: 60.230% (6023/10000)
---------------------------------------------------------------------------
Epoch: 7


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 1.020 | TrainAcc: 63.636% (31818/50000) | Time Elapsed 205.217 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.067 | TestAcc: 62.070% (6207/10000)
---------------------------------------------------------------------------
Epoch: 8


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.959 | TrainAcc: 66.148% (33074/50000) | Time Elapsed 204.929 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.043 | TestAcc: 62.760% (6276/10000)
---------------------------------------------------------------------------
Epoch: 9


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.904 | TrainAcc: 68.036% (34018/50000) | Time Elapsed 204.941 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.045 | TestAcc: 63.120% (6312/10000)
---------------------------------------------------------------------------
Epoch: 10


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.856 | TrainAcc: 69.646% (34823/50000) | Time Elapsed 204.962 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.984 | TestAcc: 64.950% (6495/10000)
---------------------------------------------------------------------------
Epoch: 11


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.809 | TrainAcc: 71.398% (35699/50000) | Time Elapsed 204.396 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.966 | TestAcc: 65.950% (6595/10000)
---------------------------------------------------------------------------
Epoch: 12


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.762 | TrainAcc: 73.198% (36599/50000) | Time Elapsed 204.911 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.931 | TestAcc: 67.150% (6715/10000)
---------------------------------------------------------------------------
Epoch: 13


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.720 | TrainAcc: 74.696% (37348/50000) | Time Elapsed 204.705 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.948 | TestAcc: 66.900% (6690/10000)
---------------------------------------------------------------------------
Epoch: 14


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.675 | TrainAcc: 76.642% (38321/50000) | Time Elapsed 204.801 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.923 | TestAcc: 67.560% (6756/10000)
---------------------------------------------------------------------------
Epoch: 15


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.636 | TrainAcc: 77.930% (38965/50000) | Time Elapsed 204.881 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.996 | TestAcc: 66.530% (6653/10000)
---------------------------------------------------------------------------
Epoch: 16


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.595 | TrainAcc: 79.672% (39836/50000) | Time Elapsed 205.130 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.036 | TestAcc: 65.220% (6522/10000)
---------------------------------------------------------------------------
Epoch: 17


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.555 | TrainAcc: 80.834% (40417/50000) | Time Elapsed 205.006 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.949 | TestAcc: 68.240% (6824/10000)
---------------------------------------------------------------------------
Epoch: 18


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.516 | TrainAcc: 82.508% (41254/50000) | Time Elapsed 204.968 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.955 | TestAcc: 67.830% (6783/10000)
---------------------------------------------------------------------------
Epoch: 19


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.478 | TrainAcc: 83.810% (41905/50000) | Time Elapsed 206.625 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.933 | TestAcc: 68.740% (6874/10000)
---------------------------------------------------------------------------
Epoch: 20


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.440 | TrainAcc: 85.346% (42673/50000) | Time Elapsed 204.660 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.948 | TestAcc: 68.640% (6864/10000)
---------------------------------------------------------------------------
Epoch: 21


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.404 | TrainAcc: 86.606% (43303/50000) | Time Elapsed 204.559 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.951 | TestAcc: 68.790% (6879/10000)
---------------------------------------------------------------------------
Epoch: 22


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.372 | TrainAcc: 87.748% (43874/50000) | Time Elapsed 204.524 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.099 | TestAcc: 65.560% (6556/10000)
---------------------------------------------------------------------------
Epoch: 23


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.337 | TrainAcc: 89.072% (44536/50000) | Time Elapsed 204.898 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.968 | TestAcc: 68.800% (6880/10000)
---------------------------------------------------------------------------
Epoch: 24


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.302 | TrainAcc: 90.470% (45235/50000) | Time Elapsed 204.828 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 0.986 | TestAcc: 69.740% (6974/10000)
---------------------------------------------------------------------------
Epoch: 25


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.272 | TrainAcc: 91.554% (45777/50000) | Time Elapsed 204.393 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.088 | TestAcc: 67.540% (6754/10000)
---------------------------------------------------------------------------
Epoch: 26


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.244 | TrainAcc: 92.556% (46278/50000) | Time Elapsed 204.646 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.057 | TestAcc: 68.500% (6850/10000)
---------------------------------------------------------------------------
Epoch: 27


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.210 | TrainAcc: 93.826% (46913/50000) | Time Elapsed 205.472 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.156 | TestAcc: 67.120% (6712/10000)
---------------------------------------------------------------------------
Epoch: 28


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.192 | TrainAcc: 94.324% (47162/50000) | Time Elapsed 204.342 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.106 | TestAcc: 68.240% (6824/10000)
---------------------------------------------------------------------------
Epoch: 29


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.169 | TrainAcc: 95.270% (47635/50000) | Time Elapsed 204.769 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.249 | TestAcc: 65.230% (6523/10000)
---------------------------------------------------------------------------
Epoch: 30


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.140 | TrainAcc: 96.298% (48149/50000) | Time Elapsed 205.155 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.077 | TestAcc: 69.360% (6936/10000)
---------------------------------------------------------------------------
Epoch: 31


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.125 | TrainAcc: 96.844% (48422/50000) | Time Elapsed 204.697 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.174 | TestAcc: 67.580% (6758/10000)
---------------------------------------------------------------------------
Epoch: 32


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.103 | TrainAcc: 97.600% (48800/50000) | Time Elapsed 204.545 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.288 | TestAcc: 66.600% (6660/10000)
---------------------------------------------------------------------------
Epoch: 33


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.094 | TrainAcc: 97.824% (48912/50000) | Time Elapsed 204.520 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.304 | TestAcc: 67.190% (6719/10000)
---------------------------------------------------------------------------
Epoch: 34


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.082 | TrainAcc: 98.232% (49116/50000) | Time Elapsed 204.531 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.153 | TestAcc: 69.800% (6980/10000)
---------------------------------------------------------------------------
Epoch: 35


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.071 | TrainAcc: 98.552% (49276/50000) | Time Elapsed 204.659 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.278 | TestAcc: 68.280% (6828/10000)
---------------------------------------------------------------------------
Epoch: 36


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.066 | TrainAcc: 98.656% (49328/50000) | Time Elapsed 205.055 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.172 | TestAcc: 69.730% (6973/10000)
---------------------------------------------------------------------------
Epoch: 37


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.055 | TrainAcc: 98.990% (49495/50000) | Time Elapsed 206.203 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.191 | TestAcc: 69.830% (6983/10000)
---------------------------------------------------------------------------
Epoch: 38


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

TrainLoss: 0.048 | TrainAcc: 99.208% (49604/50000) | Time Elapsed 204.467 sec


HBox(children=(FloatProgress(value=0.0, description='Testing', max=78.125, style=ProgressStyle(description_wid…

TestLoss: 1.170 | TestAcc: 70.590% (7059/10000)
---------------------------------------------------------------------------
Epoch: 39


HBox(children=(FloatProgress(value=0.0, description='Training', max=390.625, style=ProgressStyle(description_w…

In [None]:
# WRITE INFOS & STATS IN CSV
stats = {"TrainLoss": TrainLoss, "TrainAcc": TrainAcc, "Traintime": Traintime,
         "TestLoss": TestLoss, "TestAcc": TestAcc}

with open(PATH_Measures + PATH_Name + ".csv", "w") as f:
  writer = csv.writer(f)
  writer.writerow(stats.keys())
  writer.writerows(zip(*stats.values()))

infos = {"PATH_Name":PATH_Name,"BATCH_SIZE":BATCH_SIZE,"MAX_EPOCHS":MAX_EPOCHS,
         "NUM_WORKERS":NUM_WORKERS,"OPTIMIZER":OPTIMIZER,
         "LEARNING_RATE":LEARNING_RATE,"CRITERION":CRITERION,"OptimizerState":
         optimizer.state_dict(),"epoch":epoch}

with open(PATH_Measures + PATH_Name + "_infos.csv", "w") as f:
  writer = csv.DictWriter(f, fieldnames=infos.keys())
  writer.writeheader()
  writer.writerow(infos)

# Drafts

In [None]:
# test inspiration: https://github.com/Ksuryateja/pytorch-cifar10/blob/master/cifar10.py