In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms, utils
import matplotlib.pyplot as plt
import time
import os
import pandas as pd
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")


plt.ion()

In [None]:
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.ix[n, 0]
landmarks = landmarks_frame.ix[n, 1:].as_matrix().astype('float')
# print (landmarks)
landmarks = landmarks.reshape(-1, 2)
# print (landmarks)

print ('Img name: {}'.format(img_name))
print ('Landmarks shape: {}'.format(landmarks.shape))
print ('First four landmarks are: {}'.format(landmarks[:4]))

In [None]:
def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('faces/', img_name)),
               landmarks)
plt.show()

In [None]:
print (landmarks_frame.ix[0, 0])
68 * 2 + 1

In [None]:
class FaceLandmarksDataset(Dataset):
    '''Face landmarks dataset'''
    
    def __init__(self, csv_file, root_dir, transformation=None):
        self.root_dir = root_dir
        self.landmarks_frame = pd.read_csv(csv_file)
        self.transformation = transformation
    
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, index):
        img_name = os.path.join(self.root_dir, landmarks_frame.ix[index, 0])
        image = io.imread(img_name)
        landmarks = landmarks_frame.ix[index, 1:].as_matrix().astype('float')
        landmarks = landmarks.reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}
        
        if self.transformation:
            sample = self.transformation(sample)
        return sample   

In [None]:
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv', 
                                   root_dir='faces')
fig = plt.figure()

for i in range(len(face_dataset)):
    img = face_dataset[i]
    print ('Image shape: {}, landmarks shape: {}'.format(img['image'].shape, img['landmarks'].shape))
    
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample: {}'.format(i))
    ax.axis('off')
    show_landmarks(**img)
    
    if i == 3:
        break

In [None]:
class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
    
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        # print ("h is {}, w is {}".format(h, w))
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        
        new_h, new_w = int(new_h), int(new_w)
        image = transform.resize(image, (new_h, new_w))
        
        landmarks = landmarks * [new_w / w, new_h / h]
        return {'image': image, 'landmarks': landmarks}

In [None]:
class ToTensor(object):
    '''Convert ndarrays into tensors'''
    
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}

In [None]:
class RandomCrop(object):
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
    
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        image = image[top: top + new_h, left: left + new_w]
        landmarks = landmarks - [top, left]
        
        return {'image': image, 'landmarks': landmarks}

In [None]:
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(244)])

fig = plt.figure()
sample = face_dataset[65]
# print (sample)
for i, t in enumerate([scale, crop, composed]):
    transformed_sample = t(sample)
    
    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(t).__name__)
    show_landmarks(**transformed_sample)

In [None]:
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                          root_dir='faces',
                                          transformation=transforms.Compose([
                                              Rescale(256), 
                                              RandomCrop(244), 
                                              ToTensor()
                                          ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]
    print ('type of image: {}, landmark: {}'.format(type(sample['image']), type(sample['landmarks'])))
    print ('Image shape: {}, landmarks shape: {}'.format(sample['image'].size(), sample['landmarks'].size()))
    
    if i == 3:
        break

In [None]:
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)

def show_landmarks_batched(sample_batched):
    images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    # print (im_size)
    grid = utils.make_grid(images_batch)
    # print (grid.shape)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    print (landmarks_batch.shape)
    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                   landmarks_batch[i, :, 1].numpy(), 
                   s = 10, marker='.', c='r')
        plt.title('Dataloader batch')

for i_batch, sample_batched in enumerate(dataloader):
    print (i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())
    if (i_batch == 3):
        plt.figure()
        show_landmarks_batched(sample_batched)
        plt.axis('off')
        break