In [22]:
import torch
from torch.utils import data
import cv2
import numpy as np
import pandas as pd
import os

In [49]:
label_names = {
    0:  "Nucleoplasm",  
    1:  "Nuclear membrane",   
    2:  "Nucleoli",   
    3:  "Nucleoli fibrillar center",   
    4:  "Nuclear speckles",
    5:  "Nuclear bodies",   
    6:  "Endoplasmic reticulum",   
    7:  "Golgi apparatus",   
    8:  "Peroxisomes",   
    9:  "Endosomes",   
    10:  "Lysosomes",   
    11:  "Intermediate filaments",   
    12:  "Actin filaments",   
    13:  "Focal adhesion sites",   
    14:  "Microtubules",   
    15:  "Microtubule ends",   
    16:  "Cytokinetic bridge",   
    17:  "Mitotic spindle",   
    18:  "Microtubule organizing center",   
    19:  "Centrosome",   
    20:  "Lipid droplets",   
    21:  "Plasma membrane",   
    22:  "Cell junctions",   
    23:  "Mitochondria",   
    24:  "Aggresome",   
    25:  "Cytosol",   
    26:  "Cytoplasmic bodies",   
    27:  "Rods & rings"
}

def target2onehot(target):
    y = np.zeros(len(label_names))
    t = [int(t) for t in target.split(' ')]
    y[t] = 1
    return y 


class HmDataset(data.Dataset):
    def __init__(self, list_IDs, labels, data_dir):
        'Initialization'
        self.labels = labels # One hot labels
        self.list_IDs = list_IDs
        self.data_dir = data_dir

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        y = self.labels[ID]
        
        # TODO: fix that
        X = np.empty((4, 512, 512))
        X[0,:,:] = cv2.imread(os.path.join(self.data_dir, ID + "_green" + ".png"), 0)
        X[1,:,:] = cv2.imread(os.path.join(self.data_dir, ID + "_red" + ".png"), 0)
        X[2,:,:] = cv2.imread(os.path.join(self.data_dir, ID + "_blue" + ".png"), 0)
        X[3,:,:] = cv2.imread(os.path.join(self.data_dir, ID + "_yellow" + ".png"), 0)

        return torch.tensor(X), torch.tensor(y)

In [28]:
train_labels = pd.read_csv('data/train.csv')
train_labels.head()

Unnamed: 0,Id,Target
0,00070df0-bbc3-11e8-b2bc-ac1f6b6435d0,16 0
1,000a6c98-bb9b-11e8-b2b9-ac1f6b6435d0,7 1 2 0
2,000a9596-bbc4-11e8-b2bc-ac1f6b6435d0,5
3,000c99ba-bba4-11e8-b2b9-ac1f6b6435d0,1
4,001838f8-bbca-11e8-b2bc-ac1f6b6435d0,18


In [18]:
ids = train_labels.Id
onehot = [target2onehot(t) for t in train_labels.Target]
labels = dict(zip(ids, np.array(onehot)))

In [50]:
data_dir = "data/train"
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# cudnn.benchmark = True

# Parameters
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 8}
max_epochs = 100

# Datasets
partition = ids
labels = labels

# Generators
training_set = HmDataset(ids, labels, data_dir)
training_generator = data.DataLoader(training_set, **params)

In [41]:
for batch, target in training_generator:
    batch, target = batch.to(device), target.to(device)
    break

In [48]:
batch.shape

torch.Size([64, 4, 512, 512])