In [None]:
import csv

import numpy as np
import skimage
from PIL import Image
from torch.utils.data import Dataset


class CSVDataset(Dataset):

    def __init__(self, train_file, class_file, transform=None):
        """
        :param train_file: CSV file with training annotations
        :param class_file: CSV file with class mapping
        :param transform: Transforms to be applied to each item
        """
        self.transform = transform

        with open(class_file) as file:
            self.classes = self.extract_classes(csv.reader(file, delimiter=','))
        self.labels = {}
        for name, id in self.classes.items():
            self.labels[name] = id

        with open(train_file) as file:
            self.data = self.read_annotations(csv.reader(file, delimiter=','))
        self.file_names = list(self.data.keys())

    def __getitem__(self, index):
        image = self.load_image(index)
        annotation = self.load_annotation(index)
        sample = {'img': image, 'annot': annotation}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return len(self.file_names)

    def load_image(self, index):
        image = skimage.io.imread(self.file_names[index])

        if len(image.shape) == 2:
            image = skimage.color.gray2rgb(image)
        return image.astype(np.float32) / 255.0

    def load_annotation(self, index):
        annot_list = self.data[self.file_names[index]]
        annotations = np.zeros((0, 5))
        for id, annotation in enumerate(annot_list):
            temp = np.zeros((1, 5))
            temp[0, 0] = annotation['x1']
            temp[0, 1] = annotation['y1']
            temp[0, 2] = annotation['x2']
            temp[0, 3] = annotation['y2']
            temp[0, 4] = self.labels[annotation['class']]
            annotations = np.append(annotations, temp, axis=0)
        return annotations

    def extract_classes(self, reader):
        classes = {}
        for line, row in enumerate(reader):
            name, _id = row
            _id = int(_id)
            if name in classes:
                raise ValueError("Duplicate class names")
            else:
                classes[name] = _id
        return classes

    def read_annotations(self, reader):
        result = {}
        for line, row in enumerate(reader):
            img_file, x_1, y_1, x_2, y_2, class_name = row[:6]
            x_1 = int(x_1)
            y_1 = int(y_1)
            x_2 = int(x_2)
            y_2 = int(y_2)
            if img_file not in result:
                result[img_file] = []
            if class_name not in self.classes:
                raise ValueError("unknown class name")
            result[img_file].append({
                'x1': x_1,
                'y1': y_1,
                'x2': x_2,
                'y2': y_2,
                'class': class_name
            })
        return result

    def num_classes(self):
        return max(self.classes.values()) + 1

    def image_aspect_ratio(self, image_index):
        image = Image.open(self.file_names[image_index])
        return float(image.width) / float(image.height)