In [None]:
from __future__ import print_function
import argparse
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from utils.triplet_image_loader import TripletImageLoader
from utils.openface import prepareOpenFace
from tripletnet import Tripletnet
from visdom import Visdom
import numpy as np
import dlib
from PIL import Image

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
test_loader = torch.utils.data.DataLoader(
    TripletImageLoader(
        'name_thumbPaths_test.csv', 
        transform=transforms.Compose([
            transforms.Resize(96),
            transforms.CenterCrop(96),
            transforms.ToTensor(),
        ]), 
        triplets_per_individual = 100
    ),
    batch_size=200, num_workers=12)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = models.squeezenet1_1().features
        self.embedding = nn.Sequential(
            nn.Linear(2048, 512),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_max_pool2d(x, 2)
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        x = self.embedding(x)
        return x

In [None]:
net = Net()
net.load_state_dict(torch.load('runs/TripletNet/model_best.pth.tar')['state_dict'])
if torch.cuda.is_available():
    net.cuda()

margin = .5

In [None]:
test_iter = iter(test_loader)

In [None]:
def normalize(x):
    return x / x.norm(2, dim=1, keepdim=True)

In [None]:
# switch to evaluation mode
net.eval()
anchor, positive, negative = next(test_iter)

if torch.cuda.is_available():
    anchor, positive, negative = anchor.cuda(), positive.cuda(), negative.cuda()
anchor = Variable(anchor, volatile=True)
positive = Variable(positive, volatile=True)
negative = Variable(negative, volatile=True)

# compute output
embedded_anchor = net(anchor)
embedded_positive = net(positive)
embedded_negative = net(negative)

# normalize embeddings
norm_anc = normalize(embedded_anchor)
norm_pos = normalize(embedded_positive)
norm_neg = normalize(embedded_negative)


sim = torch.nn.CosineSimilarity(dim=1)
correct_similarity = sim(norm_anc, norm_pos)
wrong_similarity = sim(norm_anc, norm_neg)

pdist = nn.PairwiseDistance(p=2)
euc_pos = pdist(norm_anc, norm_pos)
euc_neg = pdist(norm_anc, norm_neg)

In [None]:
# This should be equal...
euc_pos.squeeze() - (2 - 2*correct_similarity)

In [None]:
def countCosineCorrect(margin):
    true_pos = (correct_similarity > margin)
    true_neg = (wrong_similarity < margin)
    both = torch.mul(true_neg, true_pos)
    return true_pos.sum().data[0], true_neg.sum().data[0], both.sum().data[0]

def countEucledianCorrect(margin):
    true_pos = (euc_pos < margin)
    true_neg = (euc_neg > margin)
    both = torch.mul(true_neg, true_pos)
    return true_pos.sum().data[0], true_neg.sum().data[0], both.sum().data[0]
    

# Visualize triplets

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['figure.figsize'] = 16, 8

In [None]:
M = np.linspace(0, 1, 1000)
precision_vs_margin = [countCosineCorrect(m) for m in M]
plt.plot(M, precision_vs_margin)
best_cos_margin = M[np.argmax(precision_vs_margin, axis=0)[2]]
plt.title('Cosine distance test, max performance at MARGIN=%1.4f'%best_cos_margin)
plt.show()

In [None]:
M = np.linspace(0, 2, 1000)
precision_vs_margin = [countEucledianCorrect(m) for m in M]
plt.plot(M, precision_vs_margin)
plt.legend(['True positive', 'True negative', 'AND'])
best_euc_margin = M[np.argmax(precision_vs_margin, axis=0)[2]]
plt.title('Eucledian distance test, max performance at MARGIN=%1.4f'%best_euc_margin)
plt.show()

In [None]:
euc_margin = .93
for i, (a, p, n, d1, d2) in enumerate(zip(anchor, positive, negative, euc_pos, euc_neg)):
    a_img = np.uint8(a.data.cpu().numpy()*255).transpose(1, 2, 0)
    p_img = np.uint8(p.data.cpu().numpy()*255).transpose(1, 2, 0)
    n_img = np.uint8(n.data.cpu().numpy()*255).transpose(1, 2, 0)
    d1 = d1.data[0]
    d2 = d2.data[0]
    
    correct = d1 < euc_margin and d2 > euc_margin
    
    print('EXAMPLE %03d' % i, 'CORRECT' if correct else '!!! FAIL')
    
    plt.figure()
    plt.subplot(131)
    plt.imshow(a_img)
    plt.title('Distance: 0')
    plt.subplot(132)
    plt.imshow(p_img)
    plt.title('Distance: %2.3f' % d1)
    plt.subplot(133)
    plt.imshow(n_img)
    plt.title('Distance: %2.3f' % d2)
    
    plt.show()