In [35]:
import os
import time
import copy
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import pandas as pd

from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import normalize

In [28]:
DATA_PATH = '../../data/humpback-whale-identification/tain.csv'

image_size = 224
embedding_dim = 50

In [23]:
class TripletModel(nn.Module):
    
    def __init__(self):
        super(TripletModel, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        self.base_model.fc = nn.Linear(2048, embedding_dim=50)
        
    def forward(self, input):
        output1 = self.base_model(input[:,0,:,:,:])
        output2 = self.base_model(input[:,1,:,:,:])
        output3 = self.base_model(input[:,2,:,:,:])
        
        return output1, output2, output3

In [21]:
def triplet_loss(inputs, dist='sqeuclidean', margin='maxplus'):
    anchor, positive, negative = inputs
    positive_distance = np.square(anchor - positive)
    negative_distance = np.square(anchor - negative)
    
    if dist == 'euclidean':
        positive_distance = np.sqrt(np.sum(positive_distance, axis=-1, keepdims=True))
        negative_distance = np.aqrt(np.sum(negative_distance, axis=-1, keepdims=True))
        
    elif dist == 'sqeuclidean':
        positive_distance = np.sum(positive_distance, axis=-1, keepdims=True)
        negative_distance = np.sum(negative_distance, axis=-1, keepdims=True)
        
    loss = positive_distance - negative_distance
    
    if margin == 'maxplus':
        loss = np.maximum(0.0, 1 + loss)
    elif margin == 'softplus':
        loss = np.log(1 + np.exp(loss))
        
    return np.mean(loss)

In [36]:
class TripletsDataset(Dataset):
    
    def __init__(self, imagename_id_mapping, other_id="new_whale", transform=None):
        
        self.transform = transform
        
        # image name ---> class id 
        self.imagename_id_mapping = imagename_id_mapping
        # class id ---> image names
        self.id_to_imagenames = defaultdict(list)
        # list of new _whale image names
        self.new_whale_list = []
        # list of all unique image names
        self.all_imagenames_list = list(imagename_id_mapping.keys())
        # range of number of image names
        self.all_imagenames_range = list(range(len(self.all_imagenames_range)))
        
        for imagename, id_ in imagename_id_mapping.items():
            if id_ == other_id:
                self.new_whale_list.append(imagename)
            else:
                self.id_to_imagenames[id_].append(imagename)
            
        self.id_list = list(set(self.imagename_id_mapping.values()))
        self.all_id_range = range(len(self.id_list))
        
        self.id_weight = np.array([len(self.id_to_imagenames[id_]) for id_ in self.id_list])
        self.id_weight = self.id_weight / np.sum(self.id_weight)
        
    def __len__(self):
        return len(self.id_list)
    
    def __getitem__(self, idx):
        # id_idx = np.random.choice(self.all_id_range, 1, p=self.id_weight)[0]
        
        examples_id_idx = np.random.choice(range(len(self.id_to_imagenames[self.id_list[id_idx]])), 2)
        
        positive_example1 = self.id_to_imagenames[self.id_list[id_idx]][examples_id_idx[0]]
        
        positive_example2 = self.id_to_imagenames[self.id_list[id_idx]][examples_id_idx[1]]
        
        negative_example = None
        
        while negative_example is None or self.imagename_id_mapping[negative_example] == \
                                        self.imagename_id_mapping[positive_example1]:
            
            negative_example_idx = np.random.choice(self.all_imagenames_range, 1)[0]
            
            negative_example = self.all_imagenames_list[negative_example_idx]
            
        sample = (positive_example1, negative_example, positive_example2)
        
        if self.transforms:
            sample = self.transform(sample)
            
        return sample

In [30]:
class sample_gen(object):
    
    def __init__(self, imagename_id_mapping, other_id="new_whale"):
        # image name ---> class id 
        self.imagename_id_mapping = imagename_id_mapping
        # class id ---> image names
        self.id_to_imagenames = defaultdict(list)
        # list of new _whale image names
        self.new_whale_list = []
        # list of all unique image names
        self.all_imagenames_list = list(imagename_id_mapping.keys())
        # range of number of image names
        self.all_imagenames_range = list(range(len(self.all_imagenames_range)))
        
        for imagename, id_ in imagename_id_mapping.items():
            if id_ == other_id:
                self.new_whale_list.append(imagename)
            else:
                self.id_to_imagenames[id_].append(imagename)
            
        self.id_list = list(set(self.imagename_id_mapping.values()))
    
        self.all_id_range = range(len(self.id_list))
        
        self.id_weight = np.array([len(self.id_to_imagenames[id_]) for id_ in self.id_list])
        
        self.id_weight = self.id_weight / np.sum(self.id_weight)
        
    def get_sample(self):
        id_idx = np.random.choice(self.all_id_range, 1, p=self.id_weight)[0]
        
        examples_id_idx = np.random.choice(range(len(self.id_to_imagenames[self.id_list[id_idx]])), 2)
        
        positive_example1 = self.id_to_imagenames[self.id_list[id_idx]][examples_id_idx[0]]
        
        positive_example2 = self.id_to_imagenames[self.id_list[id_idx]][examples_id_idx[1]]
        
        negative_example = None
        
        while negative_example is None or self.imagename_id_mapping[negative_example] == \
                                        self.imagename_id_mapping[positive_example1]:
            
            negative_example_idx = np.random.choice(self.all_imagenames_range, 1)[0]
            
            negative_example = self.all_imagenames_list[negative_example_idx]
            
        return positive_example1, negative_example, positive_example2

In [31]:
def resize(filepath):
    
    image = Image.open(filepath).convert('RGB')
    
    image = image.resize((image_size, image_size))
    
    return np.array(image, dtype="float32")

def augment(image):
    
    if np.random.uniform(0, 1) > 0.9:
        image = np.fliplr(image)
        
    return image

In [None]:
def gen(triplet):
    
    while True:
        positive_example1_list = []
        negative_example_list = []
        positive_example2_list = []
        
        for i in range(batch_size):
                
            positive_example1, negative_example, positive_example2 = triplet.get_sample()
            
            pos1_path = join(path_train, positive_example1)
            neg_path = join(path_train, negative_example)
            pos2_path = join(path_train, positive_example2)
            
            pos1_image = augment(resize(pos1_path))
            neg_image = augment(resize(neg_path))
            pos2_image = augment(resize(pos2_path))
            
            positive_example1_list.append(pos1_image)
            negative_example_list.append(neg_image)
            positive_example2_list.append(pos2_image)
            
        label = None
            
        triplet_inputs = np.array(batch_size, (positive_example1_list, negative_example_list, positive_example2_list))
        
        yield triplet_inputs

In [None]:
data = pd.read_csv(DATA_PATH)
train, valid = train_test_split(data, train_size=0.7, random_state=1337)

imagename_id_mapping_train = {k: v for k, v in zip(train.Image.values, train.Id.values)}
imagename_id_mapping_valid = {k: v for k, v in zip(valid.Image.values, valid.Id.values)}

dataset_train = TripletsDataset(imagename_id_mapping_train, transform=None)
dataset_valid = TripletsDataset(imagename_id_mapping_valid, transform=None)

dataloader_train = DataLoader(dataset_train, )
#gen_tr = gen(sample_gen(file_id_mapping_train))

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    since = time.time()
    
    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train'：
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    _, preds = torch.max(outputs, 1)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # statistics      
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
        
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    # load best model 
    model.load_state_dict(best_model_wts)
    return model, val_acc_history