**Import libraries**

In [0]:
import numpy as np
import os
import copy

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function, Variable
from torch import optim
from torch.optim import *
from torch.utils.data import Dataset, DataLoader, Subset
from pathlib import Path
from torch.utils import data
from torchvision import models
from torchvision.datasets import ImageFolder

from tqdm import tqdm
from itertools import cycle
from torch.distributions.beta import Beta
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from glob import glob
import time
from PIL import Image, ImageOps, ImageEnhance
from sklearn import metrics
import pandas as pd
import seaborn as sns
from random import randint

**Mount drive**

In [0]:
from google.colab import drive
drive.mount('/content/drive')
%cd "/content/drive/My Drive/AIML_Project"

**Set arguments**

In [0]:
NUM_CLASSES = 50
BATCH_SIZE = 32
NUM_EPOCHS = 30
LR = 1e-4
DECAY_START = 23
LMBDA = 0.1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ALPHA = 1.0
EPS = 1e-16
GAMMA = 0.1
STEP_SIZE = 15

PROJECT_PATH = str(Path(os.getcwd()))
ARCHIVE_PATH = str(Path(PROJECT_PATH) / "archives")
NPZ_PATH = str(Path(ARCHIVE_PATH) / "NPZ")
CHECKPOINT_PATH = str(Path(ARCHIVE_PATH) / "AwA2_checkpoints")
PATH_LAST = CHECKPOINT_PATH + '/last.pth'
PATH_BEST = CHECKPOINT_PATH + '/best.pth'
PATH_AUG = CHECKPOINT_PATH + '/augmix.pth.tar'

DATA_DIR = 'Animals_with_Attributes2'

train_classes = np.array(np.genfromtxt(DATA_DIR + '/trainclasses.txt', dtype='str'))
classes = np.array(np.genfromtxt(DATA_DIR + '/classes.txt', dtype='str'))[:,-1]
predicates = np.array(np.genfromtxt(DATA_DIR + '/predicates.txt', dtype='str'))[:,-1]
predicate_binary_mat = np.array(np.genfromtxt(DATA_DIR + '/predicate-matrix-binary.txt', dtype='int'))
num_labels = len(predicates)

**Plot functions**

In [0]:
def imshow(image, title=None):
  image = image.numpy().transpose((1, 2, 0))
  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])
  image = std * image + mean
  image = np.clip(image, 0, 1)
  plt.figure(figsize=(5,5))
  plt.imshow(image)
  if title is not None:
    plt.title(title)
  plt.axis('off')
  plt.pause(0.001)  # pause a bit so that plots are updated
  plt.show()

def little_imshow(image, title=None):
  image = image.numpy().transpose((1, 2, 0))
  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])
  image = std * image + mean
  image = np.clip(image, 0, 1)
  plt.figure(figsize=(2,2))
  plt.imshow(image)
  if title is not None:
    plt.title(title)
  plt.axis('off')
  plt.pause(0.001)  # pause a bit so that plots are updated
  plt.show()

def show_image_with_attributes(img):
  preds = predicates
  image = img['image'].cpu()
  pred_label = img['pred_label'].cpu()
  label = img['label'].cpu()
  attributes_pred = img['pred_attr'].cpu()
  attributes = img['attr'].cpu()
  preds_attributes = []
  colours_matrix = []

  for i in range(85):
    if attributes[i] == attributes_pred[i]:
      preds_attributes.append(preds[i] + ': ' + str(int(attributes_pred[i])) + '/' + str(int(attributes[i])))
      colours_matrix.append(['green'])
    else:
      preds_attributes.append(preds[i] + ': ' + str(int(attributes_pred[i])) + '/' + str(int(attributes[i])))
      colours_matrix.append(['red'])
  for i in range(5):
    preds_attributes.append('')
    colours_matrix.append(['white'])

  preds_attributes = np.array(preds_attributes)
  preds_attributes = preds_attributes.reshape(9,10)
  colours_matrix = np.array(colours_matrix)
  colours_matrix = colours_matrix.reshape(9,10)

  df = pd.DataFrame(preds_attributes.reshape(9, 10))

  fig = plt.figure(figsize=(20,4))
  grid = plt.GridSpec(1, 5, wspace=0.0, hspace=0.0)
  fig.suptitle("Random image vs grid of attributes (green = match, red = mismatch) and labels")

  ax1 = plt.subplot(grid[0, 0])
  out = image.numpy().transpose((1, 2, 0))
  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])
  out = std * out + mean
  out = np.clip(out, 0, 1)
  plt.axis('off')
  plt.subplots_adjust(hspace=0, wspace=1)
  plt.imshow(out)

  ax2 = plt.subplot(grid[0, 1:])
  font_size=9
  bbox=[0, 0, 1, 1]
  ax2.axis('off')
  mpl_table = ax2.table(cellText = df.values, bbox=bbox, cellLoc='left', cellColours=colours_matrix)
  mpl_table.auto_set_font_size(False)
  mpl_table.set_fontsize(font_size)
  
  plt.show()
  print('Predicted class: {}'.format(classes[pred_label.item()]))
  print('True class: {}'.format(classes[label.item()]))

