Code based on:
* https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
* https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

In [None]:
import os
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import data_utils.data_loading as data_load
from data_utils.cifar10_testset import CIFAR10Testset
from custom_nets.vgglike import VGGLike
from custom_nets.resnetlike import ResNetLike, ResBlock
from data_utils.data_preparation import Cutout

random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)

In [None]:
class NetworkTrainer:

    def __init__(self, batch_size: int, device: str, load_train_from_torch=False):
        
        self.batch_size = batch_size
        self.device = device
        self.trainloader, self.valloader = self.load_data(load_from_torch=load_train_from_torch)
        if not os.path.exists('../nets'):
            os.makedirs('../nets')
        if not os.path.exists('../loss'):
            os.makedirs('../loss')
        if not os.path.exists('../accuracy'):
            os.makedirs('../accuracy')
  
    def load_data(self, load_from_torch):
        
        if not load_from_torch:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            return data_load.load_train_data(transform, self.batch_size)
        else:
            transform = transforms.Compose([
                transforms.Pad(4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32),
                transforms.ToTensor(),
                Cutout(2, 10, p=0.4)])
            return data_load.load_torch_train_data(transform, self.batch_size)
        
    def train_one_epoch(self, network: nn.Module, optimizer, loss_criterion, epoch: int):
        
        running_loss = 0.0
        last_loss = 0.0

        running_accuracy = 0.0
        last_accuracy = 0.0

        for i, data in enumerate(self.trainloader):
            
            inputs, labels = data[0].to(self.device), data[1].to(self.device)

            optimizer.zero_grad()

            outputs = network(inputs)
            _, predicted = torch.max(outputs.data, 1)
            loss = loss_criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_accuracy += (labels == predicted).sum().item() / len(labels)

            del inputs, labels, outputs

            if i % 200 == 199:
                last_loss = running_loss / 200
                last_accuracy = running_accuracy / 200
                print(f'[epoch: {epoch + 1}, batches: {i - 198:5d} - {i + 1:5d}] train loss: {last_loss:.3f}, train accuracy: {last_accuracy:.3f}')
                running_loss = 0.0
                running_accuracy = 0.0

        return last_loss, last_accuracy


    def train_network(self, network: nn.Module, optimizer, loss_criterion, number_of_epochs: int, name: str):
        
        network.to(self.device)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        best_networkstate_path = f'../nets/cifar_{name}_{format(timestamp)}'
        best_vloss = 1_000_000.
        loss = np.empty((2, number_of_epochs))
        accuracy = np.empty((2, number_of_epochs))

        for epoch in range(number_of_epochs):
            loss[0, epoch], accuracy[0, epoch] = self.train_one_epoch(network, optimizer, loss_criterion, epoch)

            with torch.no_grad():
                running_vloss = 0.0
                running_vaccuracy = 0.0
                for i, vdata in enumerate(self.valloader):
                    vinputs, vlabels = vdata[0].to(self.device), vdata[1].to(self.device)
                    voutputs = network(vinputs)
                    _, vpredicted = torch.max(voutputs.data, 1)
                    vloss = loss_criterion(voutputs, vlabels)
                    running_vloss += vloss
                    running_vaccuracy += (vlabels == vpredicted).sum().item() / len(vlabels)
                    del vinputs, vlabels, voutputs

            loss[1, epoch] = avg_vloss = running_vloss / (i + 1)
            accuracy[1, epoch] = avg_vaccuracy = running_vaccuracy / (i + 1)
            print(f'[epoch: {epoch + 1}] validation loss: {avg_vloss:.3f}, validation accuracy: {avg_vaccuracy:.3f}')
            np.savetxt(f'../loss/cifar_{name}_{format(timestamp)}.csv', loss[:, :(epoch + 1)], delimiter=',')
            np.savetxt(f'../accuracy/cifar_{name}_{format(timestamp)}.csv', accuracy[:, :(epoch + 1)], delimiter=',')

            if avg_vloss < best_vloss:
                best_vloss = avg_vloss
                torch.save(network.state_dict(), best_networkstate_path)

        print('Finished Training')
        self.visualize_loss(number_of_epochs, loss)
        self.visualize_accuracy(number_of_epochs, accuracy)
        return best_networkstate_path


    def visualize_loss(self, number_of_epochs, loss):
        
        plt.plot(range(1, number_of_epochs + 1), loss[0,:], marker='o')
        plt.plot(range(1, number_of_epochs + 1), loss[1,:], marker='o')
        plt.legend(['train', 'validation'])
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.xticks(range(1, number_of_epochs + 1))
        plt.show()


    def visualize_accuracy(self, number_of_epochs, accuracy):
        
        plt.plot(range(1, number_of_epochs + 1), accuracy[0,:], marker='o')
        plt.plot(range(1, number_of_epochs + 1), accuracy[1,:], marker='o')
        plt.legend(['train', 'validation'])
        plt.xlabel('epoch')
        plt.ylabel('accuracy')
        plt.xticks(range(1, number_of_epochs + 1))
        plt.show()


