In [None]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
from os import cpu_count
import time
import os
import copy
import wandb
from sklearn.metrics import ConfusionMatrixDisplay

cudnn.benchmark = True
plt.ion()

In [None]:
!wandb login <your key>

In [None]:
torch.cuda.empty_cache()

In [None]:
sweep_config = {
    'method': 'grid'
    }

metric = {
    'name': 'Final accuracy',
    'goal': 'maximize'   
    }

sweep_config['metric'] = metric

parameters_dict = {
    'model': {
        'value': 'squ'
        },
    'pretrained': {
        'value': False
        },
    'optimizer': {
        'value': 'adam'
        },
    'batch_size': {
        'value': 32
        },
    'epochs': {
        'value': 400
        },
    'learning_rate': {
        'value' : 0.0001
        },
    'momentum': {
        'value' : 0.9
        },
    'scheduler': {
        'value' : 'ReduceLROnPlateau'
        },

    'weight_decay': {
        'value': 0.0001
        },
    'augmentation': {
        'values': ['mix', 'rand']
        },
    }
sweep_config['parameters'] = parameters_dict


In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

def train(config=None):
    torch.cuda.empty_cache()
    class_names = ['A','B','C','D','E','F','G','H','I','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y']
    # Initialize a new wandb run
    with wandb.init(config=config, resume=True):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        loader, dataset_sizes = build_dataset(config.batch_size, config.augmentation)# config.img_size)
        network = build_network(config.model, config.pretrained, )#dropout = config.dropout,) #batch_norm = config.batch_normalization)
        optimizer = build_optimizer(network, config.optimizer, config.learning_rate, config.momentum, config.weight_decay)
        scheduler = build_scheduler(scheduler_type=config.scheduler, optimizer_ft=optimizer,)# step_sched = config.step_sched, gamma_sched= config.gamma_sched)
        
        since = time.time()
        # best_model_wts = copy.deepcopy(network.state_dict())
        # best_acc = 0.0
        
        # for epoch in range(config.epochs):
        best_model = train_epoch(network, loader, optimizer, scheduler, dataset_sizes, scheduler_type=config.scheduler, num_epochs=config.epochs,) # model, dataloaders, optimizer, scheduler, dataset_sizes, criterion = nn.CrossEntropyLoss()
        final_loss, final_acc = test_model(best_model, loader, dataset_sizes, ) # (network, loader, dataset_sizes, criterion = nn.CrossEntropyLoss())
        wandb.log({"Final loss": final_loss, "Final accuracy": final_acc})
        top_pred_ids, ground_truth_class_ids = confusion_matrix_pass(best_model, loader)
        wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
            preds=top_pred_ids, y_true=ground_truth_class_ids,
            class_names=class_names)})
             
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')