**Data Augmentations (AugMix)**

In [0]:
IMAGE_SIZE = 224

def int_parameter(level, maxval):
  return int(level * maxval / 10)

def float_parameter(level, maxval):
  return float(level) * maxval / 10.

def sample_level(n):
  return np.random.uniform(low=0.1, high=n)

def autocontrast(pil_img, _):
  return ImageOps.autocontrast(pil_img)

def equalize(pil_img, _):
  return ImageOps.equalize(pil_img)

def posterize(pil_img, level):
  level = int_parameter(sample_level(level), 4)
  return ImageOps.posterize(pil_img, 4 - level)

def rotate(pil_img, level):
  degrees = int_parameter(sample_level(level), 30)
  if np.random.uniform() > 0.5:
    degrees = -degrees
  return pil_img.rotate(degrees, resample=Image.BILINEAR)

def solarize(pil_img, level):
  level = int_parameter(sample_level(level), 256)
  return ImageOps.solarize(pil_img, 256 - level)

def shear_x(pil_img, level):
  level = float_parameter(sample_level(level), 0.3)
  if np.random.uniform() > 0.5:
    level = -level
  return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR)


def shear_y(pil_img, level):
  level = float_parameter(sample_level(level), 0.3)
  if np.random.uniform() > 0.5:
    level = -level
  return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR)

def translate_x(pil_img, level):
  level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
  if np.random.random() > 0.5:
    level = -level
  return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR)

def translate_y(pil_img, level):
  level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
  if np.random.random() > 0.5:
    level = -level
  return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR)

def color(pil_img, level):
  level = float_parameter(sample_level(level), 1.8) + 0.1
  return ImageEnhance.Color(pil_img).enhance(level)

def contrast(pil_img, level):
  level = float_parameter(sample_level(level), 1.8) + 0.1
  return ImageEnhance.Contrast(pil_img).enhance(level)

def brightness(pil_img, level):
  level = float_parameter(sample_level(level), 1.8) + 0.1
  return ImageEnhance.Brightness(pil_img).enhance(level)

def sharpness(pil_img, level):
  level = float_parameter(sample_level(level), 1.8) + 0.1
  return ImageEnhance.Sharpness(pil_img).enhance(level)

augmentations_all = [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y, color, contrast, brightness, sharpness]
aug_list = augmentations_all

**Data Preprocessing**

In [0]:
MEAN = (0.4665, 0.4589, 0.3968)
STD = (0.2454, 0.2356, 0.2443)

standard_preprocess = transforms.Compose([
  transforms.Resize(256), 
  transforms.CenterCrop(224),    
])

preprocess = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(MEAN, STD)
])

aug_prob_coeff = 1. # Probability distribution coefficients
mixture_width = 3 # Number of augmentation chains to mix per augmented example
mixture_depth = -1 # Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]
aug_severity = 1 # Severity of base augmentation operators

def aug(image, preprocess):
  ws = np.float32(np.random.dirichlet([aug_prob_coeff] * mixture_width))
  m = np.float32(np.random.beta(aug_prob_coeff, aug_prob_coeff))

  mix = torch.zeros_like(preprocess(image))
  for i in range(mixture_width):
    image_aug = image.copy()
    depth = mixture_depth if mixture_depth > 0 else np.random.randint(1, 4)
    for _ in range(depth):
      op = np.random.choice(aug_list)
      image_aug = op(image_aug, aug_severity)
    # Preprocessing commutes since all coefficients are convex
    mix += ws[i] * preprocess(image_aug)

  mixed = (1 - m) * preprocess(image) + m * mix
  return mixed

**Dataset & dataloaders**

In [0]:
class AnimalDataset(data.dataset.Dataset):
  def __init__(self, classes_file, transform=None):
    predicate_binary_mat = np.array(np.genfromtxt(DATA_DIR + '/predicate-matrix-binary.txt', dtype='int'))
    self.predicate_binary_mat = predicate_binary_mat
    self.transform = transform

    class_to_index = dict()
    # Build dictionary of indices to classes
    with open(DATA_DIR + '/classes.txt') as f:
      index = 0
      for line in f:
        class_name = line.split('\t')[1].strip()
        class_to_index[class_name] = index
        index += 1
    self.class_to_index = class_to_index

    img_names = []
    img_index = []
    
    with open(('{}').format(classes_file)) as f:
      for line in f:
        class_name = line.split('\t')[1].strip()
        FOLDER_DIR = os.path.join(DATA_DIR + '/JPEGImages', class_name)
        file_descriptor = os.path.join(FOLDER_DIR, '*.jpg')
        files = glob(file_descriptor)

        class_index = class_to_index[class_name]
        for file_name in files:
          img_names.append(file_name)
          img_index.append(class_index)
    self.img_names = img_names
    self.img_index = img_index

  def __getitem__(self, index):
    im = Image.open(self.img_names[index])
    if im.getbands()[0] == 'L':
      im = im.convert('RGB')
    if self.transform is not None:
      im = self.transform(im)

    im_index = self.img_index[index]
    im_predicate = self.predicate_binary_mat[im_index,:]
    return im, im_predicate, im_index

  def __len__(self):
    return len(self.img_names)

class WrapDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, preprocess, augmix=False):
    self.dataset = dataset
    self.preprocess = preprocess
    self.augmix = augmix

  def __getitem__(self, i):
    x, y, z = self.dataset[i]
    if self.augmix:
      flip = transforms.Compose([transforms.RandomHorizontalFlip()])
      x = flip(x)
      return aug(x, self.preprocess), y, z
    else:
      return self.preprocess(x), y, z

  def __len__(self):
    return len(self.dataset)


def prepare_dataloaders(labeled_dataset, unlabeled_dataset, test_dataset, ind_dataset, batch_size):
  dataloaders = {}
  
  dataloaders['labeled_dataloader'] = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
  dataloaders['unlabeled_dataloader'] = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
  dataloaders['val_dataloader'] = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)
  dataloaders['ind_dataloader'] = DataLoader(ind_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)
  
  return dataloaders

print('Loading dataset...')
all_dataset = AnimalDataset(DATA_DIR + '/classes.txt', standard_preprocess)
print('Dataset loaded.')

# Check dataset sizes
print('All Dataset: {}'.format(len(all_dataset)))

train_indexes = [idx for idx in range(len(all_dataset)) if idx % 2]
test_indexes = [idx for idx in range(len(all_dataset)) if not idx % 10]
ind_indexes = [idx for idx in range(len(all_dataset)) if not idx % 2]
ind_indexes = [idx for idx in ind_indexes if idx not in test_indexes]
ind_indexes = [idx for idx in range(len(ind_indexes)) if not idx % 8]
train_dataset = Subset(all_dataset, train_indexes)
test_dataset = Subset(all_dataset, test_indexes)
test_dataset = WrapDataset(test_dataset, preprocess)
ind_dataset = Subset(all_dataset, ind_indexes)
ind_dataset = WrapDataset(ind_dataset, preprocess)
labeled_indexes = [idx for idx in range(len(train_dataset)) if not idx % 2]
labeled_indexes = [idx for idx in labeled_indexes if not idx % 3]
labeled_dataset = Subset(train_dataset, labeled_indexes)
labeled_dataset =  WrapDataset(labeled_dataset, preprocess, augmix=True)
unlabeled_dataset = WrapDataset(train_dataset, preprocess, augmix=True)

# Check dataset sizes
print('labeled Dataset: {}'.format(len(labeled_dataset)))
print('unlabeled Dataset: {}'.format(len(unlabeled_dataset)))
print('Test Dataset: {}'.format(len(test_dataset)))
print('Ind Dataset: {}'.format(len(ind_dataset)))

dataloaders = prepare_dataloaders(labeled_dataset, unlabeled_dataset, test_dataset, ind_dataset, BATCH_SIZE)

**ADA-Net**

In [0]:
class ReverseLayerF(Function):
  @staticmethod
  def forward(ctx, x, alpha):
    ctx.alpha = alpha
    return x.view_as(x)

  @staticmethod
  def backward(ctx, grad_output):
    output = grad_output.neg() * ctx.alpha
    return output, None

def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

resnet_dict = {"ResNet18":models.resnet18, "ResNet34":models.resnet34, "ResNet50":models.resnet50, "ResNet101":models.resnet101, "ResNet152":models.resnet152}

