## Applying different strategy of triplet loss

In [1]:
LOCAL = True

In [2]:
%load_ext autoreload
%autoreload 2
%load_ext skip_cell

In [3]:
%%skip $LOCAL
#Mounting the drive

import zipfile
from google.colab import drive

drive.mount('/content/drive/')

In [4]:
%%skip $LOCAL

!cp -a "/content/drive/My Drive/triplets/" .

Setting up tensorboard for PyTorch in Colab

In [5]:
%%skip $LOCAL

!pip install -q tf-nightly-2.0-preview
%load_ext tensorboard
import os
logs_base_dir = "runs"
os.makedirs(logs_base_dir, exist_ok=True)

# Imports

In [6]:
import copy
import random
import time
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, make_scorer
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
import torch
from torch import optim
from torch import nn
from torch.utils.data import random_split
from torch.utils.data.sampler import BatchSampler
# from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import SVHN
from torchvision.models import resnet18
from typing import Union
from PIL import Image

In [7]:
from triplets.datasets import TripletSVHN
from triplets.losses import TripletLoss, TripletSoftLoss, BatchAllTripletLoss, BatchHardTripletLoss, OnlineTripletLoss
from triplets.metrics import mean_average_precision
from triplets.nets import TripletNet
from triplets.train import train
from triplets.extractor import FeatureExtractor
from triplets.samplers import BalancedBatchSampler
from triplets.selectors import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, \
                               SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
from triplets.utils import freeze_layers

In [8]:
n_features = 512
n_classes = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SEED = 100
validation_split = 0.2
shuffle_dataset = True

In [9]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [10]:
dataset = SVHN(root = 'data/', download=True, split='train')
train_size = int(0.8*len(dataset))
valid_size = len(dataset) - train_size
dataset_train, dataset_valid = random_split(dataset, [train_size, valid_size])
dataset_test = SVHN(root = 'data/', download=True, split='test');

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat


# Training model with triplet loss

In [11]:
batch_size = 32
num_triplets = 1
epochs = 20

In [12]:
model_base = resnet18(pretrained=True)
model_base.eval();

Defining extractor using custom class to extract features from last cnn pretrained layer.

In [13]:
extractor = FeatureExtractor(model=model_base, n_remove_layers=1, n_features=n_features, device=device)
extracted_resnet = extractor.prepare_model()
extracted_resnet = freeze_layers(extracted_resnet, 2)

Freeze all the layers besides two last ones (pooling and two convolutional blocks)

In [14]:
for idx, child in enumerate(extracted_resnet.children()):
    if idx < 7:
        for param in child.parameters():
            param.requires_grad = False

In [15]:
preprocess = transforms.Compose([            
 transforms.Resize(256),                    
 transforms.CenterCrop(224),                
 transforms.ToTensor(),                     
 transforms.Normalize(                      
 mean=[0.485, 0.456, 0.406],                
 std=[0.229, 0.224, 0.225]                  
 )])

## Training with softmax triplet loss

In [16]:
triplet_train= TripletSVHN(dataset, dataset_train.indices, dataset_valid.indices,
                            preprocess, 'train', SEED)
triplet_valid = TripletSVHN(dataset, dataset_train.indices, dataset_valid.indices,
                            preprocess, 'val', SEED)

In [17]:
dataloader_train = torch.utils.data.DataLoader(triplet_train, batch_size=batch_size)
dataloader_valid = torch.utils.data.DataLoader(triplet_valid, batch_size=batch_size)
dataloaders = {'train': dataloader_train, 'val': dataloader_valid}

In [18]:
model = TripletNet(extracted_resnet)
criterion = TripletSoftLoss()
# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [19]:
%%skip $LOCAL

writer = SummaryWriter()

In [20]:
train(model, dataloaders, criterion, optimizer, scheduler, epochs, device)

Epoch 0/19


KeyboardInterrupt: 

Taken fron the paper titled "DEEP METRIC LEARNING USING TRIPLET NETWORK"
<br>
https://arxiv.org/pdf/1412.6622.pdf

## Training with triplet loss provided FaceNet paper

In [None]:
triplet_train= TripletSVHN(dataset, dataset_train.indices, dataset_valid.indices,
                            preprocess, 'train', SEED)
triplet_valid = TripletSVHN(dataset, dataset_train.indices, dataset_valid.indices,
                            preprocess, 'val', SEED)

In [None]:
dataloader_train = torch.utils.data.DataLoader(triplet_train, batch_size=batch_size)
dataloader_valid = torch.utils.data.DataLoader(triplet_valid, batch_size=batch_size)
dataloaders = {'train': dataloader_train, 'val': dataloader_valid}

In [None]:
model = TripletNet(extracted_resnet)
criterion = TripletLoss()
# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
%%skip $LOCAL

writer = SummaryWriter()

Optimizing margin with knn model chosen in previous experiment

In [None]:
train(model, dataloaders, criterion, optimizer, scheduler, epochs, device)

Defining triplet datasets and dataloaders

## Adding hard mining

In [30]:
extractor = FeatureExtractor(model=model_base, n_remove_layers=1, n_features=n_features, device=device)
extracted_resnet = extractor.prepare_model()
extracted_resnet = freeze_layers(extracted_resnet, 2)

In [34]:
type(extracted_resnet)

torch.nn.modules.container.Sequential

In [33]:
len(list(extracted_resnet.children()))

9

In [22]:
train_batch_sampler = BalancedBatchSampler(dataset.labels[dataset_train.indices], n_classes=10, n_samples=10)

In [23]:
dataset = SVHN(root = 'data/', download=True, split='train', transform=preprocess)
train_size = int(0.8*len(dataset))
valid_size = len(dataset) - train_size
dataset_train, dataset_valid = random_split(dataset, [train_size, valid_size])
dataset_test = SVHN(root = 'data/', download=True, split='test', transform=preprocess);

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat


In [24]:
# We'll create mini batches by sampling labels that will be present in the mini batch and number of examples from each class
train_batch_sampler = BalancedBatchSampler(dataset.labels[dataset_train.indices], n_classes=10, n_samples=25)
valid_batch_sampler = BalancedBatchSampler(dataset.labels[dataset_valid.indices], n_classes=10, n_samples=25)

online_train_loader = torch.utils.data.DataLoader(dataset_train, batch_sampler=train_batch_sampler)
online_valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_sampler=valid_batch_sampler)

margin = 1.
lr = 1e-3
optimizer = optim.Adam(extracted_resnet.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
n_epochs = 20
log_interval = 50

In [25]:
dataloaders = {'train': online_train_loader, 'val': online_valid_loader}

In [26]:
model = extracted_resnet
criterion = OnlineTripletLoss(margin, SemihardNegativeTripletSelector(margin))
# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [27]:
%%skip $LOCAL

writer = SummaryWriter()

In [28]:
test = next(iter(online_valid_loader))
data, labels = test

In [29]:
train(model, dataloaders, criterion, optimizer, scheduler, epochs, device)

Epoch 0/19


KeyboardInterrupt: 