# image classification with PyTorch

## import libraries

In [None]:
# . . import libraries
import os
from pathlib import Path
# . . pytorch modules
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils import data
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# . . numpy
import numpy as np
# . . scikit-learn
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
# . . matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as npimg
# . .  set this to be able to see the figure axis labels in a dark theme
from matplotlib import style
#style.use('dark_background')
# . . to see the available options
# print(plt.style.available)

from torchsummary import summary

# . . import libraries by tugrulkonuk
import utils
#from dataset import Dataset
from model import *
from trainer import Trainer
from callbacks import ReturnBestModel, EarlyStopping


# set device and precision

In [None]:
# . . set the device
if torch.cuda.is_available():  
    device = torch.device("cuda")  
else:  
    device = torch.device("cpu")      

#device = torch.device("cpu")      
# . . set the default tensor to cuda: DO NOT USE THIS
#torch.set_default_tensor_type('torch.cuda.FloatTensor')
# . . set the default precision
dtype = torch.float32

# . . use cudnn backend for performance
torch.backends.cudnn.benchmark = True

torch.backends.cudnn.enabled = True

In [None]:
# . . arguments . .
# . . this is only for the Jupyter notebook as there is no command line
class Args():
    # . . number of epochs 
    epochs = 100

    # . . the learning rate 
    lr = 0.001

    # . . batch_size
    batch_size = 1024

    # . . fraction of data to be used in training
    train_size = 0.8

    # . . min delta (min improvement) for early stopping
    min_delta = 0.0005

    # . . patience for early stopping
    patience = 10

    # . . number of workers for the data loader
    num_workers = 8

    # . . use pinn memory for faster CPU-GPU transler
    pin_memory = False

    # . . print interval
    jprint = 1



In [None]:
# . . instantiate the command-line parameters object
args = Args()

# . . get command-line parameters
num_epochs    = args.epochs
batch_size    = args.batch_size
learning_rate = args.lr
train_size    = args.train_size
min_delta     = args.min_delta
patience      = args.patience 
num_workers   = args.num_workers
pin_memory    = args.pin_memory
jprint        = args.jprint

## import the data


In [None]:
# . . transformer for data augmentation
transformer_train = torchvision.transforms.Compose([
  # torchvision.transforms.ColorJitter(
  #     brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
  transforms.RandomCrop(32, padding=4),
  torchvision.transforms.RandomHorizontalFlip(p=0.5),
  # torchvision.transforms.RandomRotation(degrees=15),
  torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),
  # torchvision.transforms.RandomPerspective(),
  transforms.ToTensor(),                                            
])

# . . the train set
train_dataset = torchvision.datasets.CIFAR10(
    root='.',
    train=True,
    transform=transformer_train,
    download=True)

# . . the validation set: no augmentation!
valid_dataset = torchvision.datasets.CIFAR10(
    root='.',
    train=False,
    transform=transforms.ToTensor(),
    download=True)


In [None]:
# . . the number of classes in the data
num_classes = len(set(train_dataset.targets))
print('number of classes: ',num_classes)

In [None]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## data loaders

In [None]:
# . . the training loader: shuffle
trainloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, 
                         num_workers=num_workers, pin_memory=pin_memory)

# . . the test loader: no shuffle
validloader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, pin_memory=pin_memory)

In [None]:
in_channels = 3
# . . instantiate the model
model = BayesianCNNClassifierCIFAR(in_channels, num_classes, lrt=False)

# . . send model to device (GPU)
model.to(device)

# . . show a summary of the model
summary(model, (3, 32, 32))

In [None]:
# . . create the trainer
trainer = Trainer(model, device)

# . . compile the trainer
# . . define the loss
class elbo(nn.Module):
    def __init__(self):
        super(elbo, self).__init__()

    def forward(self, input, target, kl, beta, batch_size):
        assert not target.requires_grad
        return F.nll_loss(input, target, reduction='mean') * batch_size + beta * kl

criterion = elbo().to(device)

# . . define the optimizer
optimparams = {'lr':learning_rate
              }

# . . define the callbacks
cb=[ReturnBestModel(), EarlyStopping(min_delta=min_delta, patience=patience)]

trainer.compile(optimizer='adam', criterion=criterion, callbacks=cb, jprint=jprint, **optimparams)

# . . the learning-rate scheduler
schedulerparams = {'factor':0.5,
                   'patience':50,
                   'threshold':1e-5,
                   'cooldown':5,
                   'min_lr':1e-5,                
                   'verbose':True               
                  }
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, **schedulerparams)

In [None]:
# . . train the network
train_loss, valid_loss = trainer.fit(trainloader, validloader, scheduler=None, num_epochs=num_epochs)

In [None]:
plt.plot(train_loss)
plt.plot(valid_loss)
plt.legend(['train_loss', 'valid_loss'])

In [None]:
eval_batch_size = 32

# . . training dataset without augmentation
train_dataset_noaug = torchvision.datasets.CIFAR10(
                      root='.',
                      train=True,
                      transform=transforms.ToTensor(),
                      download=True)

# . . data loader for the training dataset without transforms
trainloader_noaug = torch.utils.data.DataLoader(
                     dataset=train_dataset_noaug, 
                     batch_size=eval_batch_size, 
                     shuffle=False,
                     num_workers=num_workers,
                     pin_memory=pin_memory)

In [None]:
training_accuracy, test_accuracy = trainer.evaluate(trainloader_noaug, validloader)

In [None]:
#. . calculate and plot the confusion matrix
x_test = valid_dataset.data
y_test = np.array(valid_dataset.targets)
p_test = np.array([])

num_ensemble = 1
for inputs, targets in validloader:
    # . . move to device
    inputs, targets = inputs.to(device), targets.to(device)

    # . . prepare the outputs for multiple ensembles
    outputs = torch.zeros(inputs.shape[0], trainer.model.num_classes, num_ensemble).to(device)
                
    # . . feed-forward network: multiple ensembles
    kl_div = 0.0                
    for ens in range(num_ensemble):
        outputs_, kl_div_ = trainer.model(inputs)
        # . . accumulate the kl div loss
        kl_div += kl_div_
        # . . keep the outputs
        outputs[:,:,ens] = F.log_softmax(outputs_, dim=1).data

    # . . normalise the kl div loss over ensembles
    kl_div /= num_ensemble

    # . . make sure the outputs are positive
    log_outputs = utils.logmeanexp(outputs, dim=2)
    #log_outputs = torch.mean(outputs, dim=2)

    # . . network predictions
    _, predictions = torch.max(log_outputs, 1)

    # . . update the p-test
    p_test = np.concatenate((p_test, predictions.cpu().numpy()))

# . . the confusion matrix
cm = confusion_matrix(y_test, p_test)

# . . plot the confusion matrix 
utils.plot_confusion_matrix(cm, list(range(10)))

In [None]:
torch.save(trainer.model.state_dict(), 'models/final_model.pt')