class ResNetFc(nn.Module):
  def __init__(self, resnet_name="ResNet50", pretrained=False, augmix=True, bottleneck_dim=256, class_num=50, attr_num=85):
    super(ResNetFc, self).__init__()
    model_resnet = resnet_dict[resnet_name]() #(pretrained=pretrained)
    if pretrained:
      state_dict = torch.load(CHECKPOINT_PATH + "/" + resnet_name + ".pth")
      model_resnet.load_state_dict(state_dict)
    if augmix:
      state_dict = torch.load(PATH_AUG)['state_dict']
      state_dict = {k[len('module.'):]: state_dict[k] for k, _ in state_dict.items()}
      model_resnet.load_state_dict(state_dict)
    self.lambd = 0.0
    self.conv1 = model_resnet.conv1
    self.bn1 = model_resnet.bn1
    self.relu = model_resnet.relu
    self.maxpool = model_resnet.maxpool
    self.layer1 = model_resnet.layer1
    self.layer2 = model_resnet.layer2
    self.layer3 = model_resnet.layer3
    self.layer4 = model_resnet.layer4
    self.avgpool = model_resnet.avgpool
    self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool)

    self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
    self.fc_attr = nn.Linear(bottleneck_dim, attr_num)
    self.fc_class = nn.Linear(bottleneck_dim, class_num)

    self.bottleneck.apply(init_weights)
    self.fc_attr.apply(init_weights)
    self.fc_class.apply(init_weights)
    self.__in_features = bottleneck_dim
    
    self.discriminator = nn.Sequential(
      nn.Linear(bottleneck_dim, 1024),
      nn.ReLU(inplace=True),
      nn.Dropout(),
      nn.Linear(1024, 1024),
      nn.ReLU(inplace=True),
      nn.Dropout(),
      nn.Linear(1024, 2),
      nn.Softmax(dim=1)
    )

  def update_lambda(self, x):
    self.lambd = 1.0 * (2. / (1. + np.exp(-10 * (x))) - 1.)

  def forward(self, x):
    x = self.feature_layers(x)
    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.bottleneck(x)
    reverse_feature = ReverseLayerF().apply(x, self.lambd)
    domain_output = self.discriminator(reverse_feature)
    attribute_output = self.fc_attr(x)
    class_output = self.fc_class(x)
    # only class_output is passed through the softmax
    return class_output, attribute_output, domain_output

**Train**

In [0]:
def train_Ada_Net(model, dataloaders, optimizer, scheduler, criterions, num_epochs=NUM_EPOCHS, decay_start=DECAY_START):
  
  labeled_dataloader = dataloaders['labeled_dataloader']
  unlabeled_dataloader = dataloaders['unlabeled_dataloader']
  val_dataloader = dataloaders['val_dataloader']
  start_epoch = 1  # whether 1 or loaded from the checkpoint.
  best_accuracy = -1.0
  
  if Path(PATH_LAST).exists():
    model, optimizer, start_epoch = load_checkpoint('last', model, optimizer)
  
  if Path(PATH_BEST).exists():
    best_accuracy = load_checkpoint('best', model, optimizer)
  
  for epoch in range(start_epoch, num_epochs+1):
    print('\n\nEpoch [{}/{}]'.format(epoch, num_epochs), flush=True)
    if decay_start:
      if epoch > decay_start:
        for g in optimizer.param_groups:
          decayed_lr = (num_epochs - epoch) * g['lr'] / (num_epochs - decay_start)
          g['lr'] = decayed_lr
          g['betas'] = (0.5, 0.999) # reducing the momentum coefficient in the end of training can improve convergence

    model = train_loop(model, labeled_dataloader, unlabeled_dataloader, optimizer, criterions, epoch)

    attribute_accuracy, current_accuracy = eval_loop(model, val_dataloader)
    
    if current_accuracy > best_accuracy:
      best_accuracy = current_accuracy
      save_checkpoint('best', epoch, model, _, best_accuracy)
  
    scheduler.step()
    
    model.update_lambda(float((epoch)/num_epochs))

    save_checkpoint('last', epoch, model, optimizer, _)

def save_checkpoint(mode, epoch, model, optimizer, accuracy):
  if mode == 'last':
    torch.save({'epoch': epoch, 
              'model_state_dict': model.state_dict(), 
              'optimizer_state_dict': optimizer.state_dict()
              }, PATH_LAST)
  if mode == 'best':
    torch.save({'epoch': epoch, 
            'model_state_dict': model.state_dict(), 
            'best_accuracy': accuracy
            }, PATH_BEST)
      
def load_checkpoint(mode, model, optimizer):
  if mode == 'last':
    checkpoint = torch.load(PATH_LAST)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print('Loaded checkpoint at epoch {}'.format(start_epoch-1))
    return model, optimizer, start_epoch
  if mode == 'best':
    checkpoint = torch.load(PATH_BEST)
    model.load_state_dict(checkpoint['model_state_dict'])
    best_accuracy = checkpoint['best_accuracy']
    best_epoch = checkpoint['epoch']
    print('Loaded checkpoint of best model at epoch {}'.format(best_epoch))
    return best_accuracy

