In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import Function
import numpy as np
import random
from torchvision.models import resnet18
from torchvision import transforms
from torchvision import datasets
from PIL import Image
use_cuda = False

In [4]:
def pairwise_distance(x1, x2):
    diff = torch.abs(x1 - x2)
    return torch.pow(diff, 2).sum(dim=1)

In [5]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
            positive_dist = pairwise_distance(anchor, positive)
            negative_dist = pairwise_distance(anchor, negative)
            loss = positive_dist - negative_dist + self.margin
            relu = nn.ReLU()
            loss_final = relu(loss).mean()
            return loss_final


def triplet_loss(anchor, positive, negative, margin=0.2):
    t_loss = TripletLoss(margin)
    loss = t_loss(anchor, positive, negative)
    return loss

In [6]:
def generate_random_triplets(class_map, num_triplets, num_classes):
    triplets = []
    
    for i in range(num_triplets):
        anchor_class = random.randint(0, num_classes-1)
        while len(class_map[anchor_class]) < 2:
            # Need anchor with atleast two images
            anchor_class = random.randint(0, num_classes-1)
        negative_class = random.randint(0, num_classes-1)
        while negative_class == anchor_class:
            # Need anchor with atleast two images
            negative_class = random.randint(0, num_classes-1)
        
        anchor, positive = random.sample(set(class_map[anchor_class]), 2)
        negative = random.sample(set(class_map[negative_class]), 1)[0]
        
        triplets.append([anchor, positive, negative, anchor_class, negative_class])
    
    return triplets
        
class TripletDataset(datasets.ImageFolder):
    def __init__(self, root, num_triplets, transform=None, *arg, **kw):
        super(TripletDataset, self).__init__(root, transform)
        
        # print(self.imgs)
        # print(self.classes)
        self.num_triplets = num_triplets
        
        class_map = {}
        #Map class_idx:[image_idx, image_idx2,...]
        for idx, (image_path, class_label) in enumerate(self.imgs):
            if class_label not in class_map:
                class_map[class_label] = []
            class_map[class_label].append(image_path)
        self.class_map = class_map
#         print(class_map)
        
        self.triplets = generate_random_triplets(class_map, self.num_triplets, len(self.classes))
#         print(self.triplets)

    def __len__(self):
        return len(self.triplets)                
    
    def __getitem__(self, item):
        def get_image(image_path):
            image = self.loader(image_path)
            return self.transform(image)
        
        anchor, positive, negative, anchor_class, negative_class = self.triplets[item]
        anchor_image = get_image(anchor)
        positive_image = get_image(positive)
        negative_image = get_image(negative)
        
        return anchor_image, positive_image, negative_image, anchor_class, negative_class

