In [5]:
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import models
from torchvision import transforms
import torch.nn as nn

In [6]:
import argparse
import os
from pathlib import Path
import shutil
import time

In [7]:
import augmentations

In [8]:
# defaults to 32 in augmentations (cifar) set to imagenet
augmentations.IMAGE_SIZE = 224

In [9]:
# Thoughts - this samples and augments the same image multiple times (up to 4)
# Then Applies a convex comination of these augments to increase robustness/generalisation

In [10]:
# define hyperparams and args here - can later class or modularise
bs = 8
epochs = 15
pretrained = True 
lr = 0.1 # initial learning rate - schedulers will anneal this
momentum = 0.9 # as standard. Uses 0.95 - 0.85 in OneCycle learners
wd = 0.001

mixture_depth = 0 # setting to 0 uses np.rand.int(1, 4) for operations per image (I think?)
mixture_width = 3 # number of augmentation chains to mix per augmented example (3 default)
model = 'resnet50'

resume = False # normally a path - but passes a bool check first
save_dir = False # required currently train - cannot be false

all_ops = True # use all augmentation operations for training
aug_severity = 1 # severity of base augmentation operators
aug_prob_coeff = 1. # probability distribution coefficients
no_jsd = False # use JSD loss

clean_path = Path('/home/jack/Documents/DL/Experiments/covid-19/dataset')
evaluate = True # calculates val loss (and normally corruption loss) if True

evaluate_corrupt = False # evaluate corruption loss (requires a corrupted dataset)
corrupt_path = None # path for corrupted dataset
num_workers = 4
print_freq = 1

In [11]:
CORRUPTIONS = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
    'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
    'brightness', 'contrast', 'elastic_transform', 'pixelate',
    'jpeg_compression'
]

# Raw AlexNet errors taken from https://github.com/hendrycks/robustness
ALEXNET_ERR = [
    0.886428, 0.894468, 0.922640, 0.819880, 0.826268, 0.785948, 0.798360,
    0.866816, 0.826572, 0.819324, 0.564592, 0.853204, 0.646056, 0.717840,
    0.606500
]

In [12]:
# can break main out - but optimizer still need to model.parameters() somewhere

In [13]:
class AugMixDataset(torch.utils.data.Dataset):
  """Dataset wrapper to perform AugMix augmentation."""

  def __init__(self, dataset, preprocess, no_jsd=False):
    self.dataset = dataset
    self.preprocess = preprocess
    self.no_jsd = no_jsd

  def __getitem__(self, i):
    x, y = self.dataset[i]
    if self.no_jsd:
      return aug(x, self.preprocess), y
    else:
      im_tuple = (self.preprocess(x), aug(x, self.preprocess), aug(x, self.preprocess))
      return im_tuple, y

  def __len__(self):
    return len(self.dataset)

In [14]:
# This does cosine annealing after an epoch (built in to many default optimizers/schedulers)
# OneCycle Policy & CyclicLearners both do this for example

def adjust_learning_rate(optimizer, epoch):
  """Sets the learning rate to the initial LR (linearly scaled to batch size) decayed by 10 every n / 3 epochs."""
  b = bs / 256.
  k = epochs // 3
  if epoch < k:
    m = 1
  elif epoch < 2 * k:
    m = 0.1
  else:
    m = 0.01
  lr = lr * m * b
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [15]:
def accuracy(output, target, topk=(1,)):
  """Computes the accuracy over the k top predictions for the specified values of k."""
  with torch.no_grad():
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
      correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
      res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [16]:
def compute_mce(corruption_accs):
  """Compute mCE (mean Corruption Error) normalized by AlexNet performance."""
  mce = 0.
  for i in range(len(CORRUPTIONS)):
    avg_err = 1 - np.mean(corruption_accs[CORRUPTIONS[i]])
    ce = 100 * avg_err / ALEXNET_ERR[i]
    mce += ce / 15
  return mce

In [17]:
def aug(image, preprocess):
  """Perform AugMix augmentations and compute mixture.
  Args:
    image: PIL.Image input image
    preprocess: Preprocessing function which should return a torch tensor.
  Returns:
    mixed: Augmented and mixed image.
  """
  aug_list = augmentations.augmentations
  if all_ops:
    aug_list = augmentations.augmentations_all

  ws = np.float32(
      np.random.dirichlet([aug_prob_coeff] * mixture_width))
  m = np.float32(np.random.beta(aug_prob_coeff, aug_prob_coeff))

  mix = torch.zeros_like(preprocess(image))
  for i in range(mixture_width):
    image_aug = image.copy()
    depth = mixture_depth if mixture_depth > 0 else np.random.randint(
        1, 4)
    for _ in range(depth):
      op = np.random.choice(aug_list)
      image_aug = op(image_aug, aug_severity)
    # Preprocessing commutes since all coefficients are convex
    mix += ws[i] * preprocess(image_aug)

  mixed = (1 - m) * preprocess(image) + m * mix
  return mixed