def train_loop(model, labeled_dataloader, unlabeled_dataloader, optimizer, criterions, epoch, lmbda=LMBDA):
  batch_num = tqdm(range(len(unlabeled_dataloader)), position=0, leave=True)
  labeled_dataloader = cycle(labeled_dataloader)
  unlabeled_dataloader = iter(unlabeled_dataloader)
  
  running_corrects = 0

  for i in batch_num:
    labeled_images, attributes, labels = next(labeled_dataloader)
    unlabeled_images, _, _ = next(unlabeled_dataloader)
    labeled_images, attributes, labels = labeled_images.to(DEVICE), attributes.to(DEVICE).float(), labels.to(DEVICE)
    unlabeled_images = unlabeled_images.to(DEVICE)
    model.train()

    pred_labels, pred_attributes, _ = model(labeled_images)

    running_corrects += torch.sum(pred_labels.max(1)[1] == labels.data)
    batch_acc = running_corrects.double() / (len(labeled_images)*(i+1)) 

    class_loss = criterions['class'](pred_labels, labels)
    sigmoid_attr_outputs = torch.sigmoid(pred_attributes)
    attr_loss = criterions['attr'](sigmoid_attr_outputs, attributes)
    
    model.eval()
    with torch.no_grad():
      pseudo_labels, _, _ = model(unlabeled_images)
    model.train()
    pseudo_labels = F.softmax(pseudo_labels, 1)

    # generate mixed inputs, two one-hot label vectors and mixing coefficient
    #mixed_images, mixed_labels, mixed_indexes = mixup_data(labeled_images, labels, unlabeled_images, pseudo_labels)
    mixed_images, mixed_labels, mixed_indexes = cutmix_data(labeled_images, labels, unlabeled_images, pseudo_labels)

    pred_labels_mixed, _, pred_domains_mixed = model(mixed_images)
    pred_labels_mixed = F.softmax(pred_labels_mixed, 1)
    class_loss_mixed = criterions['mixed'](pred_labels_mixed, mixed_labels)

    domain_loss_mixed = criterions['mixed'](pred_domains_mixed, mixed_indexes)
    mixed_loss = (class_loss_mixed + domain_loss_mixed) * 0.3
    
    total_loss = class_loss + attr_loss + mixed_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    batch_num.set_postfix({'class_loss': class_loss.item(),
                           'attr_loss': attr_loss.item(),
                           'class_loss_mixed': class_loss_mixed.item(),
                           'domain_loss_mixed': domain_loss_mixed.item(),
                           'total_loss': total_loss.item(),
                           'train_acc': batch_acc.item()})
  return model

def eval_loop(model, _val_dataloader): 
  model.eval()
  batch_num = tqdm(range(len(_val_dataloader)), position=0, leave=True)
  val_dataloader = iter(_val_dataloader)
  running_corrects_top1 = 0
  running_corrects_top5 = 0
  
  num_elements = 0
  running_correct_attributes = torch.Tensor([]).to(DEVICE)
  running_attr_outputs = torch.Tensor([]).to(DEVICE)

  with torch.no_grad():
    i = randint(0, len(val_dataloader)-2)
    y = randint(0, BATCH_SIZE-1)
    for val_batch_num in batch_num:
      val_images, val_attributes, val_labels = next(val_dataloader)
      val_images, val_attributes, val_labels = val_images.to(DEVICE), val_attributes.to(DEVICE).float(), val_labels.to(DEVICE)
      val_pred_labels, val_pred_attributes, _ = model(val_images)

      attr_outputs = (torch.sigmoid(val_pred_attributes) > 0.5).float()

      running_correct_attributes =  torch.cat((running_correct_attributes, val_attributes), 0)
      running_attr_outputs =  torch.cat((running_attr_outputs, attr_outputs), 0)

      attr_score = hamming_score(running_correct_attributes.cpu(), running_attr_outputs.cpu())

      _, preds = torch.max(val_pred_labels.data, 1)

      running_corrects_top1 += torch.sum(preds == val_labels.data).data.item()

      preds_top5 = torch.topk(val_pred_labels, 5)
      indices_labels = preds_top5.indices

      for j in range(BATCH_SIZE):
        if val_labels[j] in indices_labels[j]:
          running_corrects_top5 += 1

      num_elements += len(val_images)
      corrects_top1 = running_corrects_top1 / float(num_elements)
      corrects_top5 = running_corrects_top5 / float(num_elements)
      batch_num.set_postfix({'top1_acc': corrects_top1,
                             'top5_acc': corrects_top5,
                             'hamming score': attr_score})
      
      if val_batch_num == i:
        r_img = {}
        r_img['image'] = val_images[y]
        r_img['pred_label'] = preds[y]
        r_img['label'] = val_labels[y]
        r_img['pred_attr'] = attr_outputs[y]
        r_img['attr'] = val_attributes[y]

  show_image_with_attributes(r_img)
  
  accuracy = running_corrects_top1 / float(len(_val_dataloader.dataset))
  return attr_score, accuracy


