In [80]:
import numpy as np
import pandas as pd
from torch.utils import data
import os
import io
from sklearn.preprocessing import LabelEncoder
import cv2

In [84]:
class MuraDataset(data.Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None, columns=[3, 4]):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            columns (int array, optional): an array of length 2. first element
                is the column that specifies path in the csv and second is
                the result
        """
        self.process_csv(csv_file, columns)
        self.root_dir = root_dir
        self.transform = transform
        
    def process_csv(self, file, columns):
        # read and update
        df = pd.read_csv(file)
        df = df.iloc[:, columns]
        df.columns = ['path', 'label']
        
        # encode labels
        le = LabelEncoder()
        df.label = le.fit_transform(df.label)
        
        # save
        self.le = le
        self.mura_frame = df

    def __len__(self):
        return len(self.mura_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.mura_frame.iloc[idx, 0])
        img = cv2.imread(img_name)
        #landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        #landmarks = landmarks.astype('float').reshape(-1, 2)
        label = self.mura_frame.iloc[idx, 1]
        
        sample = {'image': img, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [85]:
md = MuraDataset(csv_file='../../data/processed/train_all.csv',
                 root_dir='../../data/')

In [86]:
md[0]

{'image': array([[[23, 23, 23],
         [26, 26, 26],
         [11, 11, 11],
         ...,
         [ 8,  8,  8],
         [ 8,  8,  8],
         [ 9,  9,  9]],
 
        [[25, 25, 25],
         [ 9,  9,  9],
         [ 7,  7,  7],
         ...,
         [ 7,  7,  7],
         [ 7,  7,  7],
         [ 7,  7,  7]],
 
        [[ 8,  8,  8],
         [ 6,  6,  6],
         [ 8,  8,  8],
         ...,
         [ 6,  6,  6],
         [ 6,  6,  6],
         [ 7,  7,  7]],
 
        ...,
 
        [[ 8,  8,  8],
         [ 8,  8,  8],
         [ 9,  9,  9],
         ...,
         [ 4,  4,  4],
         [ 4,  4,  4],
         [ 4,  4,  4]],
 
        [[ 7,  7,  7],
         [ 8,  8,  8],
         [ 8,  8,  8],
         ...,
         [ 4,  4,  4],
         [ 4,  4,  4],
         [ 4,  4,  4]],
 
        [[ 8,  8,  8],
         [ 7,  7,  7],
         [ 7,  7,  7],
         ...,
         [ 3,  3,  3],
         [ 3,  3,  3],
         [ 3,  3,  3]]], dtype=uint8), 'label': 5}