In [18]:
# test is standard cross entropy loss between clean images x_val and y_val
def test(net, test_loader):
  """Evaluate network on given dataset."""
  net.eval()
  total_loss = 0.
  total_correct = 0
  with torch.no_grad():
    for images, targets in test_loader:
      images, targets = images.cuda(), targets.cuda()
      logits = net(images)
      loss = F.cross_entropy(logits, targets)
      pred = logits.data.max(1)[1]
      total_loss += float(loss.data)
      total_correct += pred.eq(targets.data).sum().item()

  return total_loss / len(test_loader.dataset), total_correct / len(
      test_loader.dataset)


def test_c(net, test_transform):
  """Evaluate network on given corrupted dataset."""
  corruption_accs = {}
  for c in CORRUPTIONS:
    print(c)
    for s in range(1, 6):
      valdir = os.path.join(corrupted_data, c, str(s))
      val_loader = torch.utils.data.DataLoader(
          datasets.ImageFolder(valdir, test_transform),
          batch_size=eval_batch_size,
          shuffle=False,
          num_workers=num_workers,
          pin_memory=True)

      loss, acc1 = test(net, val_loader)
      if c in corruption_accs:
        corruption_accs[c].append(acc1)
      else:
        corruption_accs[c] = [acc1]

      print('\ts={}: Test Loss {:.3f} | Test Acc1 {:.3f}'.format(
          s, loss, 100. * acc1))

  return corruption_accs

In [19]:
def train(net, train_loader, optimizer):
  """Train for one epoch."""
  net.train()
  data_ewma = 0.
  batch_ewma = 0.
  loss_ewma = 0.
  acc1_ewma = 0.
  acc5_ewma = 0.

  end = time.time()
  for i, (images, targets) in enumerate(train_loader):
    # Compute data loading time
    data_time = time.time() - end
    optimizer.zero_grad()

    if no_jsd:
      images = images.cuda()
      targets = targets.cuda()
      logits = net(images)
      loss = F.cross_entropy(logits, targets)
      acc1, acc5 = accuracy(logits, targets, topk=(1, 5))  # pylint: disable=unbalanced-tuple-unpacking
    else:
      images_all = torch.cat(images, 0).cuda()
      targets = targets.cuda()
      logits_all = net(images_all)
      logits_clean, logits_aug1, logits_aug2 = torch.split(logits_all, images[0].size(0))

      # Cross-entropy is only computed on clean images
      loss = F.cross_entropy(logits_clean, targets)

      p_clean, p_aug1, p_aug2 = F.softmax(logits_clean, dim=1), \
                                F.softmax(logits_aug1, dim=1), \
                                F.softmax(logits_aug2, dim=1)

      # Clamp mixture distribution to avoid exploding KL divergence
      p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
      loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
      acc1, acc5 = accuracy(logits_clean, targets, topk=(1, 5))  # pylint: disable=unbalanced-tuple-unpacking

    loss.backward()  
    optimizer.step() 

    # Compute batch computation time and update moving averages.
    batch_time = time.time() - end
    end = time.time()

    data_ewma = data_ewma * 0.1 + float(data_time) * 0.9
    batch_ewma = batch_ewma * 0.1 + float(batch_time) * 0.9
    loss_ewma = loss_ewma * 0.1 + float(loss) * 0.9
    acc1_ewma = acc1_ewma * 0.1 + float(acc1) * 0.9
    acc5_ewma = acc5_ewma * 0.1 + float(acc5) * 0.9

    if i % print_freq == 0:
      print(
          'Batch {}/{}: Data Time {:.3f} | Batch Time {:.3f} | Train Loss {:.3f} | Train Acc1 '
          '{:.3f} | Train Acc5 {:.3f}'.format(i, len(train_loader), data_ewma,
                                              batch_ewma, loss_ewma, acc1_ewma,
                                              acc5_ewma))
      print('\n') 

  return loss_ewma, acc1_ewma, batch_ewma