def mixup_data(labeled_images, labels, unlabeled_images, pseudo_labels):
  '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
  beta = Beta(torch.tensor([ALPHA]), torch.tensor([ALPHA])) # defining the Beta distribution
  batch_size, *shape = labeled_images.shape # splitting into batch size and image dimensions
  lam = beta.sample((batch_size,)).squeeze(1).to(DEVICE) # sampling 256 samples from the Beta distribution
  _lam = lam.view(batch_size, 1, 1, 1).repeat(1, *shape) # transform each of the 256 lamdas into a tensor of the image dimensions size

  zeros_onehot = torch.zeros(batch_size, NUM_CLASSES, dtype=float).to(DEVICE)
  labels_onehot = zeros_onehot.scatter_(1, labels.unsqueeze(1), 1)

  mixed_images = _lam * labeled_images + (1 - _lam) * unlabeled_images # generate random mixed input image
  mixed_labels = lam.view(batch_size, 1) * labels_onehot + (1 - lam).view(batch_size, 1) *  pseudo_labels # generate random mixed labels
  mixed_indexes = torch.stack([lam, 1 - lam], dim=1).to(DEVICE) # stacked vector of the 2 lambdas for each batch (batch_num images per batch)

  return mixed_images, mixed_labels, mixed_indexes