transform = transforms.Compose([transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data_loader = TripletDataset("./dataset/lfw/lfw-train", 20, transform)

In [7]:
class FaceNet(nn.Module):
    def __init__(self, embedding_dimensions=64):
        super(FaceNet, self).__init__()
        self.embedding_dimensions = embedding_dimensions
        self.model = resnet18(pretrained = True)
        resnet_fc_in = self.model.fc.in_features
        self.model.fc = nn.Linear(resnet_fc_in, embedding_dimensions)

        #Initialize weights
        self.model.fc.weight.data.normal_(0.0, 0.02)
        self.model.fc.bias.data.fill_(0)
        
        self.batch_size = 10
        self.num_triplets_train = self.batch_size
        self.num_triplets_test = self.num_triplets_train // 10
        self.lr = 1e-4
        self.loss_fn = triplet_loss
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr)

    def forward(self, input_images):
        output = self.model(input_images)
        return output

    def train(self):
        transform = transforms.Compose([transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
        train_data_loader = TripletDataset("./dataset/lfw/train_dataset", self.num_triplets_train, transform)
        train_data = np.array(train_data_loader.imgs)[:, 0]
        train_images = torch.stack([transform(Image.open(image_path)) for image_path in train_data])
        train_labels = np.array(train_data_loader.imgs)[:, 1]
        test_data_loader = TripletDataset("./dataset/lfw/test_dataset", self.num_triplets_test, transform)
        test_data = np.array(test_data_loader.imgs)[:, 0]
        test_images = torch.stack([transform(Image.open(image_path)) for image_path in test_data])
        test_labels = np.array(test_data_loader.imgs)[:, 1]
        
        break_batches = False
        total_epochs = 100
        last_saved_epoch = 0
        
        for epoch in range(last_saved_epoch, total_epochs):
            self.model.train()
            train_loss = 0
            print("Epoc: ", epoch)
            for idx, (anchor_img, pos_img, neg_img, anchor_class, neg_class) in enumerate(train_data_loader):
                anchor_img, pos_img, neg_img = anchor_img.unsqueeze(0), pos_img.unsqueeze(0), neg_img.unsqueeze(0)
                if torch.cuda.is_available() and use_cuda:
                    anchor_img, pos_img, neg_img = anchor_img.cuda(), pos_img.cuda(), neg_img.cuda()
                anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img)
                anchor_emb, pos_emb, neg_emb = self.forward(anchor_img), self.forward(pos_img), self.forward(neg_img)
                loss = triplet_loss(anchor_emb, pos_emb, neg_emb)
                train_loss += loss
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(loss)
                if break_batches:
                    break
                
            print("Train %.2f"% (train_loss))
            
            if epoch % 10 == 0:                                
                test_loss = 0
                self.model.eval()
                correct_ct, total_ct = 0, 0
                with torch.no_grad():
                    #Get updated embeddings
                    train_embeddings = self.forward(train_images)
                    test_embeddings = self.forward(test_images)
                    for test_embedding, test_truth in zip(test_embeddings, test_labels):
                        dist = torch.pow(train_embeddings - test_embedding, 2).sum(1)
                        train_index = torch.argmin(dist).tolist()
                        pred_label = train_labels[train_index]
                        if(pred_label == test_truth):
                            correct_ct += 1
                    total_ct = len(test_embeddings)
                    accuracy = correct_ct / total_ct
                    
                    for idx, (anchor_img, pos_img, neg_img, anchor_class, neg_class) in enumerate(test_data_loader):
                        anchor_img, pos_img, neg_img = anchor_img.unsqueeze(0), pos_img.unsqueeze(0), neg_img.unsqueeze(0)
                        if torch.cuda.is_available() and use_cuda:
                            anchor_img, pos_img, neg_img = anchor_img.cuda(), pos_img.cuda(), neg_img.cuda()
                        anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img)
                        anchor_emb, pos_emb, neg_emb = self.forward(anchor_img), self.forward(pos_img), self.forward(neg_img)
                        loss = triplet_loss(anchor_emb, pos_emb, neg_emb)
                        test_loss += loss
                        if break_batches:
                            break
                    
#                     print("Test Loss %.2f, Accuracy : %.2f" % (test_loss, accuracy))
                    print("Test Loss %.2f" % (test_loss))
    

In [8]:
fn = FaceNet()
fn.train()

Epoc:  0
tensor(0.2010, grad_fn=<MeanBackward1>)
tensor(0.2799, grad_fn=<MeanBackward1>)
tensor(0.2774, grad_fn=<MeanBackward1>)
tensor(0.1899, grad_fn=<MeanBackward1>)
tensor(0.2536, grad_fn=<MeanBackward1>)
tensor(0.1186, grad_fn=<MeanBackward1>)
tensor(0.2191, grad_fn=<MeanBackward1>)
tensor(0.2085, grad_fn=<MeanBackward1>)
tensor(0.1458, grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
Train 1.89
Test Loss 0.00
Epoc:  1
tensor(0.2429, grad_fn=<MeanBackward1>)
tensor(0.2068, grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0.0820, grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0.1843, grad_fn=<MeanBackward1>)
tensor(0.1432, grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
Train 0.86
Epoc:  2
tensor(0.1757, grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0.0379, grad_fn=<MeanBackward1>)
tensor(0

tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
Train 0.00
Epoc:  22
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
Train 0.00
Epoc:  23
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<MeanBackward1>)
tensor(0., grad_fn=<Me

KeyboardInterrupt: 

In [None]:
# fn = FaceNet()
# fn.train()
# image1_pil = Image.open("./dataset/lfw/lfw-deepfunneled/Bob_Graham/Bob_Graham_0002.jpg")
# image2_pil = Image.open("./dataset/lfw/lfw-deepfunneled/Bob_Graham/Bob_Graham_0002.jpg")
# normalizer =  transforms.Compose([transforms.Resize((224, 224)), 
#                                     transforms.ToTensor(),
#                                     transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
#                                 ])
# image1 = normalizer(image1_pil)
# print(image1.shape)
# input_images = image1.unsqueeze(0)
# print(input_images.shape)
# x = fn.forward(input_images)
# print(x)
# print(x.shape)

# print("#2")
# images = [image1_pil, image2_pil]
# images = torch.stack([normalizer(image) for image in images])
# print(images.shape)
# x = fn.forward(images)
# print(x)
# print(x.shape)