In [None]:
from __future__ import print_function
import argparse
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from utils.triplet_image_loader import TripletImageLoader
from utils.openface import prepareOpenFace
from tripletnet import Tripletnet
from visdom import Visdom
import numpy as np
import dlib
from PIL import Image

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
test_loader = torch.utils.data.DataLoader(
    TripletImageLoader(
        'name_photoPaths_train.csv', 
        transform=transforms.Compose([
            transforms.Resize(96),
            transforms.CenterCrop(96),
            transforms.ToTensor(),
        ]), 
        triplets_per_individual = 100
    ),
    batch_size=200, num_workers=12)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = models.squeezenet1_1.features
        self.embedding = nn.Sequential(
            nn.Linear(2048, 512),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_max_pool2d(x, 2)
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        return self.embedding(x)

In [None]:
net = Net()
net.load_state_dict(torch.load('runs/TripletNet/model_best.pth.tar')['state_dict'])
if torch.cuda.is_available():
    net.cuda()

In [None]:
test_iter = iter(test_loader)

In [None]:
# switch to evaluation mode
net.eval()
anchor, positive, negative = next(test_iter)

if torch.cuda.is_available():
    anchor, positive, negative = anchor.cuda(), positive.cuda(), negative.cuda()
anchor = Variable(anchor, volatile=True)
positive = Variable(positive, volatile=True)
negative = Variable(negative, volatile=True)

# compute output
embedded_anchor = net(anchor)
embedded_positive = net(positive)
embedded_negative = net(negative)

pdist = nn.PairwiseDistance(p=2)
dist1 = pdist(embedded_anchor, embedded_positive)
dist2 = pdist(embedded_anchor, embedded_negative)

# Visualize triplets

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
for i, (a, p, n, d1, d2) in enumerate(zip(anchor, positive, negative, dist1, dist2)):
    a_img = np.uint8(a.data.cpu().numpy()*255).transpose(1, 2, 0)
    p_img = np.uint8(p.data.cpu().numpy()*255).transpose(1, 2, 0)
    n_img = np.uint8(n.data.cpu().numpy()*255).transpose(1, 2, 0)
    d1 = d1.data.cpu().tolist()[0]
    d2 = d2.data.cpu().tolist()[0]
    
    if d1 < d2:
        continue
    print('EXAMPLE %03d' % i, 'CORRECT' if d1 < d2 else '!!! FAIL')
    
    plt.figure()
    plt.subplot(131)
    plt.imshow(a_img)
    plt.title('Distance: 0')
    plt.subplot(132)
    plt.imshow(p_img)
    plt.title('Distance: %2.3f' % d1)
    plt.subplot(133)
    plt.imshow(n_img)
    plt.title('Distance: %2.3f' % d2)
    
    plt.show()