# Configuration

In [None]:
# How many training epochs?
EPOCHS = 10 

# Validation/test fraction
val_fraction = 0.1

# Batch size
bs = 32

# pick GPU (you probably need to set this value to 0 to run or re-train the model)
gpu = 1

# If true, use resnet18. Else use VGG16
use_resnet = True

# Store better models?
save_better_models = True

# If set to false, retrain model. Else skip training and use existing model (save_path)
use_saved_model = True

if use_resnet: #ResNet18
    save_path = 'model_resnet18.pth'
else: # VGG16
    save_path = 'model_vgg16_bn_max100samples.pth' # Trained on maximum of 100 images/individual
    #save_path = 'model_vgg16_bn.pth' # Older model trained on all images
    #save_path = 'model.pth'

# Random seed (makes results reproducible)
# To change how the train/validation split is randomized, change the value of seed (integer)
seed = 0

# Maximum number of images to include for each individual (chosen at random)
max_files_per_individual = 100

# Load data

The images are stored on subdirectories named Part1, Part2, etc.

**Female** folders are Part1-Part6, and **Male** folders are Part7-Part9.

We are interested in all non-profile images matching the following pattern: `./PartX/Person_ID/non-profile/ImageID.jpg`

where

- PartX is one of Part1, Part2, ..., Part 9
- Person_ID is a unique integer associated with each individual person
- ImageID is the image ID

A text file containing a list of all files matching the above pattern can be generated with the following Bash command (run from Ubuntu/Linux terminal):

`find . -type f -name '*.jpg' | grep non-profile | grep -i -v -E 'MAC' > cat.txt`



In [None]:
male_folders = ['Part7','Part8','Part9']

In [None]:
# cd data
# find . -type f -name '*.jpg' | grep non-profile | grep -i -v -E 'MAC' > ../cat.txt
filelist = 'cat.txt'

data = {} # Keys: Person ID, Fields: gender, part and files (jpg files)
with open(filelist) as fp:
    lines = fp.readlines()
    for line in lines:
        _,part,id,_,jpg = line.strip().split('/')
        
        gender = 'woman'
        if part in male_folders: gender = 'man'
        
        if not id in data.keys():
            data[id] = {
                'gender': gender,
                'part': part,
                'files': [jpg]
            }
        else:
            data[id]['files'].append(jpg)


Get some statistics

In [None]:
import matplotlib.pyplot as plt

is_male = [True if data[person]['gender']=='man' else False for person in data.keys()]
num_files = [len(data[person]['files']) for person in data.keys()]

num_male = 0
num_female = 0
num_img_male = 0
num_img_female = 0
for i,male in enumerate(is_male):
    if male:
        num_male += 1
        num_img_male += num_files[i]
    else:
        num_female +=1
        num_img_female += num_files[i]
    
print('male:',num_male,'\nfemale:',num_female)
print('num_files (male):',num_img_male,'\nnum_files (female):',num_img_female)

# Number of files per individual
plt.figure(); plt.plot(num_files); plt.xlabel('Person ID'); plt.ylabel('Number of uploaded images')
plt.figure(); plt.hist(num_files,bins=100); plt.title('Histogram');

Create lists of all image paths (`images`) and labels (`labels`)

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(seed)

# Load image meta data
class_names = ['man','woman']
images = []
labels = []
for i,id in enumerate(data.keys()):
    class_name = data[id]['gender']
    if class_name == 'man':
        class_index = 0
    elif class_name == 'woman':
        class_index = 1
    else:
        print('error class name')

    part = data[id]['part']
    files = data[id]['files']
    
    # Shuffle
    indices = list(range(len(files)))
    np.random.shuffle(indices)
    
    for count,i in enumerate(indices):
        if count > max_files_per_individual: break
        jpg = files[i]
        image_path = os.path.join('./data',part,id,'non-profile',jpg)
        images.append(image_path)
        labels.append(class_index)