def cutmix_data(labeled_images, labels, unlabeled_images, pseudo_labels):
  '''Compute the mixcut data. Return mixed inputs, pairs of targets, and lambda'''
  beta = Beta(torch.tensor([ALPHA]), torch.tensor([ALPHA])) # defining the Beta distribution
  batch_size, *shape = labeled_images.shape # splitting into batch size and image dimensions
  lam = beta.sample((batch_size,)).squeeze(1) #.to(DEVICE) # sampling 256 samples from the Beta distribution

  zeros_onehot = torch.zeros(batch_size, NUM_CLASSES, dtype=float).to(DEVICE)
  labels_onehot = zeros_onehot.scatter_(1, labels.unsqueeze(1), 1)

  # CutMix
  image_height, image_width = labeled_images.shape[2:]
  cut_x = np.random.uniform(0, image_width)  # return a single random value for the x cut size extremum
  cut_y = np.random.uniform(0, image_height) # return a single random value for the y cut size extremum

  cut_ratio = np.sqrt(1. - lam.numpy()) # how much to cut from the original image (batch size cut percentages)
  cut_width = (image_width * cut_ratio).astype(int)    # width minimum for cutting the image
  cut_height = (image_height * cut_ratio).astype(int)  # height minimum for cutting the image

  x0s = np.clip(cut_x - cut_width // 2, 0, image_width).astype(int)   # initial x positions of the cut
  x1s = np.clip(cut_x + cut_width // 2, 0, image_width).astype(int)   # final x positions of the cut
  y0s = np.clip(cut_y - cut_height // 2, 0, image_height).astype(int) # initial y positions of the cut
  y1s = np.clip(cut_y + cut_height // 2, 0, image_height).astype(int) # final y positions of the cut

  mixed_images = copy.deepcopy(labeled_images)

  for idx, (x0, x1, y0, y1) in enumerate(zip(x0s, x1s, y0s, y1s)):
    mixed_images[idx, :, y0:y1, x0:x1] = unlabeled_images[idx, :, y0:y1, x0:x1] # take the cut box of the unlabeled img and substitute onto the labeled img
    # adjust lambda to exactly match pixel ratio
    lam[idx] = 1 - ((x1 - x0) * (y1 - y0) / (labeled_images[idx].size()[-1] * labeled_images[idx].size()[-2]))

  lam = lam.to(DEVICE)
  mixed_labels = lam.view(batch_size, 1) * labels_onehot + (1 - lam).view(batch_size, 1) *  pseudo_labels # generate random mixed labels
  mixed_indexes = torch.stack([lam, 1 - lam], dim=1).to(DEVICE) # stacked vector of the 2 lambdas for each batch (batch_num images per batch)

  return mixed_images, mixed_labels, mixed_indexes


class KL_div(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, prob, target):
    kl = (target * torch.log((target + EPS) / (prob + EPS))).sum(1)
    return kl.mean()

def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
  '''
  Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case
  '''
  acc_list = []
  for i in range(y_true.shape[0]):
      set_true = set( np.where(y_true[i])[0] )
      set_pred = set( np.where(y_pred[i])[0] )
      tmp_a = None
      if len(set_true) == 0 and len(set_pred) == 0:
          tmp_a = 1
      else:
          tmp_a = len(set_true.intersection(set_pred))/float( len(set_true.union(set_pred)) )
      acc_list.append(tmp_a)
  return np.mean(acc_list)

**Prepare training**

In [0]:
def prepare_training(lr=LR, step_size=STEP_SIZE, gamma=GAMMA):
  
  net = ResNetFc().to(DEVICE)
  
  optimizer = optim.Adam(net.parameters(), lr=lr)
  scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

  criterions = {}
  criterions['class'] = nn.CrossEntropyLoss()
  criterions['attr'] = nn.BCELoss()
  criterions['mixed'] = KL_div()

  return net, optimizer, scheduler, criterions

**Main**

In [0]:
net, optimizer, scheduler, criterions = prepare_training()

train_Ada_Net(net, dataloaders, optimizer, scheduler, criterions)

**Test Dataset**

In [0]:
class TestDataset(data.dataset.Dataset):
  def __init__(self, classes_file, transform=None):
    self.transform = transform

    class_to_index = dict()
    # Build dictionary of indices to classes
    with open(DATA_DIR + '/classes.txt') as f:
      index = 0
      for line in f:
        class_name = line.split('\t')[1].strip()
        class_to_index[class_name] = index
        index += 1
    class_to_index['unknown'] = 50
    self.class_to_index = class_to_index
    img_names = []
    img_index = []
    
    for key in class_to_index.keys():
      FOLDER_DIR = os.path.join('Test', key)
      file_descriptor = os.path.join(FOLDER_DIR, '*.jpg')
      files = glob(file_descriptor)
      class_index = class_to_index[key]
      for file_name in files:
        img_names.append(file_name)
        img_index.append(class_index)
    self.img_names = img_names
    self.img_index = img_index

  def __getitem__(self, index):
    im = Image.open(self.img_names[index])
    if im.getbands()[0] == 'L':
      im = im.convert('RGB')
    if self.transform is not None:
      im = self.transform(im)

    im_index = self.img_index[index]
    return im, _, im_index

  def __len__(self):
    return len(self.img_names)

test_preprocess = transforms.Compose([
  transforms.Resize(256), 
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(MEAN, STD)
])

**Confidence**

In [0]:
criterion = nn.CrossEntropyLoss()

test_preprocess = transforms.Compose([
  transforms.Resize(256), 
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(MEAN, STD)
])

unknown_dataset = TestDataset(DATA_DIR + '/classes.txt', test_preprocess)
ood_dataloader = DataLoader(unknown_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=4)

ind_dataloader = dataloaders['ind_dataloader']

def tpr95(ind_confidences, ood_confidences):
    #calculate the falsepositive error when tpr is 95%
    Y1 = ood_confidences
    X1 = ind_confidences

    start = np.min([np.min(X1), np.min(Y1)])
    end = np.max([np.max(X1), np.max(Y1)])
    gap = (end - start) / 100000

    total = 0.0
    fpr = 0.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
        error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
        if tpr <= 0.9505 and tpr >= 0.9495:
            fpr += error2
            total += 1

    fprBase = fpr / total

    return fprBase

def detection(ind_confidences, ood_confidences, n_iter=100000, return_data=False):
    # calculate the minimum detection error
    Y1 = ood_confidences
    X1 = ind_confidences

    start = np.min([np.min(X1), np.min(Y1)])
    end = np.max([np.max(X1), np.max(Y1)])
    gap = (end - start) / n_iter

    best_error = 1.0
    best_delta = None
    all_thresholds = []
    all_errors = []
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
        error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
        detection_error = (tpr + error2) / 2.0

        if return_data:
            all_thresholds.append(delta)
            all_errors.append(detection_error)

        if detection_error < best_error:
            best_error = np.minimum(best_error, detection_error)
            best_delta = delta

    if return_data:
        return best_error, best_delta, all_errors, all_thresholds
    else:
        return best_error, best_delta

def evaluate(model, dataloader, T, eps):
  out = []
  for data in tqdm(dataloader, position=0, leave=True):
    if len(data) == 3:
      images, _, _ = data
    else:
      images, _ = data

    images = Variable(images, requires_grad=True).to(DEVICE)
    images.retain_grad()

    model.zero_grad()
    pred, _, _ = model(images)
    _, pred_idx = torch.max(pred.data, 1)
    labels = Variable(pred_idx)
    pred = pred / T
    loss = criterion(pred, labels)
    loss.backward()

    images = images - eps * torch.sign(images.grad)
    images = Variable(images.data, requires_grad=True)

    pred, _, _ = model(images)
    output = pred

    pred = pred / T
    pred = F.softmax(pred, dim=-1)

    pred = torch.max(pred.data, 1)[0]

    pred = pred.cpu().numpy()
    
    out.append(pred)

  out = np.concatenate(out)
  return out

# GRID SEARCH

temperatures = np.linspace(1.23, 1.45, 12)
#epsilons = [0, 0.002, 0.004, 0.006]
epsilons = [0]
best_error = 1000.0

for T in temperatures:
  model = ResNetFc().to(DEVICE)
  checkpoint = torch.load(PATH_BEST)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.eval()
  for eps in epsilons:
    print('\nT = {}, eps = {}'.format(T, eps), flush=True)
    print('Evaluation of ind_dataset...', flush=True)
    ind_scores = evaluate(model, ind_dataloader, T, eps)
    ind_labels = np.ones(ind_scores.shape[0])
    print('\nEvaluation of ood_dataset...', flush=True)
    ood_scores = evaluate(model, ood_dataloader, T, eps)
    ood_labels = np.zeros(ood_scores.shape[0])

    labels = np.concatenate([ind_labels, ood_labels])
    scores = np.concatenate([ind_scores, ood_scores])

    fpr, tpr, threshold = metrics.roc_curve(labels, scores)
    roc_auc = metrics.auc(fpr, tpr)

    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

    fpr_at_95_tpr = tpr95(ind_scores, ood_scores)
    detection_error, best_delta = detection(ind_scores, ood_scores)

    if detection_error<best_error:
      best_error = detection_error
      best_eps = eps
      best_temp = T

    auroc = metrics.roc_auc_score(labels, scores)
    aupr_in = metrics.average_precision_score(labels, scores)
    aupr_out = metrics.average_precision_score(-1 * labels + 1, 1 - scores)

    print("")
    print("TPR95 (lower is better): ", fpr_at_95_tpr)
    print("Detection error (lower is better): ", detection_error)
    print("Best threshold:", best_delta)
    print("AUROC (higher is better): ", auroc)
    print("AUPR_IN (higher is better): ", aupr_in)
    print("AUPR_OUT (higher is better): ", aupr_out)

    ranges = (np.min(scores), np.max(scores))
    plt.figure()
    sns.distplot(ind_scores.ravel(), hist_kws={'range': ranges}, kde=False, bins=50, norm_hist=True, label='In-distribution')
    sns.distplot(ood_scores.ravel(), hist_kws={'range': ranges}, kde=False, bins=50, norm_hist=True, label='Out-of-distribution')
    plt.xlabel('Confidence')
    plt.ylabel('Density')
    plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    plt.legend()
    plt.show()

**Test**

In [0]:
def test_out(model, dataloader, T, eps, threshold):

  out = []
  running_corrects = 0
  errors = 0

  for images, _, _labels in tqdm(dataloader, position=0, leave=True):
    
    images = Variable(images, requires_grad=True).to(DEVICE)
    images.retain_grad()

    model.zero_grad()
    pred_labels, pred_attributes, _ = model(images)
    _, pred_idx = torch.max(pred_labels.data, 1)
    labels = Variable(pred_idx)
    pred_labels = pred_labels / T
    loss = criterions['class'](pred_labels, labels)
    loss.backward()

    images = images - eps * torch.sign(images.grad)
    images = Variable(images.data, requires_grad=True)

    pred_labels, _, _ = model(images)

    pred_labels = pred_labels / T
    pred_labels = F.softmax(pred_labels, dim=-1)

    for i in range(len(pred_labels)):
      pred_label = pred_labels[i]
      pred_attribute = pred_attributes[i]
      true_label = _labels[i]
    
      if(true_label == 50):
        #Here we have an out of distribution sample
        print('There is an out of distribution sample.')
        if (pred_label.max() < threshold):
          running_corrects += 1
          top5 = torch.topk(pred_label, 5)
          values_labels = top5.values
          indices_labels = top5.indices                    
          imshow(images[i].detach().cpu())
             
          attr_outputs = (torch.sigmoid(pred_attribute) > 0.5).float()
          top85 = torch.topk(torch.sigmoid(pred_attribute), 85)
          attributes_found = []
          attributes_found_indices = []
          for z in range(85):
            if predicate_binary_mat[indices_labels[0].item()][z] == 0 and attr_outputs[z] == 1:
              attributes_found_indices.append(z)
          values_attributes = top85.values
          indices_attributes = top85.indices
          first = 0
          if not attributes_found_indices:
            most_attribute = 'strange'
          else:
            for i in range(85):
              for j in attributes_found_indices:
                if indices_attributes[i].item()==j:
                  if first == 0:
                    most_attribute = predicates[j]
                    first = 1
          print('New class found: {} {}.'.format(most_attribute, classes[indices_labels[0].item()]))
        else:
          errors += 1
      else:
        print('There is an in of distribution sample.')
        if (pred_label.max() < threshold):
          errors +=1
        else:
          _, preds = torch.max(pred_label.data, -1)
          if preds.item() == true_label.item():
            running_corrects +=1       
  return running_corrects, errors

model = net
if Path(PATH_BEST).exists():
  best_accuracy = load_checkpoint('last', model, optimizer)
model.eval()

unknown_dataset = TestDataset(DATA_DIR + '/classes.txt', test_preprocess)
test_dataset = unknown_dataset + ind_dataset
print('Length dataset: {}'.format(len(test_dataset)))
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# TEST PARAMETERS AFTER GRID SEARCH

T = 1.4
eps = 0.0
threshold = 0.50

corrects, errors = test_out(model, test_dataloader, T, eps, threshold)

total = len(test_dataloader.dataset)

print('Accuracy: {:.4f}'.format(corrects / total))
print('Errors: [{}/{}]'.format(errors, total))