In [None]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
import pandas as pd
from data_utils.cifar10_testset import CIFAR10Testset
import data_utils.data_loading as data_load
from custom_nets.vgglike import VGGLike
import random
from torch.utils.tensorboard import SummaryWriter

random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)

In [None]:
class NetworkHyperparamsTuner():

    MAX_EPOCHS = 2
    GPUs_COUNT = 0
    CPUs_COUNT = 4

    @staticmethod
    def tune_network(train_function, tune_config: dict, device: str, num_samples: int):
        
            scheduler = ASHAScheduler(
                metric='loss',
                mode='min',
                max_t=NetworkHyperparamsTuner.MAX_EPOCHS,
                grace_period=1,
                reduction_factor=2
            )

            reporter = CLIReporter(
                metric_columns=['loss', 'accuracy', 'training_iteration']
            )

            results = tune.run(
                partial(train_function, device=device),
                resources_per_trial={'cpu': NetworkHyperparamsTuner.CPUs_COUNT, 'gpu': NetworkHyperparamsTuner.GPUs_COUNT},
                config=tune_config,
                num_samples=num_samples,
                scheduler=scheduler,
                progress_reporter=reporter,
                local_dir='../raytune_logs'
            )

            return results

    @staticmethod
    def general_tune_network(net, criterion, optimizer, checkpoint_dir, device):
        if (device == 'cuda:0' and torch.cuda.device_count() > 1):
            net = nn.DataParallel(net)
        net.to(device)

        if checkpoint_dir:
            model_state, optimizer_state = torch.load(
                os.path.join(checkpoint_dir, 'checkpoint'))
            net.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)
        
        trainloader, valloader = NetworkHyperparamsTuner.__load_data()

        for epoch in range(NetworkHyperparamsTuner.MAX_EPOCHS):
            train_loss, train_acc = NetworkHyperparamsTuner.__train_one_epoch(trainloader, net, optimizer, criterion, epoch, device)
            
            val_loss, val_acc = NetworkHyperparamsTuner.__validate_one_epoch(valloader, net, criterion, device)

            with tune.checkpoint_dir(epoch) as checkpoint_dir:
                path = os.path.join(checkpoint_dir, 'checkpoint')
                torch.save((net.state_dict(), optimizer.state_dict()), path)

            writer = SummaryWriter(tune.get_trial_dir())
            print(tune.get_trial_dir())
            writer.add_scalars('loss',
                                {'Training': train_loss, 'Validation': val_loss},
                                epoch)
            writer.add_scalars('acc',
                                {'Training': train_acc, 'Validation': val_acc},
                                epoch)

            tune.report(loss=val_loss, accuracy=val_acc)
        
        print('Finished Training')


    @staticmethod 
    def __load_data():
        batch_size = 32
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        return data_load.load_train_data(transform, batch_size)

    @staticmethod
    def __train_one_epoch(trainloader, net: nn.Module, optimizer, criterion, epoch, device):
        running_loss = 0.0
        epoch_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_steps += 1

        train_loss = running_loss / epoch_steps
        train_acc = correct / total

        return train_loss, train_acc

    @staticmethod
    def __validate_one_epoch(valloader, net, criterion, device):
        val_loss = 0.0
        val_steps = 0.0
        total = 0
        correct = 0
        for i, data in enumerate(valloader):

            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        val_loss = val_loss / val_steps
        val_acc = correct / total

        return val_loss, val_acc

In [None]:
class NetworkTester():

    @staticmethod
    def test_accuracy(net, device, net_name):
        if not os.path.exists('../tests'):
            os.makedirs('../tests')

        net.to(device)
        testloader, number_of_images = NetworkTester.__load_data()

        classes = np.empty((number_of_images), dtype=object)

        total = 0

        with torch.no_grad():
            for data in testloader:
                images = data[0].to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.tolist()
                for predicted_label in predicted:
                    classes[total] = CIFAR10Testset.label_number_to_class[predicted_label]
                    total += 1

        pd.DataFrame(classes, index=range(1, total + 1), columns=['label']).to_csv('../tests/' + net_name + '.csv', index_label='id')

    @staticmethod 
    def __load_data():
        batch_size = 32
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        return data_load.load_test_data(transform, batch_size)
        

In [None]:
# tuning & training
def tune_net(config, checkpoint_dir=None, device=None):
    net = VGGLike(config['dropout_factor'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    NetworkHyperparamsTuner.general_tune_network(net, criterion, optimizer, checkpoint_dir, device)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

config = {
    'dropout_factor': tune.uniform(0.1, 0.5)
}

net_results = NetworkHyperparamsTuner.tune_network(tune_net, config, device, 2)

In [None]:
net_best_trial = net_results.get_best_trial('loss', 'min', 'last')
net_best_trained_model = VGGLike(net_best_trial.config['dropout_factor'])

net_best_checkpoint_dir = net_best_trial.checkpoint.value
model_state, optimizer_state = torch.load(os.path.join(net_best_checkpoint_dir, 'checkpoint'))

net_best_trained_model.load_state_dict(model_state)
test_acc = NetworkTester.test_accuracy(net_best_trained_model, device, 'tuned')