In [None]:
import os
import pickle

import scipy.io as sio
import numpy as np
import torch
import torchvision
import h5py
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import patches,  lines
import time

%matplotlib inline

## Preprocess the raw data

In [None]:
# Source: https://stackoverflow.com/questions/41176258/h5py-access-data-in-datasets-in-svhn

def get_box_data(index, hdf5_data):
    """
    get `left, top, width, height` of each picture
    :param index:
    :param hdf5_data:
    :return:
    """
    meta_data = dict()
    meta_data['height'] = []
    meta_data['label'] = []
    meta_data['left'] = []
    meta_data['top'] = []
    meta_data['width'] = []

    def print_attrs(name, obj):
        vals = []
        if obj.shape[0] == 1:
            vals.append(obj[0][0])
        else:
            for k in range(obj.shape[0]):
                vals.append(int(hdf5_data[obj[k][0]][0][0]))
        meta_data[name] = vals

    box = hdf5_data['/digitStruct/bbox'][index]
    hdf5_data[box[0]].visititems(print_attrs)
    return meta_data

def get_name(index, hdf5_data):
    name = hdf5_data['/digitStruct/name']
    return ''.join([chr(v[0]) for v in hdf5_data[name[index][0]].value])


def aggregate_data(index, hdf5_data):
    
    image_id = get_name(index, mat_data)
    labels = get_box_data(index, mat_data)
    
    # Convert label 10 to label 0 for digit 0
    if 10 in labels['label']:
        labels['label'] = [0 if x==10 else x for x in labels['label']]
        
    metadata = {}
    
    metadata['filename'] = image_id
    metadata['metadata'] = labels

    return metadata


def save_obj(obj, root_dir, filename):
    with open(root_dir + filename + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(root_dir, filename):
    with open(root_dir + filename + '.pkl', 'rb') as f:
        return pickle.load(f)


In [None]:
root_dir = 'data/SVHN/train/'

In [None]:
# Step 1 : Download SVHN data in to /data/SVHN/
# Parse all metadata from digitStruct.mat into metadata dict (long!)


start_time = time.time()
file1 = 'data/SVHN/train/digitStruct.mat'
mat_data = h5py.File(file1)
dataset_size = mat_data['/digitStruct/name'].size


metadata = {}
for index in range(dataset_size):
    
    metadata[index] = aggregate_data(index, mat_data)

    if index % 5000 == 0:
        print(index)
        
end_time = time.time()

print("Total time :", end_time - start_time)

print("Saving metadata dict ...")
filename = 'labels'
save_obj(metadata, root_dir, filename)

## Create custom dataloader.

https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel

In [None]:
from torch.utils import data

class SVHNDataset(data.Dataset):
    
    def __init__(self, metadata, root_dir, transform=None):
        """
        Args:
            labels (dict): Dictionary containing all labels and metadata
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.metadata = metadata
        self.root_dir = root_dir
        self.transform = transform
        
        
    def __len__(self):
        return len(self.metadata)

    
    def __getitem__(self, index):
        '''
        Parameters
        ----------
        index : int
            The index of the dataset

        Returns
        -------
        X : PIL objet
        
        y : dict
            The metadata associated to the image in dict form.

        '''
        'Generates one sample of data'

        img_name = os.path.join(self.root_dir,
                                self.metadata[index]['filename'])

        # Load data and get labels
        image = Image.open(img_name)
        meta = self.metadata[index]['metadata']
        
        sample = {'image':image, 'metadata':meta}
        return sample



In [None]:
filename = 'labels'

metadata = load_obj(root_dir, filename)
traindata = SVHNDataset(metadata, root_dir)

In [None]:
## Draft code to extract bboxes
## Inspiration: https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/visualize.py


def visualize_sample(dataset, idx=None, bbox=True, captions=True):
    
    
    # Fetch image + labels    
    if not idx:    
        idx = np.random.randint(len(dataset))

    sample = dataset[idx]
    img = sample['image']
    meta = sample['metadata']
    
    # Display image
    _, ax = plt.subplots(1)
    plt.axis('off')
    plt.imshow(img)

    
    
    N = len(meta['label']) # Number of digits in image

    labels = [] # Digits present in image
    boxes = [] # bboxes present in image

    # Extract boxes and labels
    for jj in range(N):
        labels.append(meta['label'][jj])
        y1 = meta['top'][jj]
        y2 = y1+meta['height'][jj]
        x1 = meta['left'][jj]
        x2 = x1 + meta['width'][jj]

        boxes.append((y1,x1,y2,x2))


    
    
    # Show boxes and labels
    for i in range(N):
        y1, x1, y2, x2 = boxes[i]
        # Show bounding boxes
        if bbox:
            p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
                                alpha=0.7, linestyle="dashed",
                                edgecolor='red', facecolor='none')
            ax.add_patch(p)

        # Show Label
        if captions:
            caption = labels[i]
            ax.text(x1, y1 + 8, caption,
                color='w', size=11, backgroundcolor="none")



In [None]:
visualize_sample(traindata)

In [None]:
# Get smallest dimensions of images possible

im_width = []
im_height =  []
for jj in range(len(traindata)):
    
    shape = np.asarray(traindata[jj]['image']).shape
    im_height.append(shape[0])
    im_width.append(shape[1])
    
im_width = np.asarray(im_width)
im_height = np.asarray(im_height)

In [None]:
# Explore dataset for cleaning

# Minimum width and height of images
print("minimum image width", np.min(im_width))
print("minimum image height", np.min(im_height))


#
total = np.sum(np.logical_or(im_height < 28, im_width < 28))

print('total number of images that are too small', total)

In [None]:
## sample image that is too small

index = np.argmin(im_height)
visualize_sample(traindata, idx=index)

sample = traindata[index]
print(sample['metadata']['label'])
np.asarray(sample['image']).shape

In [None]:
# Add example of at least one transform
# use imgaug