In [None]:
class NetworkTester:
    def __init__(self, batch_size: int, device: str):  
        self.batch_size = batch_size
        self.device = device
        self.testloader, self.number_of_images = self.load_data()
        if not os.path.exists('../tests'):
            os.makedirs('../tests')

    def load_data(self):
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        return data_load.load_test_data(transform, self.batch_size)
    
    def test_network(self, network: nn.Module, best_networkstate_path: str):
        
        classes = np.empty((self.number_of_images), dtype=object)
        network.load_state_dict(torch.load(best_networkstate_path))
        network.to(self.device)
        total = 0

        with torch.no_grad():
            for data in self.testloader:
                images = data[0].to(self.device)
                outputs = network(images)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.tolist()
                for predicted_label in predicted:
                    classes[total] = CIFAR10Testset.label_number_to_class[predicted_label]
                    total += 1

        pd.DataFrame(classes, index=range(1, total + 1), columns=['label']).to_csv(best_networkstate_path.replace('nets', 'tests') + '.csv', index_label='id')

In [None]:
class Hyperparams():
  def __init__(self, learning_rate, optimizer_name, weight_decay, dropout_p=0):
    self.learning_rate = learning_rate
    self.optimizer_name = optimizer_name
    self.weight_decay = weight_decay
    self.dropout_p = dropout_p

  def get_optimizer(self, network):
    if(self.optimizer_name == 'ADAM'):
      return optim.Adam(network.parameters(), lr = self.learning_rate, weight_decay=self.weight_decay)
    elif(self.optimizer_name == 'SGD'):
      return optim.SGD(network.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

  def get_network_params_name(self, network):
    return f'{network.name}_lr_{self.learning_rate}_o_{self.optimizer_name}_wd_{self.weight_decay}_d_{self.dropout_p}'

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## VGGlike custom network training and tuning

In [None]:
number_of_epochs = 20

hyperparams = Hyperparams(
    learning_rate=0.001,
    optimizer_name='ADAM',
    weight_decay=0,
    dropout_p=0.3
)
batch_size = 32

trainer = NetworkTrainer(batch_size, device)
network = VGGLike(dropout_p=hyperparams.dropout_p)
optimizer = hyperparams.get_optimizer(network)
save_name = hyperparams.get_network_params_name(network)

best_state_path = trainer.train_network(network, optimizer, nn.CrossEntropyLoss(), number_of_epochs, save_name)

tester = NetworkTester(batch_size, device)
tester.test_network(VGGLike(dropout_p=hyperparams.dropout_p), best_state_path)

## ResNetlike custom network training and tuning

In [None]:
number_of_epochs = 20

hyperparams = Hyperparams(
    learning_rate=0.001,
    optimizer_name='ADAM',
    weight_decay=0.001
)
batch_size = 32

trainer = NetworkTrainer(batch_size, device, load_train_from_torch=False)
network = ResNetLike(ResBlock, [2, 2, 2]).to(device)
optimizer = hyperparams.get_optimizer(network)
save_name = hyperparams.get_network_params_name(network)

best_state_path = trainer.train_network(network, optimizer, nn.CrossEntropyLoss(), number_of_epochs, save_name)

tester = NetworkTester(batch_size, device)
tester.test_network(ResNetLike(ResBlock, [2, 2, 2]).to(device), best_state_path)