In [1]:
import os
import time
import copy
from collections import defaultdict
import csv

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import TripletMarginLoss
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import pandas as pd

import cv2
from PIL import Image

from tqdm import tqdm_notebook as tqdm

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize, LabelEncoder, OneHotEncoder

In [2]:
triplet_size = 5
image_size = 224
embedding_dim = 250
batch_size = 32

data_root = '../../data/humpback-whale-identification/'
train_csv = 'train.csv'
train_triplet_file = 'train_triplets.txt'
valid_triplet_file = 'valid_triplets.txt'

In [3]:
class TripletsDataset(Dataset):
    
    def __init__(self, data_root, df, triplet_file_path, transform=None):
        '''
            data_root: The root path of the data folder.
            df: Dataframe of the dataset
            transform: the transforms to preprocess the input dataset
        '''
        super(TripletsDataset, self).__init__()
        
        self.df = df 
        self.triplet_file_path = triplet_file_path
        
        # 1. Get image name and corresponding labels
        self.image_names, self.image_labels = self.df.Image.values, self.df.Id.values

        self.label_to_imagenames = defaultdict(list)       
        for image_name, label in zip(list(self.image_names), list(self.image_labels)):
            #if id_ == other_id:
            #    self.new_whale_list.append(imagename)
            #else:
            self.label_to_imagenames[label].append(image_name)
        
        
        # 2. Make triplets list
        self.triplets = []
        
        self.make_triplet_list()
        for line in open(os.path.join(data_root, self.triplet_file_path)):
            img1, img2, img3 = line.split()
            self.triplets.append([img1, img2, img3])
        
        # 3. Set data transform
        self.transform = transform
        
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, index):
        img1, img2, img3 = self.triplets[index][0], self.triplets[index][1], self.triplets[index][2]
        
        img1 = Image.open(os.path.join(data_root, 'train/'+img1)).convert('RGB')
        img2 = Image.open(os.path.join(data_root, 'train/'+img2)).convert('RGB')
        img3 = Image.open(os.path.join(data_root, 'train/'+img3)).convert('RGB')
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        
        return img1, img2, img3
    
    def make_triplet_list(self):
        
        print('Generating triplets...')
        
        triplets = []
        
        for label in set(self.image_labels):       
            anchor = None
            positive = None
            
            if len(self.label_to_imagenames[label]) > 1:
                positives = np.random.choice((self.label_to_imagenames[label]), size=2, replace=False)
                anchor = positives[0]
                positive = positives[1]
            
                negative = None
                while negative == None or self.df[self.df['Image'] == negative]['Id'].values == label:
                    negative = np.random.choice(self.df['Image'], 1)[0]
            
                triplets.append([anchor, positive, negative])
                
        with open(os.path.join(data_root, self.triplet_file_path), "w") as f:
            writer = csv.writer(f, delimiter=' ')
            writer.writerows(triplets)
        print('Done!')
        
        
    def prepare_labels(self, y):
        values = np.array(y)
        label_encoder = LabelEncoder()
        integer_encoded = label_encoder.fit_transform(values)
        
        onehot_encoder = OneHotEncoder(sparse=False)
        integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
        onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
        
        y = onehot_encoded
        
        return y, label_encoder

In [4]:
train_transform = transforms.Compose([transforms.Resize((image_size, image_size)),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                       std=[0.229, 0.224, 0.225])])

valid_transform = transforms.Compose([transforms.Resize((image_size, image_size)),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                       std=[0.229, 0.224, 0.225])])

In [5]:
data_df = pd.read_csv(os.path.join(data_root, train_csv))
train_df,  valid_df = train_test_split(data_df, train_size=0.7, test_size=0.3, random_state=43)

train_dataset = TripletsDataset(data_root, train_df, train_triplet_file, train_transform)
valid_dataset = TripletsDataset(data_root, valid_df, valid_triplet_file, valid_transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Generating triplets...
Done!
Generating triplets...
Done!


In [6]:
input1, input2, input3 = next(iter(train_dataloader))

In [7]:
input1.shape, input2.shape, input3.shape

(torch.Size([32, 3, 224, 224]),
 torch.Size([32, 3, 224, 224]),
 torch.Size([32, 3, 224, 224]))

Create the triplet model.

In [8]:
class TripletModel(nn.Module):
    
    def __init__(self):
        super(TripletModel, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        
        #for param in self.base_model.parameters():
        #    param.requires_grad = False
        
        self.base_model.fc = nn.Linear(2048, embedding_dim)
        
    def forward(self, x, y, z):
        x = self.base_model(x)
        y = self.base_model(y)
        z = self.base_model(z)
        
        return x, y, z

In [9]:
trip_model = TripletModel()

#criterion = torch.nn.MarginRankingLoss(margin = 0.2)
optimizer = optim.SGD(trip_model.parameters(), lr=0.001, momentum=0.5)

In [10]:
def train_model(model, dataloaders, optimizer, num_epochs=25):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    since = time.time()
    
    val_loss_history = []
    
    model.to(device)
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 999999
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            
            steps = 0
            
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            
            for inputs in tqdm(dataloaders[phase]):
                inputs[0] = inputs[0].to(device)
                inputs[1] = inputs[1].to(device)
                inputs[2] = inputs[2].to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    x, y, z = model(*inputs)
                    
                    trip_loss = TripletMarginLoss()
                    loss = trip_loss(x, y, z)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # statistics      
                running_loss += loss.item()
                
                steps += 1
            
            epoch_loss = running_loss / steps
            
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            
            # deep copy the model
            if phase == 'valid' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'valid':
                val_loss_history.append(best_loss)
        
        print()
    
    time_elapsed = time.time() - since
    
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:4f}'.format(best_loss))
    
    # load best model 
    model.load_state_dict(best_model_wts)
    
    return model, val_loss_history

In [11]:
dataloaders = {'train': train_dataloader, 'valid': valid_dataloader}

train_model(trip_model, dataloaders, optimizer)

Epoch 1/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.5944


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.4762

Epoch 2/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.3237


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.3578

Epoch 3/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.1914


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.3126

Epoch 4/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.1202


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2851

Epoch 5/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0765


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2716

Epoch 6/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0522


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2609

Epoch 7/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0333


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2643

Epoch 8/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0230


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2481

Epoch 9/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0184


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2466

Epoch 10/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0121


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2503

Epoch 11/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0109


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2399

Epoch 12/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0088


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2405

Epoch 13/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0061


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2405

Epoch 14/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0068


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2416

Epoch 15/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0057


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2378

Epoch 16/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0037


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2379

Epoch 17/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0051


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2365

Epoch 18/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0038


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2479

Epoch 19/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0038


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2223

Epoch 20/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0030


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2426

Epoch 21/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0034


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2223

Epoch 22/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0023


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2376

Epoch 23/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0028


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2393

Epoch 24/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0025


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2302

Epoch 25/25
----------


HBox(children=(IntProgress(value=0, max=67), HTML(value='')))


train Loss: 0.0016


HBox(children=(IntProgress(value=0, max=29), HTML(value='')))


valid Loss: 0.2216

Training complete in 14m 1s
Best val Loss: 999999.000000


(TripletModel(
   (base_model): ResNet(
     (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (layer1): Sequential(
       (0): Bottleneck(
         (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu): ReLU(inplace)
         (downsample): Sequenti