In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from timeit import default_timer as timer
from datetime import timedelta
import pandas as pd

plt.ion()   # interactive mode

In [2]:
directory = 'C:\\Data_Competitions\\Facebook image matching\\FB_image_matching_competition\\'
data_directory = directory + 'data\\'
training_image_path = data_directory + 'training_images\\'
ref_image_path = data_directory + 'reference_images\\'
query_image_path = data_directory + 'query_images\\'
ground_truth_csv = directory + 'public_ground_truth.csv'
from FBImageTriplet import FBImgMatchingDataSetTriplet
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [3]:
from resnet_triplet import Resnet18Triplet, Resnet34Triplet

model = Resnet18Triplet()
#print(model)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda:0


Resnet18Triplet(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

In [4]:
def set_optimizer(optimizer, model, learning_rate):
    if optimizer == "sgd":
        optimizer_model = optim.SGD(
            params=model.parameters(),
            lr=learning_rate,
            momentum=0.9,
            dampening=0,
            nesterov=False,
            weight_decay=1e-5
        )

    elif optimizer == "adagrad":
        optimizer_model = optim.Adagrad(
            params=model.parameters(),
            lr=learning_rate,
            lr_decay=0,
            initial_accumulator_value=0.1,
            eps=1e-10,
            weight_decay=1e-5
        )

    elif optimizer == "rmsprop":
        optimizer_model = optim.RMSprop(
            params=model.parameters(),
            lr=learning_rate,
            alpha=0.99,
            eps=1e-08,
            momentum=0,
            centered=False,
            weight_decay=1e-5
        )

    elif optimizer == "adam":
        optimizer_model = optim.Adam(
            params=model.parameters(),
            lr=learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            amsgrad=False,
            weight_decay=1e-5
        )

    return optimizer_model



# Set optimizer
optimizer_model = set_optimizer(
    optimizer="sgd",
    model=model,
    learning_rate=0.075
)

In [5]:
# load existing model dict:
resume_path = directory + 'triplet_loss\\resnet18_semihard12.pt'
checkpoint = torch.load(resume_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer_model.load_state_dict(checkpoint['optimizer_model_state_dict'])

In [6]:
from torch.nn.modules.distance import PairwiseDistance

def generate_triplets(anchor_embedding, pos_embedding, neg_embedding, margin = 0.3, use_semihard_negatives = True):
    """
    Args:
        anchor_embedding: torch tensor, size: ([number_of_anchor_embeddings, embedding_dimension]), stored in cuda
        pos_embedding: torch tensor, size: ([number_of_pos_embeddings(in this dataset, it is the same as number of anchors), embedding_dimension]), stored in cuda
        neg_embedding: torch tensor, size: ([number_of_negative_embeddings, embedding_dimension]), stored in cuda
        margin: used to choose the right semi-hard triplet, default is 0.3
            
    This function is used to select one or more valid smihard triplets for each (anchor_embedding, pos_embedding) pair.
    all 3 returned tensor is of the same size: ([number_of_chosen_embeddings, embedding_dimension])
    """
    pdist = PairwiseDistance(p=2)
    pos_dist = pdist.forward(anchor_embedding, pos_embedding) # calculate pairwise L2 distance between anchor_embedding and pos_embedding, return tensor size: torch.Size([number of anchor/pos embeddings])
    neg_dist = torch.cdist(anchor_embedding, neg_embedding) # calculate L2 distance between possible anchor_embeding and all neg_embedding, return tensor size:torch.Size([number of anchor/pos embeddings, number of neg embeddings])
    
    # reshape pos_dist
    pos_dist_reshape = pos_dist[:,None] - neg_dist + neg_dist  # repeat the pos_dist, copy each row 
    
    first_condition = neg_dist - pos_dist_reshape < margin
    if use_semihard_negatives:
    # semihard triplets
        second_condition = pos_dist_reshape < neg_dist
        all_condition = torch.logical_and(first_condition, second_condition)
    else:
    # hard triplets
        all_condition = first_condition
        
    # triplets_index: tuple(tensor[] (size: number of triplets), tensor[] (size: number of triplets)), 
    # this is the index for all valid entries chosen in neg_dists and reshaped pos_dist, the first entry would be the row index and second is the column. 
    # as can be seen, the row index here is the corresponding index for anchor/pos embedding, the column index is the corresponding index for negative embedding
    triplets_index = torch.where(all_condition == 1)  
    
    # use the row index to get selected anchor embeddings and positive embeddings
    selected_anchor_embeddings = anchor_embedding[triplets_index[0]]
    selected_pos_embeddings = pos_embedding[triplets_index[0]]
    
    # use the column index to get select negatie embeddings
    selected_neg_embeddings = neg_embedding[triplets_index[1]]
    
    # print("selected {} semihard triplets".format(selected_anchor_embeddings.size(0)))  # these 3 selected embeddings should have the same size, so doesn't matter which one to use
        
    return selected_anchor_embeddings, selected_pos_embeddings, selected_neg_embeddings  

In [7]:
from FBImageTriplet import FBImgMatchingDataSetTriplet, TripletLoss
dataset = FBImgMatchingDataSetTriplet(query_image_path, ref_image_path, training_image_path, ground_truth_csv, data_transforms['train'])

# train model
tt_epoch = 12
margin = 0.3
model.train()
for epoch in range(tt_epoch):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=25, shuffle=False, num_workers=2)
    progress_bar = tqdm(dataloader)
    num_valid_training_triplets = 0
    total_triplet_loss = 0
    number_batches = 0
    for anchor_imgs, pos_imgs, neg_imgstacks, idxs in progress_bar:
        # get all images from this mini batch
        anchor_img_size = anchor_imgs.size(0)
        pos_img_size = pos_imgs.size(0)
        neg_imgs = torch.flatten(neg_imgstacks, start_dim=0, end_dim=1)   # reshape the neg imgs as each index return multiple negative images, so we need to multiply the batch size. example output dimension: [10,3,224,224]
        neg_img_size = neg_imgs.size(0)
        all_imgs = torch.cat((anchor_imgs,pos_imgs, neg_imgs))  # example output dimension  [120,3,224,224]
        all_imgs = all_imgs.to(device)
        
        # feed all image to the model and get the corresponding embedding
        embeddings = model(all_imgs)  # example output size: torch.Size([120, 512])
        anchor_embedding = embeddings[:anchor_img_size]
        pos_embedding = embeddings[anchor_img_size:anchor_img_size + pos_img_size]
        neg_embedding = embeddings[anchor_img_size + pos_img_size:]
        
        # generate triplets
        selected_anchor_embeddings, selected_pos_embeddings, selected_neg_embeddings = generate_triplets(anchor_embedding, pos_embedding, neg_embedding, margin = margin)
        
        # calculate triplet loss
        triplet_loss = TripletLoss(margin=margin).forward(
            anchor=selected_anchor_embeddings,
            positive=selected_pos_embeddings,
            negative=selected_neg_embeddings
        )
        
        # calculate statistics
        num_valid_training_triplets += selected_anchor_embeddings.size(0)
        total_triplet_loss += triplet_loss.item()
        number_batches += 1
        
        # backward pass
        optimizer_model.zero_grad()
        triplet_loss.backward()
        optimizer_model.step()
        
    # output log
    print("epoch: {}, total triplets: {}, average loss per batch: {}".format(epoch, num_valid_training_triplets, total_triplet_loss/number_batches))
    
# store model
state = {
    'epoch': tt_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_model_state_dict': optimizer_model.state_dict()
}
    
torch.save(state, 'resnet34_semihard12.pt')

detect 50000 jpg images under query directory C:\Data_Competitions\Facebook image matching\FB_image_matching_competition\data\query_images\
detect 1000000 jpg images under reference directory C:\Data_Competitions\Facebook image matching\FB_image_matching_competition\data\reference_images\


  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

detect 1000000 jpg images under directory C:\Data_Competitions\Facebook image matching\FB_image_matching_competition\data\training_images\
detect 4991 number of ground truth pairs


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:30<00:00,  1.65s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 0, total triplets: 197227, average loss per batch: 0.14587855949997902


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:28<00:00,  1.64s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 1, total triplets: 193997, average loss per batch: 0.14512887582182885


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:35<00:00,  1.68s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 2, total triplets: 193931, average loss per batch: 0.144635114595294


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:26<00:00,  1.63s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 3, total triplets: 191736, average loss per batch: 0.14432828724384308


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:23<00:00,  1.62s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 4, total triplets: 189784, average loss per batch: 0.14409638337790967


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:28<00:00,  1.64s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 5, total triplets: 188776, average loss per batch: 0.14460732735693455


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 6, total triplets: 187818, average loss per batch: 0.14440894782543182


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [06:11<00:00,  1.86s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 7, total triplets: 186284, average loss per batch: 0.14465362809598445


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:44<00:00,  1.72s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 8, total triplets: 186588, average loss per batch: 0.14404310181736946


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:35<00:00,  1.68s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 9, total triplets: 185658, average loss per batch: 0.14445452190935612


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:34<00:00,  1.67s/it]
  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

epoch: 10, total triplets: 184480, average loss per batch: 0.14378141567111016


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:37<00:00,  1.69s/it]

epoch: 11, total triplets: 183896, average loss per batch: 0.143377720490098





In [8]:
torch.cuda.empty_cache()

In [9]:
import gc
gc.collect()

118