In [19]:
# Standard lib
from collections import OrderedDict
import glob
from multiprocessing import cpu_count
import os
from pathlib import Path
from pdb import set_trace 
import re
from string import ascii_lowercase

# Utils
from imageio import imread
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image
from tqdm import tqdm_notebook as tqdm

# Model training
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T

from catalyst.dl.callbacks import AccuracyCallback, AUCCallback, F1ScoreCallback
from catalyst.dl.runner import SupervisedRunner
from catalyst.utils import set_global_seed, prepare_cudnn

import pretrainedmodels

In [5]:
seed = 1
set_global_seed(seed)
prepare_cudnn(deterministic=True)

if os.environ.get('KAGGLE_URL_BASE', False):
    ROOT = Path.cwd().parent/'input'/'recursion-cellular-image-classification'
else:
    ROOT = Path.home()/'data'/'protein'
    
try:
    rxrx
except NameError:
    import sys
    if not os.path.exists('rxrx1-utils'):
        print('Cloning RxRx repository...')
        !git clone https://github.com/recursionpharma/rxrx1-utils
    print('Adding to the search path.')
    sys.path.append('rxrx1-utils')
    print('Done!')
    
import rxrx.io as rio

Adding to the search path.
Done!


In [6]:
!ls -l {ROOT}

total 2464
-rw-rw-r--  1 ck ck   35620 авг 24 14:33 recursion_dataset_license.pdf
-rw-rw-r--  1 ck ck  367018 авг 24 14:33 sample_submission.csv
drwxrwxr-x 20 ck ck    4096 авг 24 16:42 test
-rw-rw-r--  1 ck ck  114364 авг 24 14:33 test_controls.csv
-rw-rw-r--  1 ck ck  574862 авг 24 14:33 test.csv
drwxrwxr-x  3 ck ck    4096 авг 24 18:44 tmp
drwxrwxr-x 35 ck ck    4096 авг 24 16:44 train
-rw-rw-r--  1 ck ck  208866 авг 24 14:33 train_controls.csv
-rw-------  1 ck ck 1203816 июн 26 08:01 train.csv


In [7]:
STATS_FILE = ROOT/'pixel_stats.csv'

SUBMIT_FILE = ROOT/'sample_submission.csv'
TEST_FILE = ROOT/'test.csv'
TEST_CTRL_FILE = ROOT/'test_controls.csv'
TRAIN_FILE = ROOT/'train.csv'
TRAIN_CTRL_FILE = ROOT/'train_controls.csv'

TEST_DIR = ROOT/'test'
TRAIN_DIR = ROOT/'train'

In [8]:
meta = rio.combine_metadata()
meta = meta.reset_index()
meta.head()

Unnamed: 0,id_code,cell_type,dataset,experiment,plate,sirna,site,well,well_type
0,HEPG2-08_1_B02,HEPG2,test,HEPG2-08,1,1138.0,1,B02,negative_control
1,HEPG2-08_1_B02,HEPG2,test,HEPG2-08,1,1138.0,2,B02,negative_control
2,HEPG2-08_1_B03,HEPG2,test,HEPG2-08,1,,1,B03,treatment
3,HEPG2-08_1_B03,HEPG2,test,HEPG2-08,1,,2,B03,treatment
4,HEPG2-08_1_B04,HEPG2,test,HEPG2-08,1,,1,B04,treatment


In [9]:
key = ['experiment', 'plate', 'well', 'site', 'sirna']
meta_train = meta[meta.dataset == 'train']
meta_test = meta[meta.dataset == 'test']
train_experiments = list(meta_train[key].itertuples(index=False, name=None))
test_experiments = list(meta_test[key].itertuples(index=False, name=None))

In [10]:
NUM_OF_CLASSES = meta_train.sirna.astype(int).nunique()

In [12]:
rio.load_site('train', 'RPE-05', 3, 'D19', 2).shape

(512, 512, 6)

In [13]:
def collect_records(basedir):
    records = []
    columns = ['experiment', 'plate', 'well', 'site', 'channel', 'filename']
    for path in glob.glob(f'{basedir}/**/*.png', recursive=True):
        exp, plate, filename = os.path.relpath(path, start=basedir).split('/')
        basename, _ = os.path.splitext(filename)
        well, site, channel = basename.split('_')
        records.append([exp, int(plate[-1]), well, int(site[1:]), int(channel[1:]), path])
    records = pd.DataFrame(records, columns=columns)
    return records

In [28]:
def join_channels(records, meta, output_dir):
    labels = meta.set_index(['id_code', 'site'])
    os.makedirs(output_dir, exist_ok=True)
    print(f'Saving joined channels into folder: {output_dir}')
    
    labels = meta_train.set_index(['id_code', 'site'])
    
    for key, group in tqdm(records.groupby(['experiment', 'plate', 'well', 'site'])):
        group = group.sort_values(by='channel')
        
        x = np.zeros((512, 512, 6), dtype=np.uint8)
        for r in group.itertuples(index=False, name=None):
            exp, plate, well, site, channel, filename = r
            xi = np.asarray(imread(filename))
            x[:, :, channel-1] = xi
        
        id_code = f'{exp}_{plate}_{well}'
        y = int(labels.loc[(id_code, site)].sirna)
        output_file = f'{id_code}_s{site}_{y}.png'
        output_path = os.path.join(output_dir, output_file)
        rgb = rio.convert_tensor_to_rgb(x)
        img = PIL.Image.fromarray(rgb.astype(np.uint8))
        img.save(output_path)
        
    return output_dir

In [50]:
from joblib import Parallel, delayed

In [65]:
def parallel_join_channels(image_groups, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    def worker(channel_group, output_dir):
        x = np.zeros((512, 512, 6), dtype=np.uint8)
        for info in channel_group:
            xc = np.asarray(imread(info['filename']))
            x[:, :, info['channel']-1] = xc
        
        sirna = info['sirna']
        y = 0 if pd.isna(sirna) else int(sirna)
        output_file = f"{info['id_code']}_s{info['site']}_{y}.png"
        output_path = os.path.join(output_dir, output_file)
        rgb = rio.convert_tensor_to_rgb(x)
        img = PIL.Image.fromarray(rgb.astype(np.uint8))
        img.save(output_path)
        return output_path
    
    with Parallel(n_jobs=cpu_count()) as p:
        paths = p(delayed(worker)(g, output_dir) for g in tqdm(image_groups))
        
    return paths

In [55]:
trn_records = collect_records('/home/ck/data/protein/train')
trn_info = pd.merge(trn_records, meta_train, on=['experiment', 'plate', 'well', 'site'])
trn_groups = [group.to_dict('records') for _, group in trn_info.groupby(['id_code', 'site'])]

In [57]:
trn_paths = parallel_join_channels(trn_groups, '/home/ck/data/protein/tmp/train')

HBox(children=(IntProgress(value=0, max=81224), HTML(value='')))

In [59]:
tst_records = collect_records('/home/ck/data/protein/test')
tst_info = pd.merge(tst_records, meta_test, on=['experiment', 'plate', 'well', 'site'])
tst_groups = [group.to_dict('records') for _, group in tst_info.groupby(['id_code', 'site'])]

In [66]:
tst_paths = parallel_join_channels(tst_groups, '/home/ck/data/protein/tmp/test')

HBox(children=(IntProgress(value=0, max=44286), HTML(value='')))