In [None]:
def build_dataset(batch_size, augmentation, img_size = (256,256), ): #
    from squareImage import squareImage
    
    if augmentation == 'rand':
        augmentation_transform = transforms.RandAugment()
    elif augmentation == 'trivial':
        augmentation_transform = transforms.TrivialAugmentWide()
    elif augmentation == 'mix':
        augmentation_transform = transforms.AugMix()
    else:
        raise Exception
    
    data_transforms = {
        'train': transforms.Compose([
            transforms.Lambda(squareImage), # custom transform to square image before resizing to keep scale
            transforms.Resize(img_size),
            augmentation_transform,
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Lambda(squareImage),
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Lambda(squareImage),
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    batch_sizes = batch_size
    #batchsize and numworkers can be changed to predefined variable for easier logging
    data_dir =  '.\datasets\degree_256_hands\dataset_v3'
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                            data_transforms[x])
                    for x in ['train', 'val', 'test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes,
                                                shuffle=True, num_workers=4)
                for x in ['train', 'val', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val','test']}
    class_names = image_datasets['train'].classes
    return (dataloaders, dataset_sizes)


def build_network(model_arch = 'resnet18', pretrained = True, dropout = 0.5, batch_norm = True):
    if model_arch.startswith('resnet'):
        if model_arch == 'resnet18':
            if pretrained == True:
                model_ft = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)
            else:
                model_ft = models.resnet18(weights = None)
        
        elif model_arch == 'resnet34':
            if pretrained == True:
                model_ft = models.resnet34(weights = models.ResNet34_Weights.DEFAULT)
            else:  
                model_ft = models.resnet34(weights = None)
                
        elif model_arch == 'resnet50':
            if pretrained == True:
                model_ft = models.resnet50(weights = models.ResNet50_Weights.DEFAULT)
            else:  
                model_ft = models.resnet50(weights = None)
        
        elif model_arch == 'resnet101':
            if pretrained == True:
                model_ft = models.resnet101(weights = models.ResNet101_Weights.DEFAULT)
            else:  
                model_ft = models.resnet101(weights = None)        
        num_classes = 24 
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    
    elif model_arch.startswith('vgg'):
        if model_arch == 'vgg11':
            if batch_norm == True:
                if pretrained == True:
                    model_ft = models.vgg11_bn(weights = models.VGG11_BN_Weights.DEFAULT)
                else:
                    model_ft = models.vgg11_bn(weights = None)
                    
            elif batch_norm == False:
                if pretrained == True:
                    model_ft = models.vgg11(weights = models.VGG11_Weights.DEFAULT)
                else:
                    model_ft = models.vgg11(weights = None)               
        
        elif model_arch == 'vgg13':
            if batch_norm == True:
                if pretrained == True:
                    model_ft = models.vgg13_bn(weights = models.VGG13_BN_Weights.DEFAULT)
                else:
                    model_ft = models.vgg13_bn(weights = None)
                    
            elif batch_norm == False:
                if pretrained == True:
                    model_ft = models.vgg13(weights = models.VGG13_Weights.DEFAULT)
                else:
                    model_ft = models.vgg13(weights = None)   
                
        elif model_arch == 'vgg16':
            if batch_norm == True:
                if pretrained == True:
                    model_ft = models.vgg16_bn(weights = models.VGG16_BN_Weights.DEFAULT)
                else:
                    model_ft = models.vgg16_bn(weights = None)
                    
            elif batch_norm == False:
                if pretrained == True:
                    model_ft = models.vgg16(weights = models.VGG16_Weights.DEFAULT)
                else:
                    model_ft = models.vgg16(weights = None)   
        
        elif model_arch == 'vgg19':
            if batch_norm == True:
                if pretrained == True:
                    model_ft = models.vgg19_bn(weights = models.VGG19_BN_Weights.DEFAULT)
                else:
                    model_ft = models.vgg19_bn(weights = None)
                    
            elif batch_norm == False:
                if pretrained == True:
                    model_ft = models.vgg19(weights = models.VGG19_Weights.DEFAULT)
                else:
                    model_ft = models.vgg19(weights = None)   
                      
        num_classes = 24 
        model_ft.classifier[6] = nn.Linear(4096,num_classes)
        
    elif model_arch.startswith('alex'):
        if pretrained == True:
            model_ft = models.alexnet(weights = models.AlexNet_Weights.DEFAULT)
        else:
            model_ft = models.alexnet(weights = None)
            
        num_classes = 24 
        model_ft.classifier[6] = nn.Linear(4096,num_classes)
    
    elif model_arch.startswith('squ'):
        if pretrained == True:
            model_ft = models.squeezenet1_1(weights = models.SqueezeNet1_1_Weights.DEFAULT)
        else:
            model_ft = models.squeezenet1_1(weights = None)
            
        num_classes = 24 
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
    
    elif model_arch.startswith('ince'):
        if pretrained == True:
            model_ft = models.inception_v3(weights = models.Inception_V3_Weights.DEFAULT)
        else:
            model_ft = models.inception_v3(weights = None)
            
        model_ft.aux_logits=False
        num_classes = 24 
        model_ft.AuxLogits.fc = nn.Linear(768, num_classes)
        model_ft.fc = nn.Linear(2048, num_classes)
        
    elif model_arch.startswith('goo'):
        if pretrained == True:
            model_ft = models.googlenet(weights = models.GoogLeNet_Weights.DEFAULT)
        else:
            model_ft = models.googlenet(weights = None)
            
        model_ft.aux_logits=False
        num_classes = 24 
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        
    return model_ft.to(device)
        

def build_optimizer(network, optimizer_type, learning_rate, momentum = 0.9, weight_decay=1e-4):
    if optimizer_type == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    elif optimizer_type == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_type == "adagrad":
        optimizer = optim.Adagrad(network.parameters(),
                               lr=learning_rate, weight_decay=weight_decay)        
    return optimizer

def build_scheduler(scheduler_type, optimizer_ft, step_sched = 3, gamma_sched = 0.1):
    if scheduler_type == "ReduceLROnPlateau":
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='min', factor=0.1, threshold_mode='rel', min_lr=0, eps=1e-08, verbose=False)

    elif scheduler_type == "ConstantLR":
        scheduler = lr_scheduler.ConstantLR(optimizer_ft, factor=1, last_epoch=-1, verbose=False)
        
    elif scheduler_type == "StepLR":
        scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=step_sched, gamma=gamma_sched)

    return scheduler    


def train_epoch(model, dataloaders, optimizer, scheduler, dataset_sizes, scheduler_type, num_epochs, criterion = nn.CrossEntropyLoss(), ):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                # epoch_lr = scheduler.get_last_lr()
                if scheduler_type == "ReduceLROnPlateau":
                    scheduler.step(running_loss)
                else:
                    scheduler.step() # running_loss

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train':
                wandb.log({"Train loss": epoch_loss}, epoch) # epoch_loss
                wandb.log({"Train accuracy": epoch_acc}, epoch) #epoch_acc
                
            if phase == 'val':
                wandb.log({"Validation loss": epoch_loss}, epoch) # epoch_loss
                wandb.log({"Validation accuracy": epoch_acc}, epoch) #epoch_acc
                
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
    model.load_state_dict(best_model_wts)
    return model

def test_model(network, loader, dataset_sizes, criterion = nn.CrossEntropyLoss()):
    
    model = network
    model.eval()
    running_loss = 0.0
    running_corrects = 0

    # Iterate over data.
    for inputs, labels in loader['test']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward
        # track history if only in train
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        # statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data) 

    epoch_loss = running_loss / dataset_sizes['test']
    epoch_acc = running_corrects.double() / dataset_sizes['test']
    
    return (epoch_loss, epoch_acc)
    
def confusion_matrix_pass(model, loader):
    was_training = model.training
    model.eval()
    predictions = []
    truths = []
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(loader['test']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            for j in range(inputs.size()[0]):
                prediction = int(preds[j])
                truth = int(labels.cpu().data[j].int())
                predictions.append(prediction)
                truths.append(truth)
        model.train(mode=was_training)
        return (predictions, truths)
                

In [None]:
import pprint

pprint.pprint(sweep_config)

In [None]:
sweep_id = wandb.sweep(sweep_config, project="Name of sweep")

In [None]:
wandb.agent(sweep_id, train)

In [None]:
wandb.finish()