# Remove empty images (this will take a while)
delete_ix = []
for i,path in enumerate(images):
    img = cv2.imread(path)
    if img is None:
        delete_ix.append(i)
        print('deleting',str(i))

for i in delete_ix:
    images.pop(i)
    labels.pop(i)

In [None]:
delete_ix = [1505]
#delete_ix = [13222]

In [None]:
# Total number of labelled images
print(len(images),len(labels))

# Set up data loaders

In [None]:
import torch
from torch.utils.data import Dataset as BaseDataset
from PIL import Image

class FacebookDataset(BaseDataset):
    
    def __init__(
            self, 
            images,
            labels,
            class_names,
            indices,
            transform=None,
    ):
        
        self.images = [images[indices[index]] for index in range(len(indices))]
        self.labels = [labels[indices[index]] for index in range(len(indices))]
        self.class_names = class_names
        self.transform = transform
        
    def __getitem__(self, i):
        
        # read data
        x = cv2.cvtColor(cv2.imread(self.images[i]), cv2.COLOR_BGR2RGB)        
        y = self.labels[i]
        x = Image.fromarray(x)
        
        # apply augmentations
        if self.transform:
            x = self.transform(x)
                    
        return x,y
        
    def __len__(self):
        return len(self.images)

In [None]:
import random
import torchvision as tv
from torch.utils.data.sampler import WeightedRandomSampler

# Use same seed for PyTorch as for Numpy
torch.manual_seed(seed)

