In [1]:
import os
import time
import copy
from collections import defaultdict
import csv

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import MarginRankingLoss
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import pandas as pd

import cv2
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize, LabelEncoder, OneHotEncoder

In [2]:
triplet_size = 5
image_size = 224
embedding_dim = 50
batch_size = 64

data_root = '../../data/humpback_whale_prediction/'
train_csv = 'train.csv'
train_triplet_file = 'train_triplets.txt'
valid_triplet_file = 'valid_triplets.txt'

In [3]:
class TripletsDataset(Dataset):
    
    def __init__(self, data_root, df, triplet_file_path, transform=None):
        '''
            data_root: The root path of the data folder.
            df: Dataframe of the dataset
            transform: the transforms to preprocess the input dataset
        '''
        super(TripletsDataset, self).__init__()
        
        self.df = df 
        self.triplet_file_path = triplet_file_path
        
        # 1. Get image name and corresponding labels
        self.image_names, self.image_labels = self.df.Image.values, self.df.Id.values

        self.label_to_imagenames = defaultdict(list)       
        for image_name, label in zip(list(self.image_names), list(self.image_labels)):
            #if id_ == other_id:
            #    self.new_whale_list.append(imagename)
            #else:
            self.label_to_imagenames[label].append(image_name)
        
        
        # 2. Make triplets list
        self.triplets = []
        
        self.make_triplet_list()
        for line in open(os.path.join(data_root, self.triplet_file_path)):
            img1, img2, img3 = line.split()
            self.triplets.append([img1, img2, img3])
        
        # 3. Set data transform
        self.transform = transform
        
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, index):
        img1, img2, img3 = self.triplets[index][0], self.triplets[index][1], self.triplets[index][2]
        
        img1 = Image.open(os.path.join(data_root, 'train/'+img1)).convert('RGB')
        img2 = Image.open(os.path.join(data_root, 'train/'+img2)).convert('RGB')
        img3 = Image.open(os.path.join(data_root, 'train/'+img3)).convert('RGB')
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        
        return img1, img2, img3
    
    def make_triplet_list(self):
        
        print('Generating triplets...')
        
        triplets = []
        
        for label in set(self.image_labels):       
            anchor = None
            positive = None
            
            if len(self.label_to_imagenames[label]) > 1:
                positives = np.random.choice((self.label_to_imagenames[label]), size=2, replace=False)
                anchor = positives[0]
                positive = positives[1]
            
                negative = None
                while negative == None or self.df[self.df['Image'] == negative]['Id'].values == label:
                    negative = np.random.choice(self.df['Image'], 1)[0]
            
                triplets.append([anchor, positive, negative])
                
        with open(os.path.join(data_root, self.triplet_file_path), "w") as f:
            writer = csv.writer(f, delimiter=' ')
            writer.writerows(triplets)
        print('Done!')
        
        
    def prepare_labels(self, y):
        values = np.array(y)
        label_encoder = LabelEncoder()
        integer_encoded = label_encoder.fit_transform(values)
        
        onehot_encoder = OneHotEncoder(sparse=False)
        integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
        onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
        
        y = onehot_encoded
        
        return y, label_encoder

In [4]:
train_transform = transforms.Compose([transforms.Resize((image_size, image_size)),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                       std=[0.229, 0.224, 0.225])])

valid_transform = transforms.Compose([transforms.Resize((image_size, image_size)),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                       std=[0.229, 0.224, 0.225])])

In [5]:
data_df = pd.read_csv(os.path.join(data_root, train_csv))
train_df,  valid_df = train_test_split(data_df, train_size=0.7, test_size=0.3, random_state=43)

train_dataset = TripletsDataset(data_root, train_df, train_triplet_file, train_transform)
valid_dataset = TripletsDataset(data_root, valid_df, valid_triplet_file, valid_transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Generating triplets...
Done!
Generating triplets...
Done!


In [6]:
input1, input2, input3 = next(iter(train_dataloader))

In [7]:
input1.shape, input2.shape, input3.shape

(torch.Size([64, 3, 224, 224]),
 torch.Size([64, 3, 224, 224]),
 torch.Size([64, 3, 224, 224]))

Create the triplet model.

In [8]:
class TripletModel(nn.Module):
    
    def __init__(self):
        super(TripletModel, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        self.base_model.fc = nn.Linear(2048, embedding_dim)
        
    def forward(self, x, y, z):
        x = self.base_model(x)
        y = self.base_model(y)
        z = self.base_model(z)
        
        dist_a = F.pairwise_distance(x, y, 2)
        dist_b = F.pairwise_distance(x, z, 2)
        
        return dist_a, dist_b, x, y, z

In [9]:
def triplet_loss(inputs, dist='sqeuclidean', margin='maxplus'):
    anchor, positive, negative = inputs
    positive_distance = np.square(anchor - positive)
    negative_distance = np.square(anchor - negative)
    
    if dist == 'euclidean':
        positive_distance = np.sqrt(np.sum(positive_distance, axis=-1, keepdims=True))
        negative_distance = np.aqrt(np.sum(negative_distance, axis=-1, keepdims=True))
        
    elif dist == 'sqeuclidean':
        positive_distance = np.sum(positive_distance, axis=-1, keepdims=True)
        negative_distance = np.sum(negative_distance, axis=-1, keepdims=True)
        
    loss = positive_distance - negative_distance
    
    if margin == 'maxplus':
        loss = np.maximum(0.0, 1 + loss)
    elif margin == 'softplus':
        loss = np.log(1 + np.exp(loss))
        
    return np.mean(loss)

In [10]:
trip_model = TripletModel()

criterion = torch.nn.MarginRankingLoss(margin = 0.2)
optimizer = optim.SGD(trip_model.parameters(), lr=0.01, momentum=0.5)

In [11]:
def train_model(model, dataloaders, optimizer, num_epochs=25):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    since = time.time()
    
    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            for inputs in dataloaders[phase]:
                inputs[0] = inputs[0].to(device)
                inputs[1] = inputs[1].to(device)
                inputs[2] = inputs[2].to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    dist_a, dist_b, x, y, z = model(*inputs)
                    target = torch.ones(batch_size)
                    
                    #loss = np.maximum(dist_a - dist_b + 0.2, 0.0)
                    loss = nn.MarginRankingLoss(margin=1.0).forward(dist_a, dist_b, target)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                print(loss)
                # statistics      
                #running_loss += loss.item() * inputs.size(0)
                #running_corrects += torch.sum(preds == labels.data)
            
            #epoch_loss = running_loss / len(dataloaders[phase].dataset)
            #epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            #print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            # deep copy the model
            #if phase == 'val' and epoch_acc > best_acc:
            #    best_acc = epoch_acc
            #    best_model_wts = copy.deepcopy(model.state_dict())
            #if phase == 'val':
            #    val_acc_history.append(epoch_acc)
        
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))
    
    # load best model 
    #model.load_state_dict(best_model_wts)
    return None#model#, val_acc_history

In [12]:
dataloaders = {'train': train_dataloader, 'valid': valid_dataloader}

train_model(trip_model, dataloaders, optimizer)

Epoch 0/24
----------
tensor(1.1178, grad_fn=<MeanBackward1>)
tensor(1.1376, grad_fn=<MeanBackward1>)
tensor(1.1304, grad_fn=<MeanBackward1>)
tensor(1.1034, grad_fn=<MeanBackward1>)
tensor(1.0912, grad_fn=<MeanBackward1>)
tensor(1.0862, grad_fn=<MeanBackward1>)


Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._

KeyboardInterrupt: 