In [1]:
import copy
import itertools
import time
from pathlib import Path
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, models, transforms

from multiprocessing import cpu_count
from multiprocessing.dummy import Pool
import numpy as np
import pandas as pd

from sklearn.preprocessing import MultiLabelBinarizer
import tqdm
from PIL import Image

import matplotlib as mpl
mpl_params = {
    'figure.figsize': (10, 5),
    'figure.dpi': 300,
}
from matplotlib import pyplot as plt
mpl.rcParams.update(mpl_params)

import seaborn as sns
sns.set()

In [2]:
HPA_DIR = Path('../input/HPAv18/')
HPA_PROCESSED = Path('../hpa_processed')
HPA_PROCESSED.mkdir(parents=True, exist_ok=True)

In [3]:
hpa_df = pd.read_csv('../HPAv18RBGY_wodpl.csv')

In [4]:
# def plot_labels(df, ax):
#     labels, counts = np.unique(list(map(int, itertools.chain(*df.Target.str.split()))), return_counts=True)  
#     pd.DataFrame(counts, labels).plot(kind='bar', ax=ax)
#     ax.set_ylim([0, 15000])

In [5]:
# aug_labels = [1, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 27]

In [6]:
# selection_df = hpa_df[hpa_df.Target.apply(lambda x: bool(set(aug_labels).intersection(list(map(int, x.split())))))]

In [7]:
# selection_df.head()

In [8]:
# selection_df.to_csv('hpa_select.csv', index=False)

In [9]:
# len(selection_df)

In [10]:
# fig, axes = plt.subplots(nrows=2)
# plot_labels(hpa_df, axes[0])
# plot_labels(selection_df, axes[1])

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [12]:
LABELS = {
    0: 'Nucleoplasm', 
    1: 'Nuclear membrane',   
    2: 'Nucleoli',   
    3: 'Nucleoli fibrillar center' ,  
    4: 'Nuclear speckles',
    5: 'Nuclear bodies',
    6: 'Endoplasmic reticulum',   
    7: 'Golgi apparatus',
    8: 'Peroxisomes',
    9: 'Endosomes',
    10: 'Lysosomes',
    11: 'Intermediate filaments',   
    12: 'Actin filaments',
    13: 'Focal adhesion sites',   
    14: 'Microtubules',
    15: 'Microtubule ends',   
    16: 'Cytokinetic bridge',   
    17: 'Mitotic spindle',
    18: 'Microtubule organizing center',  
    19: 'Centrosome',
    20: 'Lipid droplets',   
    21: 'Plasma membrane',   
    22: 'Cell junctions', 
    23: 'Mitochondria',
    24: 'Aggresome',
    25: 'Cytosol',
    26: 'Cytoplasmic bodies',   
    27: 'Rods & rings'
}

LABEL_NAMES = list(LABELS.values())

In [13]:
class ProteinDataset(Dataset):
    def __init__(self, df, images_dir, transform=None):            
        self.df = df.copy()
        self._dir = images_dir
        self.transform = transform
        self.p = Pool(1)
        self.mlb = MultiLabelBinarizer(list(range(len(LABELS))))
        self.count = 0
        self.total_load = 0
        self.total_stack = 0
        self.total_transform = 0
        self.colors = ['red', 'green', 'blue', 'yellow']
        self.cache_size = len(self.df)
        self.latest = 0
        self.stack = []
        self.save = iter(list(range(73)))
        
        self.cache = {}
#         for i in range(self.cache_size):
#             self.latest = i
#             id_ = self.df.iloc[i].Id
#             image_paths = [self._dir / f'{id_}_{c}.png' for c in self.colors]
#             self.cache[i] = self.p.map_async(self.mp_load, image_paths)

    def __len__(self):
        return len(self.df)
    
    def mp_load(self, path):
        pil_im = Image.open(path)
        return np.array(pil_im, np.uint8)
                                      
    def __getitem__(self, key):
        self.count += 1
        id_ = self.df.iloc[key].Id
        
#         image_paths = [self._dir / f'{id_}_{c}.png' for c in self.colors]
#         t1 = time.time()
#         if key in self.cache:
#             r, g, b, y = self.cache.pop(key).get()
#         else:
#             r, g, b, y = self.p.map(self.mp_load, image_paths)
#         self.total_load += time.time() - t1
        
#         t1 = time.time()
#         rgb = np.stack([
#             r // 2 + y // 2,
#             g // 2 + y // 2,
#             b // 2
#         ], axis=2)
#         self.total_stack += time.time() - t1
        rgb = np.array(Image.open(HPA_DIR / f'{id_}.png'), np.uint8)
        
        y = []
        if 'Target' in self.df:
            y = list(map(int, self.df.iloc[key].Target.split(' ')))
            y = self.mlb.fit_transform([y]).squeeze()
            
        if transform:
            t1 = time.time()
            X = transform(rgb)
            self.total_transform += time.time() - t1
        else:
            X = rgb
            
        self.stack.append(np.array(X))
        if len(self.stack) == (len(self.df) / 73):
            np.savez_compressed(HPA_PROCESSED / f'{next(self.save)}-processed.npz', *self.stack)
            del self.stack
            self.stack = []
            
        fn = f'{id_}.png'
        return None #(np.array(X), y, fn)

In [14]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((299, 299)),  # (299, 299) InceptionV3 input
    transforms.ToTensor(),  # To Tensor dtype and convert [0, 255] uint8 to [0, 1] float
    transforms.Normalize(  # Standard image normalization
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

In [15]:
%%time
train_ds = ProteinDataset(
    df=hpa_df, #selection_df,
    images_dir=HPA_DIR,
    transform=transform
)

CPU times: user 0 ns, sys: 4 ms, total: 4 ms
Wall time: 4.49 ms


In [16]:
len(hpa_df)

74606

In [17]:
len(hpa_df) / 73

1022.0

In [65]:
train_dl = DataLoader(
    train_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
) 

In [66]:
# save_pool = Pool(8)
count = 0
t1 = time.time()
for _ in tqdm.tqdm_notebook(train_ds):
    count += 1
#     save_pool.apply_async(np.save, args=(PROCESSED / z.replace('png', 'npy'), X))
    if count % 1022 == 0:
        print(count)
        print(time.time() - t1)
# np.savez_compressed(PROCESSED / 'processed.npz', *train_ds.stack)

HBox(children=(IntProgress(value=0, max=74606), HTML(value='')))

1022
61.39681553840637
2044
124.64675641059875
3066
186.63933563232422
4088
247.5900421142578
5110
311.1207928657532
6132
372.49723649024963
7154
432.6443660259247
8176
493.8517909049988
9198
555.9773960113525
10220
617.6187725067139
11242
678.2107384204865
12264
738.1803865432739
13286
798.975503206253
14308
860.803448677063
15330
923.1103713512421
16352
983.7217626571655
17374
1045.2637102603912
18396
1106.7033779621124
19418
1167.10830950737
20440
1228.7718935012817
21462
1288.6188311576843
22484
1348.7222998142242
23506
1408.6210677623749
24528
1469.7324120998383
25550
1531.7242126464844
26572
1591.461081981659
27594
1653.1047322750092
28616
1713.6802620887756
29638
1774.8306760787964
30660
1838.1042799949646
31682
1946.7421712875366
32704
2054.5798845291138
33726
2164.3240530490875
34748
2274.1523926258087
35770
2383.8140029907227
36792
2494.583196401596
37814
2606.0963563919067
38836
2715.8928096294403
39858
2824.7196753025055
40880
2934.5826914310455
41902
3044.6406016349792
429

In [142]:
ls hpa_processed/