# Training data preprocessing and augmentation
def train_transform():
    return tv.transforms.Compose([
      tv.transforms.RandomResizedCrop(256,scale=(0.5, 1.0)),
      tv.transforms.RandomRotation(degrees = 15),
      tv.transforms.CenterCrop(224),
      tv.transforms.RandomHorizontalFlip(),
      tv.transforms.ColorJitter(0.1,0.1,0.1,0.05),
      tv.transforms.ToTensor(),
      tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

# Test data preprocessing (no augmentation)
def valid_transform():
    return tv.transforms.Compose([
      tv.transforms.Resize(256),
      tv.transforms.CenterCrop(224),
      tv.transforms.ToTensor(),
      tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

# Needed to handle class imbalance
def class_imbalance_sampler(labels):
    class_sample_count = np.array([len(np.where(np.asarray(labels)==t)[0]) for t in np.unique(np.asarray(labels))])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in labels])
    samples_weight = torch.from_numpy(samples_weight)
    sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
    return sampler

def get_loaders(images, labels, class_names, val_fraction=0.2, bs=32):
    
    dataset_size = len(images)
    val_size = int(val_fraction * dataset_size)
    
    indices = list(range(dataset_size))
    np.random.shuffle(indices)

    train_idx = indices[:dataset_size - val_size]
    valid_idx = indices[dataset_size - val_size:]

    trainset = FacebookDataset(
        images,
        labels,
        class_names,
        train_idx,
        transform=train_transform(),
    )
    
    validset = FacebookDataset(
        images,
        labels,
        class_names,
        valid_idx,
        transform=valid_transform(),
    )
    
    # Same as trainset but without augmentation
    trainset_no_aug = FacebookDataset(
        images,
        labels,
        class_names,
        train_idx,
        transform=valid_transform(),
    )

    train_sampler = class_imbalance_sampler(trainset.labels)
    valid_sampler = class_imbalance_sampler(validset.labels)
    
    # Loaders with sampler (male and female are balanced)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, num_workers=4, sampler=train_sampler) # shuffle=True)
    validloader = torch.utils.data.DataLoader(validset, batch_size=bs, num_workers=4, sampler=valid_sampler)

    # Deterministic loaders to generate/report results faster (no sampler)
    trainloader_deterministic = torch.utils.data.DataLoader(trainset_no_aug, batch_size=bs, num_workers=4) # shuffle=True)
    validloader_deterministic = torch.utils.data.DataLoader(validset, batch_size=bs, num_workers=4)

    return trainloader, validloader, trainset, validset, trainloader_deterministic, validloader_deterministic

In [None]:
trainloader, validloader, trainset, validset, trainloader_deterministic, validloader_deterministic = get_loaders(
    images,
    labels,
    class_names,
    val_fraction=val_fraction,
    bs=bs)

print(class_names)
print('train length:', len(trainset))
print('valid length:', len(validset))

## Display some example images

In [None]:
def tens2im(tensor):
    return np.clip(tensor.squeeze().numpy().transpose(1,2,0) * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406),0,1)

print("TRAINING DATA EXAMPLES")
fig,ax = plt.subplots(2,5,figsize=(20,8))
for i in range(10):
    t,l = trainset[i]
    ax[i//5,i%5].imshow(tens2im(t))
    ax[i//5,i%5].set_title(class_names[l])
    ax[i//5,i%5].axis('off')
plt.show()

print("TEST DATA EXAMPLES")
fig,ax = plt.subplots(2,5,figsize=(20,8))
for i in range(10):
    t,l = validset[np.random.randint(len(validset))]
    ax[i//5,i%5].imshow(tens2im(t))
    ax[i//5,i%5].set_title(class_names[l])
    ax[i//5,i%5].axis('off')
plt.show()

print("SAME TRAINING IMAGE 10 TIMES (WITH RANDOM AUGMENTATION)")
fig,ax = plt.subplots(2,5,figsize=(20,8))
for i in range(10):
    t,l = trainset[0]
    ax[i//5,i%5].imshow(tens2im(t))
    ax[i//5,i%5].set_title(class_names[l])
    ax[i//5,i%5].axis('off')
plt.show()

# Train model (or load trained model)

In [None]:
torch.cuda.set_device(gpu)
print(torch.cuda.is_available())

In [None]:
if use_saved_model:
    # Load trained model
    model = torch.load(save_path)
else:
    # Initialize model (not trained)
    if use_resnet:
        model = tv.models.resnet18(pretrained=True)
        model.fc = torch.nn.Linear(512,len(class_names))
    else: # use VGG
        model = tv.models.vgg16_bn(pretrained=True)
        model.classifier[6] = torch.nn.Linear(4096,len(class_names))

In [None]:
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from IPython.display import clear_output, display
import time

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.000025)

def train(model, loss_fn, optimizer, trainloader, validloader, epochs, save_better_models, save_path, save_fig=None):
    train_losses = []
    valid_losses = []

    valid_acc = []

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    for epoch in range(epochs):
        model.train()
        balance = []
        losses = []
        start_time = time.time()
        for i, (data, target) in enumerate(trainloader):
            pred = model(data.to(device))
            loss = loss_fn(pred, target.to(device))
            losses.append(loss.detach().cpu().item())
            loss.backward()
            balance.append(target.sum())
            optimizer.step()
            optimizer.zero_grad()
            
            if i % 100 == 0:
                duration = time.time() - start_time
                print('train',i,len(trainloader),np.mean(losses),np.mean(balance),duration)
                start_time = time.time()

        train_losses.append(np.mean(losses))
        model.eval()

        losses = []
        accs = []
        balance = []
        with torch.no_grad():
            for i, (data,target) in enumerate(validloader):
                balance.append(target.sum())
                pred = model(data.to(device))
                target = target.to(device)
                loss = loss_fn(pred, target)
                losses.append(loss.detach().cpu().item())
                acc = (pred.max(dim=1)[1] == target).sum().float().item() / target.size(0)
                accs.append(acc)
                if i % 100 == 0:
                    print('validation',i,len(validloader),np.mean(losses),np.mean(balance))

            if save_better_models and (len(valid_acc) == 0 or (np.mean(accs) >= np.max(valid_acc) and np.mean(losses) >= np.max(valid_losses))):
                torch.save(model, save_path)

        valid_acc.append(np.mean(accs))
        valid_losses.append(np.mean(losses))

        fig,ax = plt.subplots(1,2,figsize=(16,4))

        ax[0].plot(train_losses, label='train loss')
        ax[0].plot(valid_losses, label='validation loss')
        ax[1].plot(valid_acc, 'orange',label='validation accuracy')

        ax[0].legend()
        ax[1].legend()

        ax[0].set_xlabel('epoch')
        ax[0].set_ylabel('loss')

        ax[1].set_xlabel('epoch')
        ax[1].set_ylabel('accuracy')

        plt.show()

        print('CURRENT STATUS:')
        print('train loss:', train_losses[-1])
        print('validation loss:', valid_losses[-1])
        print('accuracy', valid_acc[-1])

        if save_fig is not None:
            fig.savefig(save_fig)

        #clear_output(wait=True)

    return model

# Train model?
if not use_saved_model:
    model = train(model, loss_fn, optimizer, trainloader, validloader, EPOCHS, save_better_models, save_path, save_fig='losscurves.png')
    _ = model.eval()

Uncomment and run to save the latest model:

In [None]:
#torch.save(model, save_path)

# Test model

## Test on 10 random images

In [None]:
def single_predict(data, model, get_all=False):
    device = 'cuda' if next(model.parameters()).is_cuda else 'cpu'
    prediction = model(data.unsqueeze(0).to(device)).detach().cpu()
    softmax = torch.softmax(prediction, 1)
    if get_all:
        return softmax.squeeze().detach().cpu()
    else:
        max_pred, max_ind = softmax.max(dim=1)
    return max_pred.item(), max_ind.item()

fig, ax = plt.subplots(2,5,figsize=(20,10))
print("EKSEMPLER PÅ GÆT:")

for i in range(10):
    d,l = validset[np.random.randint(len(validset))]
    p_p,p_i = single_predict(d,model)
    img = tens2im(d)
    ax[i//5,i%5].imshow(img)
    ax[i//5,i%5].set_title('is %s\npred: %s with %.2f' % (class_names[l], class_names[p_i], p_p))
    ax[i//5,i%5].axis('off')

## Calculate confusion matrix

In [None]:
# confusion matrix
import seaborn as sn
import pandas as pd

def get_confusion_from_loader(model, loader, classes, fraction=True):
    matrix = np.zeros((len(classes), len(classes)))
    count = np.zeros((len(classes),))
    
    for i, (images,labels) in enumerate(loader):
        for k in range(len(labels)):
            data,label = images[k,:,:,:].squeeze(), labels[k]
            pred_conf, pred_label = single_predict(data, model)
            matrix[label, pred_label] += 1
            count[label] += 1
        
        if i%100==0: print(i)

    if fraction:
        matrix /= np.maximum(count[...,np.newaxis], 1.)

    return matrix,count

matrix,count = get_confusion_from_loader(model, validloader, class_names, fraction=False)
frame = pd.DataFrame(matrix, index=class_names, columns=class_names)
fig = plt.figure(figsize=(10,7))
sn.heatmap(frame, annot=True)
fig.savefig('confusion_count.png')

matrix /= np.maximum(count[...,np.newaxis], 1.)
frame = pd.DataFrame(matrix, index=class_names, columns=class_names)
fig = plt.figure(figsize=(10,7))
sn.heatmap(frame, annot=True)
fig.savefig('confusion_frac.png')

## True positives and false negatives

In [None]:
limit_number_of_images = False
max_num_images = 10000

def get_scores(model, loader):
    scores = [] # CNN score of correct class
    IDs = [] # Person IDs
    true_labels = []
    is_correct = [] # Is predicted label correct or not?

    for i, (images,labels) in enumerate(loader):
        for k in range(len(labels)):
            data,label = images[k,:,:,:].squeeze(), labels[k]
            class_scores = single_predict(data, model, get_all=True)
            max_pred, max_ind = class_scores.max(dim=0) # max_ind = predicted class index (label = true class index)
            score = class_scores[label].item() # score of correct class
            
            scores.append(score)
            IDs.append((i,k))
            true_labels.append(label)
            is_correct.append(max_ind == label)
        
        if i%100==0: print(i)
        if limit_number_of_images and i > int(max_num_images/bs): break

    scores = np.asarray(scores)
    true_labels = np.asarray(true_labels)
    is_correct = np.asarray(is_correct)
    
    return scores,IDs,true_labels,is_correct

In [None]:
# Slow version : works with all 4 dataloaders
def get_top_k(scores,
              IDs,
              true_labels,
              is_correct,
              male,        # is male? True/False
              tp,          # True positives (True) or False Negatives (False)
              top_k,
              loader):
    
    # Get relevant gender indices
    if male:
        sel_ix_gender = np.where(true_labels==0)[0] # indices of male images
    else:
        sel_ix_gender = np.where(true_labels==1)[0] # indices of female images
    
    # Get relevant prediction indicies (correct or incorrect)
    if tp:
        sel_ix_tpfn = np.where(is_correct==True)[0] # indices of correctly labelled images (true positives)
    else:
        sel_ix_tpfn = np.where(is_correct==False)[0] # indices of incorrectly labelled images (false negatives)
    
    # Get relevant indices
    # - sel_ix_gender : male or female
    # - sel_ix_tpfn : true positive or false negative
    sel = np.intersect1d(sel_ix_gender,sel_ix_tpfn) # selected image indices
    
    # Get top-k indices
    if tp:
        # if true positive, sort scores descending
        sel_top_k = sel[np.argsort(scores[sel])[::-1][:top_k]]
    else:
        # if false negative, sort scores ascending
        sel_top_k = sel[np.argsort(scores[sel])[:top_k]]
    
    # Image indices sorted by rank (score)
    top_k_ids = [IDs[index] for index in sel_top_k]
    
    # Output data (also sorted by rank)
    output_data = [{} for i in range(top_k)]
    
    for i, (images,labels) in enumerate(loader):
        for k in range(len(labels)):
            if (i,k) in top_k_ids:
                rank = top_k_ids.index((i,k))
                index = IDs.index((i,k))
                data,label = images[k,:,:,:].squeeze(), labels[k]
                output_data[rank] = {
                    'data': data,
                    'label': label,
                    'score': scores[index]
                }
                #print(scores[index],rank)
    return output_data

# Fast version : only works with validloader_deterministic and trainloader_deterministic
def get_top_k_fast(scores,
                   IDs,
                   true_labels,
                   is_correct,
                   male,        # is male? True/False
                   tp,          # True positives (True) or False Negatives (False)
                   top_k,
                   loader):
    
    # Get relevant gender indices
    if male:
        sel_ix_gender = np.where(true_labels==0)[0] # indices of male images
    else:
        sel_ix_gender = np.where(true_labels==1)[0] # indices of female images
    
    # Get relevant prediction indicies (correct or incorrect)
    if tp:
        sel_ix_tpfn = np.where(is_correct==True)[0] # indices of correctly labelled images (true positives)
    else:
        sel_ix_tpfn = np.where(is_correct==False)[0] # indices of incorrectly labelled images (false negatives)
    
    # Get relevant indices
    # - sel_ix_gender : male or female
    # - sel_ix_tpfn : true positive or false negative
    sel = np.intersect1d(sel_ix_gender,sel_ix_tpfn) # selected image indices
    
    # Get top-k indices
    if tp:
        # if true positive, sort scores descending
        sel_top_k = sel[np.argsort(scores[sel])[::-1][:top_k]]
    else:
        # if false negative, sort scores ascending
        sel_top_k = sel[np.argsort(scores[sel])[:top_k]]
        
    # Image indices sorted by rank (score)
    top_k_ids = [IDs[index] for index in sel_top_k]
    
    # Output data (also sorted by rank)
    output_data = [{} for i in range(top_k)]
    
    for (i,k) in top_k_ids:
        data, label = loader.dataset[i*bs+k] # Note: This is where the deterministic part is exploited
        rank = top_k_ids.index((i,k))
        index = IDs.index((i,k))
        output_data[rank] = {
            'data': data,
            'label': label,
            'score': scores[index]
        }
    return output_data

def plot_top_k(input_data,title,savefig=None):
    nrow = int(len(input_data) / 10)
    ncol = 10
    fig,ax = plt.subplots(nrow, ncol, figsize=(2*ncol,2*nrow))
    print(title)
    for i in range(len(input_data)):
        row = i // ncol
        col = i % ncol

        data = input_data[i]['data']
        score = input_data[i]['score']

        ax[row,col].imshow(tens2im(data))
        ax[row,col].set_title('score: %.4f' % (score))
        ax[row,col].axis('off')
    
    plt.show()
    if savefig:
        fig.savefig(savefig)
    
def save_top_k(input_data,directory):
    for i in range(len(input_data)):
        data = input_data[i]['data']
        img = tens2im(data)
        img = np.uint8(img*255)
        img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(directory,'img_'+str(i)+'.jpg'),img)

Get scores for test data

In [None]:
scores,IDs,true_labels,is_correct = get_scores(model, validloader_deterministic)

Get top-k results on test data, then display and save

In [None]:
input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=True,tp=True,top_k=200,loader=validloader_deterministic)
plot_top_k(input_data,title='True: Male / Predicted: Male',savefig='male_true_positive_test.png')
!mkdir -p male_true_positive_test
!rm -f male_true_positive_test/*.jpg
save_top_k(input_data,'male_true_positive_test')

input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=True,tp=False,top_k=200,loader=validloader_deterministic)
plot_top_k(input_data,title='True: Male / Predicted: Female',savefig='male_false_negative_test.png')
!mkdir -p male_false_negative_test
!rm -f male_false_negative_test/*.jpg
save_top_k(input_data,'male_false_negative_test')

input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=False,tp=True,top_k=200,loader=validloader_deterministic)
plot_top_k(input_data,title='True: Female / Predicted: Female',savefig='female_true_positive_test.png')
!mkdir -p female_true_positive_test
!rm -f female_true_positive_test/*.jpg
save_top_k(input_data,'female_true_positive_test')

input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=False,tp=False,top_k=200,loader=validloader_deterministic)
plot_top_k(input_data,title='True: Female / Predicted: Male',savefig='female_false_negative_test.png')
!mkdir -p female_false_negative_test
!rm -f female_false_negative_test/*.jpg
save_top_k(input_data,'female_false_negative_test')

# GradCAM (= explainable AI)

https://github.com/kazuto1011/grad-cam-pytorch/blob/master/main.py

In [None]:
#!/usr/bin/env python
# coding: utf-8
#
# Author:   Kazuto Nakashima
# URL:      http://kazuto1011.github.io
# Created:  2017-05-26

from collections import Sequence

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm

class _BaseWrapper(object):
    def __init__(self, model):
        super(_BaseWrapper, self).__init__()
        self.device = next(model.parameters()).device
        self.model = model
        self.handlers = []  # a set of hook function handlers

    def _encode_one_hot(self, ids):
        one_hot = torch.zeros_like(self.logits).to(self.device)
        one_hot.scatter_(1, ids, 1.0)
        return one_hot

    def forward(self, image):
        self.image_shape = image.shape[2:]
        self.logits = self.model(image)
        self.probs = F.softmax(self.logits, dim=1)
        return self.probs.sort(dim=1, descending=True)  # ordered results

    def backward(self, ids):
        """
        Class-specific backpropagation
        """
        one_hot = self._encode_one_hot(ids)
        self.model.zero_grad()
        self.logits.backward(gradient=one_hot, retain_graph=True)

    def generate(self):
        raise NotImplementedError

    def remove_hook(self):
        """
        Remove all the forward/backward hook functions
        """
        for handle in self.handlers:
            handle.remove()

class BackPropagation(_BaseWrapper):
    def forward(self, image):
        self.image = image.requires_grad_()
        return super(BackPropagation, self).forward(self.image)

    def generate(self):
        gradient = self.image.grad.clone()
        self.image.grad.zero_()
        return gradient

class GradCAM(_BaseWrapper):
    """
    "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
    https://arxiv.org/pdf/1610.02391.pdf
    Look at Figure 2 on page 4
    """

    def __init__(self, model, candidate_layers=None):
        super(GradCAM, self).__init__(model)
        self.fmap_pool = {}
        self.grad_pool = {}
        self.candidate_layers = candidate_layers  # list

        def save_fmaps(key):
            def forward_hook(module, input, output):
                self.fmap_pool[key] = output.detach()

            return forward_hook

        def save_grads(key):
            def backward_hook(module, grad_in, grad_out):
                self.grad_pool[key] = grad_out[0].detach()

            return backward_hook

        # If any candidates are not specified, the hook is registered to all the layers.
        for name, module in self.model.named_modules():
            if self.candidate_layers is None or name in self.candidate_layers:
                self.handlers.append(module.register_forward_hook(save_fmaps(name)))
                self.handlers.append(module.register_backward_hook(save_grads(name)))

    def _find(self, pool, target_layer):
        if target_layer in pool.keys():
            return pool[target_layer]
        else:
            raise ValueError("Invalid layer name: {}".format(target_layer))

    def generate(self, target_layer):
        fmaps = self._find(self.fmap_pool, target_layer)
        grads = self._find(self.grad_pool, target_layer)
        weights = F.adaptive_avg_pool2d(grads, 1)

        gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
        gcam = F.relu(gcam)
        gcam = F.interpolate(
            gcam, self.image_shape, mode="bilinear", align_corners=False
        )

        B, C, H, W = gcam.shape
        gcam = gcam.view(B, -1)
        gcam -= gcam.min(dim=1, keepdim=True)[0]
        gcam /= gcam.max(dim=1, keepdim=True)[0]
        gcam = gcam.view(B, C, H, W)

        return gcam

In [None]:

if use_resnet: #ResNet18
    target_layer = 'layer4'
else: # VGG16
    target_layer = 'features'

import matplotlib.cm as cm
import os.path as osp

def save_gradcam(gcam,
                 raw_image,
                 target_layer,
                 ground_truth,
                 class_name,
                 paper_cmap=False):
            
    gcam = gcam.cpu().numpy()
    cmap = cm.jet_r(gcam)[..., :3] * 255.0
    raw_image = np.uint8(raw_image * 255.0)
    
    if paper_cmap:
        alpha = gcam[..., None]
        gcam = alpha * cmap + (1 - alpha) * raw_image
    else:
        gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2
    gcam = cv2.cvtColor(np.uint8(gcam),cv2.COLOR_RGB2BGR)
    
    return gcam

def save_all(input_data,directory):
    device = get_device()
    model.to(device)
    model.eval()

    bp = BackPropagation(model=model)
    gcam = GradCAM(model=model)
    topk = 1
    
    nrow = int(2*len(input_data) / 8)
    ncol = 8
    fig,ax = plt.subplots(nrow, ncol, figsize=(2*ncol,2*nrow))
    
    for k in range(len(input_data)):
        row = (2*k) // ncol
        col = (2*k) % ncol

        data = input_data[k]['data']
        label = input_data[k]['label']
        
        raw_image = tens2im(data)
        image = torch.stack([data]).to(device)

        ax[row,col].imshow(raw_image)
        ax[row,col].set_title(str(k))
        ax[row,col].axis('off')
        
        probs, ids = bp.forward(image)  # sorted
        _ = gcam.forward(image)

        p = ['' for i in range(topk)]
        for i in range(topk):

            # Grad-CAM
            gcam.backward(ids=ids[:, [i]])
            regions = gcam.generate(target_layer=target_layer)

            cam = save_gradcam(
                gcam=regions[0, 0],
                raw_image=raw_image,
                target_layer=target_layer,
                ground_truth=class_names[label],
                class_name=class_names[ids[0, i]],
                paper_cmap=False
            )
            ax[row,col+1].imshow(cam)
            ax[row,col+1].set_title(str(k)+' ('+class_names[ids[0, i]]+')')
            ax[row,col+1].axis('off')

    plt.show()
        
    fig.savefig('gcam_'+directory+'.png')
    
    gcam.remove_hook()

def save_all_old(input_data,directory):
    device = get_device()
    model.to(device)
    model.eval()

    bp = BackPropagation(model=model)
    gcam = GradCAM(model=model)
    topk = 2
    
    for k in range(len(input_data)):

        data = input_data[k]['data']
        label = input_data[k]['label']
        
        raw_image = tens2im(data)
        image = torch.stack([data]).to(device)

        fig,ax = plt.subplots(1,3,figsize=(8,5))
        ax[0].imshow(raw_image)
        ax[0].set_title('Image')
        ax[0].axis('off')
    
        probs, ids = bp.forward(image)  # sorted
        _ = gcam.forward(image)

        p = ['' for i in range(topk)]
        for i in range(topk):

            # Grad-CAM
            gcam.backward(ids=ids[:, [i]])
            regions = gcam.generate(target_layer=target_layer)

            cam = save_gradcam(
                gcam=regions[0, 0],
                raw_image=raw_image,
                target_layer=target_layer,
                ground_truth=class_names[label],
                class_name=class_names[ids[0, i]],
                paper_cmap=False
            )
            ax[1+ids[0, i]].imshow(cam)
            ax[1+ids[0, i]].set_title(class_names[ids[0, i]])
            ax[1+ids[0, i]].axis('off')
        
        plt.show()
        
        fig.savefig(os.path.join(directory,'gcam_'+str(k)+'.png'))
    
    gcam.remove_hook()

In [None]:
input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=True,tp=True,top_k=200,loader=validloader_deterministic)
save_all(input_data,'male_true_positive_test')

input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=True,tp=False,top_k=200,loader=validloader_deterministic)
save_all(input_data,'male_false_negative_test')

input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=False,tp=True,top_k=200,loader=validloader_deterministic)
save_all(input_data,'female_true_positive_test')

input_data = get_top_k_fast(scores,IDs,true_labels,is_correct,male=False,tp=False,top_k=200,loader=validloader_deterministic)
save_all(input_data,'female_false_negative_test')

# Generate similar results on training images

Get scores for training data

In [None]:
scores_tr,IDs_tr,true_labels_tr,is_correct_tr = get_scores(model, trainloader_deterministic)

Get top-k results on training data, then display and save

In [None]:
input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=True,tp=True,top_k=200,loader=trainloader_deterministic)
plot_top_k(input_data,title='True: Male / Predicted: Male')
!mkdir -p male_true_positive_train
save_top_k(input_data,'male_true_positive_train')

input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=True,tp=False,top_k=200,loader=trainloader_deterministic)
plot_top_k(input_data,title='True: Male / Predicted: Female')
!mkdir -p male_false_negative_train
save_top_k(input_data,'male_false_negative_train')

input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=False,tp=True,top_k=200,loader=trainloader_deterministic)
plot_top_k(input_data,title='True: Female / Predicted: Female')
!mkdir -p female_true_positive_train
save_top_k(input_data,'female_true_positive_train')

input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=False,tp=False,top_k=200,loader=trainloader_deterministic)
plot_top_k(input_data,title='True: Female / Predicted: Male')
!mkdir -p female_false_negative_train
save_top_k(input_data,'female_false_negative_train')

In [None]:
input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=True,tp=True,top_k=100,loader=trainloader_deterministic)
save_all(input_data,'male_true_positive_train')

input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=True,tp=False,top_k=100,loader=trainloader_deterministic)
save_all(input_data,'male_false_negative_train')

input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=False,tp=True,top_k=100,loader=trainloader_deterministic)
save_all(input_data,'female_true_positive_train')

input_data = get_top_k_fast(scores_tr,IDs_tr,true_labels_tr,is_correct_tr,male=False,tp=False,top_k=100,loader=trainloader_deterministic)
save_all(input_data,'female_false_negative_train')