In [24]:
def main():
    torch.manual_seed(1)
    np.random.seed(1)

  # Load datasets
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    train_transform = transforms.Compose(
      [transforms.RandomResizedCrop(224),
       transforms.RandomHorizontalFlip()])
    
    preprocess = transforms.Compose(
      [transforms.ToTensor(),
       transforms.Normalize(mean, std)])
    test_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      preprocess,
  ])

    traindir = os.path.join(clean_path, 'train')

    valdir = os.path.join(clean_path, 'val')
    
    train_dataset = datasets.ImageFolder(traindir, train_transform)
    train_dataset = AugMixDataset(train_dataset, preprocess)
    train_loader = torch.utils.data.DataLoader(
                              train_dataset,
                              batch_size = bs,
                              shuffle=True,
                              num_workers=num_workers)

    val_loader = torch.utils.data.DataLoader(
                              datasets.ImageFolder(valdir, test_transform),
                              batch_size = bs,
                              shuffle=False,
                              num_workers=num_workers)

    if pretrained:
        print("=> using pre-trained model '{}'".format(model))
        net = models.__dict__[model](pretrained=True)
    else:
        print("=> creating model '{}'".format(model))
        net = models.__dict__[model]()
    
    optimizer = torch.optim.SGD(
      net.parameters(),
      lr = lr,
      momentum = momentum,
      weight_decay = wd)
    
  # Distribute model across all visible GPUs
    net = torch.nn.DataParallel(net).cuda()
    cudnn.benchmark = True

    start_epoch = 0

    if resume:
        if os.path.isfile(resume):
            checkpoint = torch.load(resume)
            start_epoch = checkpoint['epoch'] + 1
            best_acc1 = checkpoint['best_acc1']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('Model restored from epoch:', start_epoch)

    if evaluate:
        test_loss, test_acc1 = test(net, val_loader)
        print('Clean\n\tVal Loss {:.3f} | Val Acc1 {:.3f}'.format(test_loss, 100 * test_acc1))

    if evaluate_corrupt:
        corruption_accs = test_c(net, test_transform)
        for c in CORRUPTIONS:
            print('\t'.join([c] + map(str, corruption_accs[c])))

        print('mCE (normalized by AlexNet): ', compute_mce(corruption_accs))
#         return

    if save_dir: # making it so only saves logs if you provide a savedir (can False it)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if not os.path.isdir(save_dir):
            raise Exception('%s is not a dir' % save_dir)

        log_path = os.path.join(save_dir, 'imagenet_{}_training_log.csv'.format(model))
        with open(log_path, 'w') as f:
            f.write('epoch,batch_time,train_loss,train_acc1(%),test_loss,test_acc1(%)\n')

    best_acc1 = 0
    print('Beginning training from epoch:', start_epoch + 1)
    for epoch in range(start_epoch, epochs):
        print(f'epoch : {epoch}')
#         adjust_learning_rate(optimizer, epoch) # this is similar to a scheduler step (potentially replace?)

        train_loss_ewma, train_acc1_ewma, batch_ewma = train(net, train_loader, optimizer)
        test_loss, test_acc1 = test(net, val_loader)

        is_best = test_acc1 > best_acc1
        best_acc1 = max(test_acc1, best_acc1)
        checkpoint = {
            'epoch': epoch,
            'model': model,
            'state_dict': net.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
            }

        if save_dir:
            save_dir_path = os.path.join(save_dir, 'checkpoint.pth.tar')
            torch.save_dir(checkpoint, save_dir_path)
            if is_best:
                shutil.copyfile(save_dir_path, os.path.join(save_dir, 'model_best.pth.tar'))

            with open(log_path, 'a') as f:
                  f.write('%03d,%0.3f,%0.6f,%0.2f,%0.5f,%0.2f\n' % (
                  (epoch + 1),
                  batch_ewma,
                  train_loss_ewma,
                  100. * train_acc1_ewma,
                  test_loss,
                  100. * test_acc1,
              ))

        print(
            'Epoch {:3d} | Train Loss {:.4f} | Test Loss {:.3f} | Test Acc1 '
            '{:.2f}'
            .format((epoch + 1), train_loss_ewma, test_loss, 100. * test_acc1))

        if evaluate_corrupt:
            corruption_accs = test_c(net, test_transform)
            for c in CORRUPTIONS:
                print('\t'.join(map(str, [c] + corruption_accs[c])))

            print('mCE (normalized by AlexNet):', compute_mce(corruption_accs))