# Load datasets

Download datasets from bucket

In [1]:
# !gsutil -m cp -r gs://kex-data/casia-144000 data/ > logs.txt 2>&1
# !gsutil -m cp -r gs://kex-data/digiface_subjects_0-1999_72_imgs/ data/ > logs2.txt 2>&1
# !rm data/casia-144000/.DS_Store

In [4]:
import os
print(len(os.listdir('data/casia-144000')))
print(len(os.listdir('data/digiface_subjects_0-1999_72_imgs')))

2094
2000


In [5]:
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

train_transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    # transforms.RandomApply([transforms.Resize(112),
    #                         transforms.RandomCrop((112, 96))]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),  
    # transforms.RandomErasing()
])

# data_path = 'data/casia-webface'
# data_path = 'data/webface-10'
train_data_path = 'data/casia-144000'
train_data_path_test = 'data/casia-144000'

batch_size = 256

train_data = datasets.ImageFolder(train_data_path, transform=train_transform)

train_loader = torch.utils.data.DataLoader(
    train_data, shuffle=True, batch_size=batch_size)

train_data_test = datasets.ImageFolder(train_data_path_test, transform=train_transform)

train_loader_test = torch.utils.data.DataLoader(
    train_data_test, shuffle=True, batch_size=batch_size)

In [6]:
# Make sure dataset is loaded correctly

In [7]:
iterator = iter(train_loader)
image, label = next(iterator)

print("image", image[0].shape)
print("label", label[0])

image torch.Size([3, 112, 112])
label tensor(524)


In [8]:
# Init the model

In [9]:
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn, optim
import os
# model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model = resnet18()
num_classes = len(os.listdir(train_data_path))
print(num_classes)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

model.fc = nn.Linear(512, num_classes)
model.to(device)

2094
cuda


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [10]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    # img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [11]:
from PIL import Image

# img = Image.open('./data/webface-10/000000/00000001.jpg')
# timg = transforms.ToTensor()(img)
# imshow(timg)
# print(device)
# timg = timg.unsqueeze(0).to(device)
# # model.to(device)
# print(model(timg))

Load LFW

In [12]:
batch_size_test = 256

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),  
])

lfw_test = datasets.LFWPairs(root='data', split='10fold', transform=test_transform, download=True)

test_loader = DataLoader(lfw_test, batch_size=batch_size_test, shuffle=True)

Files already downloaded and verified


In [13]:
print(len(lfw_test))
print(len(test_loader))

6000
24


Feature extractor:

In [14]:
import torch.nn.functional as F

def model_to_extractor(model):
    model.eval()
    feature_map = list(model.children())
    feature_map.pop()  # remove the final "class prediction" layer
    extractor = nn.Sequential(*feature_map) # create feature extractor
    # Inspect the structure - it is a nested list of various modules
    # print extractor[-1]       
    # print extractor[-2][-1]   
    return extractor

extractor = model_to_extractor(model)

def extract_features():
    labels_list = []
    extractor.to(device)
    extractor.eval()
    feat_pair1 = []
    feat_pair2 = []
    n = 0

    with torch.no_grad():
        for (imgs1, imgs2, labels) in iter(test_loader):
            imgs1, imgs2 = imgs1.to(device), imgs2.to(device)
            imgs1_out = extractor(imgs1)
            imgs2_out = extractor(imgs2)
            feat_pair1.append(imgs1_out.data)
            feat_pair2.append(imgs2_out.data)
            labels_list.append(labels.data)
            if (n == 10):
                print('done extracting for iteration', n)
        
    return feat_pair1, feat_pair2, labels_list


In [15]:
def tlist_to_squeezed(tensor_list):
    #Merge the array of tensors into one tensor
    merged_tensor = torch.cat(tensor_list, dim=0)
    return merged_tensor.squeeze() #Remove the extra dimension

def tlist_distance(tlist1, tlist2):
    tlist1 = tlist_to_squeezed(tlist1)
    tlist2 = tlist_to_squeezed(tlist2)
    return torch.norm(tlist1 - tlist2, dim=1)

Validate the model with the ROC AUC method

In [16]:
import matplotlib.pyplot as plt
import sklearn.metrics

def get_test_accuracy(should_print=False):
    feat_pair1, feat_pair2, labels = extract_features()
    feature_distance = tlist_distance(feat_pair1, feat_pair2)
    feat_dist = feature_distance.cpu().numpy()

    # Eval metrics
    scores = -feat_dist
    gt = torch.cat(labels).numpy()

    roc_auc = sklearn.metrics.roc_auc_score(gt, scores)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(gt, scores)
    fig_path = "./figures/roc_auc.png"
    
    if should_print:
        print ('ROC-AUC: %.04f' % roc_auc)
    # Plot and save ROC curve
    plot_roc(fpr, tpr, roc_auc, fig_path)
    return roc_auc

def plot_roc(fpr, tpr, roc_auc, fig_path, show=False):
    plt.title('ROC - lfw dev-test')
    plt.plot(fpr, tpr, lw=2, label='ROC (auc = %0.4f)' % roc_auc)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.grid()
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.savefig(fig_path, bbox_inches='tight')
    if show:
        plt.show()

# get_test_accuracy()

Validate model with the synface method 

In [17]:
import time 

def extractDeepFeature(img, model, device):
    # model.to(device)
    img = img.to(device)
    fc = model(img)
    fc = fc.to('cpu').squeeze()
    # fc = fc.to('cpu')
    
    return fc

