In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from sklearn.preprocessing import MultiLabelBinarizer
import pathlib
import torchvision.transforms as transforms
import torch
import PIL
import numpy as np
from tqdm import tqdm_notebook

In [None]:
class MultiBandMultiLabelDataset(Dataset):
    BANDS_NAMES = ['_red.png', '_green.png', '_blue.png', '_yellow.png']

    def __len__(self):
        return len(self.images_df)

    def __init__(self, images_df,
                 base_path,
                 image_transform,
                 augmentator=None,
                 train_mode=True
                 ):
        if not isinstance(base_path, pathlib.Path):
            base_path = pathlib.Path(base_path)

        self.images_df = images_df.copy()
        self.image_transform = image_transform
        self.augmentator = augmentator
        self.images_df.Id = self.images_df.Id.apply(lambda x: base_path / x)
        self.mlb = MultiLabelBinarizer(classes=list(LABEL_MAP.keys()))
        self.train_mode = train_mode

    def __getitem__(self, index):
        y = None
        X = self._load_multiband_image(index)
        if self.train_mode:
            y = self._load_multilabel_target(index)

        # augmentator can be for instance imgaug augmentation object
        if self.augmentator is not None:
            X = self.augmentator(X)

        X = self.image_transform(X)

        return X, y

    def _load_multiband_image(self, index):
        row = self.images_df.iloc[index]
        image_bands = []
        for band_name in self.BANDS_NAMES:
            p = str(row.Id.absolute()) + band_name
            pil_channel = PIL.Image.open(p).convert('L')
            image_bands.append(pil_channel)

        # lets pretend its a RBGA image to support 4 channels
        band4image = PIL.Image.merge('RGBA', bands=image_bands)
        band4image = np.array(band4image)
        return band4image

    def _load_multilabel_target(self, index):
        return list(map(int, self.images_df.iloc[index].Target.split(' ')))

    def collate_func(self, batch):
        labels = None
        images = [x[0] for x in batch]

        if self.train_mode:
            labels = [x[1] for x in batch]
            labels_one_hot = self.mlb.fit_transform(labels)
            labels = torch.FloatTensor(labels_one_hot)

        return torch.stack(images), labels

    def visualize_sample(self, sample_size):
        samples = np.random.choice(self.df['id'].values, sample_size)
        self.df.set_index('id', inplace=True)
        fig, axs = plt.subplots(2, sample_size)
        for i in range(sample_size):
            im = cv2.imread(self.df.loc[samples[i], 'im_path'], cv2.IMREAD_COLOR)
            mask = cv2.imread(self.df.loc[samples[i], 'mask_path'], cv2.IMREAD_GRAYSCALE)
            print('Image shape: ', np.array(im).shape)
            print('Mask shape: ', np.array(mask).shape)
            axs[0, i].imshow(im)
            axs[1, i].imshow(mask)


In [None]:
LABEL_MAP = {
0: "Nucleoplasm" ,
1: "Nuclear membrane"   ,
2: "Nucleoli"   ,
3: "Nucleoli fibrillar center",
4: "Nuclear speckles"   ,
5: "Nuclear bodies"   ,
6: "Endoplasmic reticulum"   ,
7: "Golgi apparatus"  ,
8: "Peroxisomes"   ,
9:  "Endosomes"   ,
10: "Lysosomes"   ,
11: "Intermediate filaments"  ,
12: "Actin filaments"   ,
13: "Focal adhesion sites"  ,
14: "Microtubules"   ,
15: "Microtubule ends"   ,
16: "Cytokinetic bridge"   ,
17: "Mitotic spindle"  ,
18: "Microtubule organizing center",
19: "Centrosome",
20: "Lipid droplets"   ,
21: "Plasma membrane"  ,
22: "Cell junctions"   ,
23: "Mitochondria"   ,
24: "Aggresome"   ,
25: "Cytosol" ,
26: "Cytoplasmic bodies",
27: "Rods & rings"}

In [None]:
PATH_TO_IMAGES = './data/full_train/'
PATH_TO_TEST_IMAGES = './data/test/'
PATH_TO_META = './data/full_dev_train.csv'
SAMPLE_SUBMI = './data/sample_submission.csv'

SEED = 666
DEV_MODE = False
SIZE = 256

In [None]:
df = pd.read_csv(PATH_TO_META)
print(len(df))
def image_transform(img):
    # img = self.normalize(img)
#     mean = []
#     std = []
#     img = cv2.resize(img, (self.size, self.size))
    img = np.array(img).transpose(2,0,1).astype('float32')
    img = torch.from_numpy(img)
#     img = transforms.functional.normalize(img, mean=mean, std=std)
    return img
 
gtrain = MultiBandMultiLabelDataset(df, base_path=PATH_TO_IMAGES, image_transform=image_transform)

train_load = DataLoader(gtrain, collate_fn=gtrain.collate_func, batch_size=16, num_workers=6)

In [None]:
# for i, lists in enumerate(tqdm_notebook(train_load)):
#     if i > 1:
#         break
#     img = lists[0].numpy()
#     print(img[0, :, :, 0])


In [None]:
m0 = []
s0 = []
m1 = []
s1 = []
m2 = []
s2 = []
m3 = []
s3 = []

for i, lists in enumerate(tqdm_notebook(train_load)):
    img = lists[0].numpy()
    mean0 = np.mean(img[:, :, :, 0])
    std0 = np.std(img[:, :, :, 0])
    m0.append(mean0)
    s0.append(std0)
    
    mean1 = np.mean(img[:, :, :, 1])
    std1 = np.std(img[:, :, :, 1])
    m1.append(mean1)
    s1.append(std1)

    
    mean2 = np.mean(img[:, :, :, 2])
    std2 = np.std(img[:, :, :, 2])
    m2.append(mean2)
    s2.append(std2)

    mean3 = np.mean(img[:, :, :, 3])
    std3 = np.std(img[:, :, :, 3])
    m3.append(mean3)
    s3.append(std3)

mm0 = np.mean(m0, axis=0)
ss0 = np.mean(s0)
mm1 = np.mean(m1)
ss1 = np.mean(s1)
mm2 = np.mean(m2)
ss2 = np.mean(s2)
mm3 = np.mean(m3)
ss3 = np.mean(s3)
print(mm0, mm1, mm2, mm3)
print(ss0, ss1, ss2, ss3)
m = 'mean: ({}, {}, {}, {})'.format(mm0, mm1, mm2, mm3)
s = 'std: ({}, {}, {}, {})'.format(ss0, ss1, ss2, ss3)
with open('./data/mean&std.txt', 'a+') as f:
    f.write(m)
    f.write(s)
#     for j, img in enumerate(lists[0]):
#         img = img.numpy().transpose((1, 2, 0))
#         if i > 1:
#             break
#         plt.imshow(img[:,:,0], cmap='Greens')
#         plt.show()
#         plt.imshow(img[:,:,1], cmap='Reds')
#         plt.show()
#         plt.imshow(img[:,:,2], cmap='Oranges')
#         plt.show()
#         plt.imshow(img[:,:,3], cmap='Blues')
#         plt.show()