In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms, utils
from copy import deepcopy
from os import path
import time
import torch
import wandb
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix
import plotly.figure_factory as ff

In [None]:
class myImages:
  def __init__(self, data_dir, n, test_dir = '', advprop = False, ):
    self.device = torch.device("cuda:0")
    self.advprop = advprop
    if advprop:  # for models using advprop pretrained weights
      normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
    else:
      self.mu = [0.485, 0.456, 0.406]
      self.sd = [0.229, 0.224, 0.225]
      normalize = transforms.Normalize(mean=self.mu, std=self.sd)

    # https://gist.github.com/WillKoehrsen/d5de7d61e9cc2971c5aed1763f6e1ff3#file-pytorch_transforms-py
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(n+32, scale=(0.8, 1.0)),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(),
            transforms.RandomHorizontalFlip(),
            transforms.CenterCrop(size=n),
            transforms.ToTensor(),
            normalize
        ]),
        'val': transforms.Compose([
            transforms.Resize(n+32),
            transforms.CenterCrop(n),
            transforms.ToTensor(),
            normalize        
        ]),
        'test': transforms.Compose([
            transforms.Resize(n+32),
            transforms.CenterCrop(n),
            transforms.ToTensor(),
            transforms.Normalize(self.mu, self.sd)
        ]),        
    }

    if test_dir == '':
      if path.isdir(path.join(data_dir, 'test')):
        self.arr = ['train', 'val', 'test']
        self.image_datasets = {x: datasets.ImageFolder(path.join(data_dir, x), data_transforms[x]) for x in self.arr}
      else:
        arr = ['train', 'val']
        self.image_datasets = {x: datasets.ImageFolder(path.join(data_dir, x), data_transforms[x]) for x in self.arr}
    else:
      self.arr = ['train', 'val', 'test']
      self.image_datasets = {x: datasets.ImageFolder(path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
      self.image_datasets.update({x: datasets.ImageFolder(path.join(test_dir, x), data_transforms[x]) for x in ['test']})

    self.dataloaders = {x: torch.utils.data.DataLoader(self.image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in self.arr}
    self.dataset_sizes = {x: len(self.image_datasets[x]) for x in self.arr}
    self.class_names = self.image_datasets['train'].classes
    self.nb_classes = len(self.class_names)

  def imshow(self, title=None):
    inputs, classes = next(iter(self.dataloaders['train']))
    out = utils.make_grid(inputs)
    inp = out.numpy().transpose((1, 2, 0))
    if self.advprop:  # for models using advprop pretrained weights
      normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
      # inp = inp * np.array(2.0) - np.array(1.0)
    else:
      mean = np.array(self.mu)
      std = np.array(self.sd)
      inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is None:
      title = [self.class_names[x] for x in classes]
    plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

  def train_model(self, model, criterion, optimizer, scheduler, num_epochs=25, name='no-name', is_inception = False):
    since = time.time()    
    val_acc_history = []
    test_acc_history = []    

    if model.__class__.__name__:
      project_name = model.__class__.__name__ + '-' + name

    wandb.init(project=project_name)
    config = wandb.config
    config.learning_rate = optimizer.state_dict()['param_groups'][0]['lr']
    config.momentum = optimizer.state_dict()['param_groups'][0]['momentum']
    config.step_size = scheduler.step_size
    config.gamma = scheduler.gamma
    config.max_epochs = num_epochs
    wandb.watch(model)

    best_model_wts = deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        y_true = dict()
        y_pred = dict()

        y_true['val'] = y_true['test'] = np.array(-1)
        y_pred['val'] = y_pred['test'] = np.array(-1)

        # Each epoch has a training and validation phase
        for phase in self.arr:
            y_true[phase] = np.array(-1)
            y_pred[phase] = np.array(-1)

            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in self.dataloaders[phase]:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == "train":
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4 * loss2
                    else:                
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    wandb.log({phase + "_loss": loss})

                    # backward + optimize only if in training phase
                    if phase == 'train':                      
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                y_pred[phase] = np.append(y_pred[phase], preds.cpu().numpy())
                y_true[phase] = np.append(y_true[phase], labels.data.cpu().numpy())

            epoch_loss = running_loss / self.dataset_sizes[phase]
            epoch_acc = running_corrects.double() / self.dataset_sizes[phase]
            
            z = confusion_matrix(np.delete(y_true[phase], 0), np.delete(y_pred[phase], 0))
            z_p = z / z.astype(np.float).sum(axis=1)[:, np.newaxis]
            fig = conmat(self.class_names, z)
            fig_p = conmat(self.class_names, z_p)
            wandb.log ({'epoch': epoch, 'confusion_matrix ' + phase: wandb.data_types.Plotly(fig)})
            wandb.log ({'epoch': epoch, 'confusion_matrix percentage ' + phase: wandb.data_types.Plotly(fig_p)})
            wandb.log({'epoch': epoch, 'acc' + phase: epoch_acc})
            wandb.log({'epoch': epoch, 'loss' + phase: epoch_loss})
            wandb.log({'epoch': epoch, 'acc': epoch_acc})
            wandb.log({'epoch': epoch, 'loss': epoch_loss})

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = deepcopy(model.state_dict())
                
            if phase == "val":
                val_acc_history.append(epoch_acc)

            if phase == "test":
                test_acc_history.append(epoch_acc)    
                            
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    self.model = model
    self.val_acc_history = val_acc_history

In [None]:
def visualize_model(model, dataloaders, class_names, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                # ax.set_title('predicted:\n {}'.format(class_names[preds[j]]))
                ax.set_title('{} ({})'.format(class_names[preds[j]], class_names[labels[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

In [None]:
def conmat(class_names, z):
  # https://stackoverflow.com/questions/60860121/plotly-how-to-make-an-annotated-confusion-matrix-using-a-heatmap
  z = [[round(y, 4) for y in x] for x in z]    
  z_text = [[str(y) for y in x] for x in z]
  fig = ff.create_annotated_heatmap(z, x=class_names, y=class_names, annotation_text=z_text, colorscale='Viridis')
  fig.update_layout(title_text='<i><b>Confusion matrix</b></i>',
                    yaxis = dict(title='true'),
                    xaxis = dict(title='pred')
  )
  fig.add_annotation(dict(font=dict(color="black",size=14),
    x=0.5,
    y=-0.15,
    showarrow=False,
    text="Predicted value",
    xref="paper",
    yref="paper"))
  
  fig.add_annotation(dict(font=dict(color="black",size=14),
    x=-0.35,
    y=0.5,
    showarrow=False,
    text="Real value",
    textangle=-90,
    xref="paper",
    yref="paper"))
  
  fig.update_layout(margin=dict(t=50, l=200))
  fig['data'][0]['showscale'] = True
  return (fig)