def KFold(n=6000, n_folds=10):
    folds = []
    base = list(range(n))
    for i in range(n_folds):
        test = base[int(i * n / n_folds):int((i + 1) * n / n_folds)]
        train = list(set(base) - set(test))
        folds.append([train, test])
    return folds


def eval_acc(threshold, diff):
    y_true = []
    y_predict = []
    for d in diff:
        same = 1 if float(d[0]) > threshold else 0
        y_predict.append(same)
        y_true.append(int(d[1]))
    y_true = np.array(y_true)
    y_predict = np.array(y_predict)
    accuracy = 1.0 * np.count_nonzero(y_true == y_predict) / len(y_true)
    return accuracy


def find_best_threshold(thresholds, predicts):
    best_threshold = best_acc = 0
    for threshold in thresholds:
        accuracy = eval_acc(threshold, predicts)
        if accuracy >= best_acc:
            best_acc = accuracy
            best_threshold = threshold
    return best_threshold

def cosine_similarity(f1, f2):
    # compute cosine_similarity for 2-D array
    f1 = f1.numpy()
    f2 = f2.numpy()

    A = np.sum(f1*f2, axis=1)
    B = np.linalg.norm(f1, axis=1) * np.linalg.norm(f2, axis=1) + 1e-5

    return A / B

def compute_distance(img1, img2, model, flag, device):
        f1 = extractDeepFeature(img1, model, device)
        f2 = extractDeepFeature(img2, model, device)

        distance = cosine_similarity(f1, f2)

        flag = flag.squeeze().numpy()
        return np.stack((distance, flag), axis=1)

def obtain_acc(predicts, num_class, start, should_print=False):
    accuracy = []
    thd = []
    folds = KFold(n=num_class, n_folds=10)
    thresholds = np.arange(-1.0, 1.0, 0.005)
    
    for idx, (train, test) in enumerate(folds):
        best_thresh = find_best_threshold(thresholds, predicts[train])
        accuracy.append(eval_acc(best_thresh, predicts[test]))
        thd.append(best_thresh)
    end = time.time()
    time_used = (end - start) / 60.0
    # logger.info('LFW_ACC={:.4f} std={:.4f} thd={:.4f} time_used={:.4f} mins'.format(np.mean(accuracy), np.std(accuracy), np.mean(thd), time_used))
    if should_print:
        print('LFW_ACC={:.4f} std={:.4f} thd={:.4f} time_used={:.4f} mins'.format(np.mean(accuracy), np.std(accuracy), np.mean(thd), time_used))
    return np.mean(accuracy)


def eval(model, num_classes, test_loader, device, should_print=False):
    predicts = np.zeros(shape=(len(test_loader.dataset), 2))

    model.eval()
    start = time.time()

    cur = 0
    with torch.no_grad():
        for (img1, img2, flag) in iter(test_loader):
            predicts[cur:cur+flag.shape[0]] = compute_distance(img1, img2, model, flag, device)
            cur += flag.shape[0]
    assert cur == predicts.shape[0]

    accuracy = obtain_acc(predicts, num_classes, start)

    # visualize the masks stats
    if should_print:
        print('LFW_ACC', np.mean(accuracy))

    return np.mean(accuracy), predicts

# Train and validate the model

In [None]:
import time
import datetime
import torch
import torch.nn as nn
import torch.optim as optim

learning_rate = 0.1
num_epochs = 40
weight_decay = 0.0
momentum = 0.9

output_to_file = True
epoch_output_threshold = 20

#The learning rate starts from 0.1, and is divided by 10 at the 24, 30
#and 36 epochs, with 40 epochs in total.

lr_milestones = [6, 12, 18, 36]
lr_gamma = 0.1

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),
                    lr=learning_rate,
                    momentum=momentum,
                    weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=lr_milestones, gamma=lr_gamma)

if output_to_file:
    outputfile = open('output_'+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+'.txt', 'w')

def output(text):
    if output_to_file:
        outputfile.write(text + '\n')
    else:
        print(text)

def get_train_accuracy(model: nn.Module, test=True):
    correct = 0
    total = 0
    n = 0
    loader = train_loader_test if test else train_loader
    with torch.no_grad():
        for imgs, labels in iter(loader):
            imgs, labels = imgs.to(device), labels.to(device)
            model.eval()
            output = model(imgs)
            pred = output.max(1)[1]
            correct += pred.eq(labels).sum().item()
            total += imgs.shape[0]
            n += 1
    return correct / total 

for epoch in range(num_epochs):
    n = 0
    epoch_tic = time.perf_counter()
    for imgs, labels in iter(train_loader):
        tic = time.perf_counter()
        imgs = imgs.to(device)
        labels = labels.to(device)

        model.train()

        out = model(imgs)
        loss = loss_fn(out, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        n += 1
        if n % 50 == 0:
            toc = time.perf_counter()
            output('epoch: {}, iter: {}, loss: {}, time: {}'.format(epoch, n, loss, toc - tic))
            outputfile.flush()
    if epoch > 0 and epoch % 0 == 0:
        output('train acc: {}'.format(get_train_accuracy(model, test=False)))
        output(get_test_accuracy())
        output(eval(model, len(lfw_test), test_loader, device)[0])
    lr_scheduler.step() 
    epoch_toc = time.perf_counter()
    output('epoch time: {}'.format(epoch_toc - epoch_tic))
    outputfile.flush()
output('train acc: {}'.format(get_train_accuracy(model, test=False)))
output('roc auc: {}'.format(get_test_accuracy()))
eval(model, len(lfw_test), test_loader, device)[0]

if output_to_file:
    outputfile.close()