In [10]:
#export
import glob
import json
import os
from pdb import set_trace
import pickle
import sys

import pandas as pd
from tqdm import tqdm_notebook as tqdm

try:
    extended
except NameError:
    sys.path.insert(0, 'rxrx1-utils')
    import rxrx.io as rio
    
from basedir import ROOT, TRAIN, TEST, SAMPLE

In [2]:
#export
def collect_records(basedir):
    """Globs the folder with images and constructs data frame with image paths
    and additional meta-information.
    """
    records = []
    columns = ['experiment', 'plate', 'well', 'site', 'channel', 'filename']
    for path in glob.glob(f'{basedir}/**/*.png', recursive=True):
        exp, plate, filename = os.path.relpath(path, start=basedir).split('/')
        basename, _ = os.path.splitext(filename)
        well, site, channel = basename.split('_')
        records.append([exp, int(plate[-1]), well, int(site[1:]), int(channel[1:]), path])
    records = pd.DataFrame(records, columns=columns)
    records['id_code'] = records[['experiment', 'plate', 'well']].apply(
        lambda r: '_'.join(map(str, r)), axis='columns')
    return records.drop(columns=['experiment', 'plate', 'well'])

In [3]:
#export
def build_files_index():
    trn_df = collect_records(TRAIN)
    trn_df['dataset'] = 'train'
    tst_df = collect_records(TEST)
    tst_df['dataset'] = 'test'
    df = pd.concat([trn_df, tst_df], axis='rows')
    keys = ['id_code', 'site', 'dataset']
    df.set_index(keys, inplace=True)
    meta = rio.combine_metadata(base_path=ROOT)
    meta = meta.reset_index().set_index(keys)
    return df.join(meta).reset_index()

In [4]:
#export
def generate_samples(files_df):
    samples = []
    for _, g in tqdm(files_df.groupby(['id_code', 'site', 'dataset'])):
        g = g.sort_values(by='channel')
        images = list(zip(g.channel, g.filename))
        records = g.to_dict(orient='records')[0]
        sirna = records['sirna']
        sample = dict(
            images=images, sirna=0 if pd.isna(sirna) else int(sirna), 
            site=records['site'], cell_type=records['cell_type'], 
            experiment=records['experiment'], well_type=records['well_type'], 
            plate=records['plate'], subset=records['dataset'])
        samples.append(sample)
    return samples

In [5]:
#export
def train_test(items):
    train, test = [], []
    for item in tqdm(items):
        subset = item.pop('subset')
        (train if subset == 'train' else test).append(item)
    return train, test

In [6]:
#export
files_index = build_files_index()

In [7]:
#export
dataset = generate_samples(files_index)

HBox(children=(IntProgress(value=0, max=125510), HTML(value='')))




In [8]:
#export
train, test = train_test(dataset)

HBox(children=(IntProgress(value=0, max=125510), HTML(value='')))




In [11]:
with open('train.json', 'w') as f:
    json.dump(train, f)
with open('test.json', 'w') as f:
    json.dump(test, f)