In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np

import sys
import os

sys.path.append(os.environ['REPO_DIR'] + '/utilities')
from utilities2015 import *

import matplotlib.pyplot as plt
%matplotlib inline

import pandas as pd
import mxnet as mx

from joblib import Parallel, delayed
import time

progress_bar = FloatProgress(min=0, max=np.inf)

In [2]:
patches_rootdir = '/home/yuncong/CSHL_data_patches/'
model_dir = '/home/yuncong/mxnet_models/'

In [3]:
labels = ['BackG', '5N', '7n', '7N', '12N', 'Pn', 'VLL', 
          '6N', 'Amb', 'R', 'Tz', 'RtTg', 'LRt', 'LC', 'AP', 'sp5']

labels_index = dict((j, i) for i, j in enumerate(labels))

labels_from_surround = dict( (l+'_surround', l) for l in labels[1:])

labels_surroundIncluded_list = labels[1:] + [l+'_surround' for l in labels[1:]]
labels_surroundIncluded = set(labels_surroundIncluded_list)

labels_surroundIncluded_index = dict((j, i) for i, j in enumerate(labels_surroundIncluded_list))

colors = np.random.randint(0, 255, (len(labels_index), 3))

In [4]:
# mean_img = mx.nd.load(os.path.join(model_dir, 'mean_224.nd'))['mean_img'].asnumpy()
mean_img = np.load(model_dir + '/saturation_mean_224.npy')

In [5]:
model_name = 'Sat16ClassFinetuned'
model_iteration = 10

model = mx.model.FeedForward.load(os.path.join(model_dir, model_name), model_iteration, ctx=mx.gpu())

flatten_output = model.symbol.get_internals()['flatten_output']

model = mx.model.FeedForward(ctx=mx.gpu(), symbol=flatten_output, num_epoch=model_iteration,
                            arg_params=model.arg_params, aux_params=model.aux_params,
                            allow_extra_params=True)

In [6]:
test_features_rootdir = '/home/yuncong/CSHL_patch_features_%(model_name)s' % {'model_name': model_name}
create_if_not_exists(test_features_rootdir)

'/home/yuncong/CSHL_patch_features_Sat16ClassFinetuned'

In [7]:
sat_rootdir = '/home/yuncong/CSHL_data_saturation/'
create_if_not_exists(sat_rootdir)

'/home/yuncong/CSHL_data_saturation/'

In [8]:
def convert_to_saturation(fn, out_fn, rescale=True):
    
    img = imread(fn)
    
    m = img/255.
    ma = m.max(axis=-1)
    mi = m.min(axis=-1)
    s = (ma-mi)/ma
    s = 1-s
    
    if rescale:
        pmax = s.max()
        pmin = s.min()
        s = (s - pmin) / (pmax - pmin)
    
    sat = (s*255).astype(np.uint8)    
    cv2.imwrite(out_fn, sat)
    
    del m, ma, mi, s, img, sat

In [15]:
for stack in ['MD602', 'MD592', 'MD585', 'MD590', 'MD591', 'MD595', 'MD598']:
    dm = DataManager(stack=stack, data_dir='/media/yuncong/BstemAtlasData/CSHL_data_processed')

    sat_dir = os.path.join(sat_rootdir, '%(stack)s_saturation' % {'stack': stack})
    create_if_not_exists(sat_dir)

    first_detect_sec, last_detect_sec = detect_bbox_range_lookup[stack]

    fns = [dm._get_image_filepath(section=sec, version='rgb-jpg') 
           for sec in range(first_detect_sec, last_detect_sec+1)]

    out_fns = [sat_dir + '/%(stack)s_%(sec)04d_sat.jpg' % {'stack': stack, 'sec': sec}
               for sec in range(first_detect_sec, last_detect_sec+1)]

    t = time.time()
    Parallel(n_jobs=4)(delayed(convert_to_saturation)(fn, out_fn) for fn, out_fn in zip(fns, out_fns))
    sys.stderr.write('convert to saturation: %.2f seconds\n' % (time.time() - t)) # ~2500s

convert to saturation: 2303.44 seconds
convert to saturation: 1654.70 seconds
convert to saturation: 1149.62 seconds
convert to saturation: 1929.48 seconds
convert to saturation: 2364.96 seconds
convert to saturation: 2462.03 seconds
convert to saturation: 1822.06 seconds


In [None]:
for stack in ['MD593', 'MD602', 'MD592', 'MD585', 'MD590', 'MD591', 'MD595', 'MD598']:
# for stack in ['MD589']:
    
    if stack in ['MD589', 'MD594']:
        stack_has_annotation = True
    else:
        stack_has_annotation = False

    dm = DataManager(stack=stack, data_dir='/media/yuncong/BstemAtlasData/CSHL_data_processed')

    table_filepath = os.path.join(patches_rootdir, '%(stack)s_indices_allROIs_allSections.h5'%{'stack':stack})
    indices_allROIs_allSections = pd.read_hdf(table_filepath, 'indices_allROIs_allSections')
    grid_parameters = pd.read_hdf(table_filepath, 'grid_parameters')
    
    patch_size, stride, w, h = grid_parameters.tolist()
    half_size = patch_size/2
    ys, xs = np.meshgrid(np.arange(half_size, h-half_size, stride), np.arange(half_size, w-half_size, stride),
                     indexing='xy')
    sample_locations = np.c_[xs.flat, ys.flat]
    
    if stack_has_annotation:
        table_filepath = os.path.join(patches_rootdir, '%(stack)s_indices_allLandmarks_allSections.h5'%{'stack':stack})
        indices_allLandmarks_allSections = pd.read_hdf(table_filepath, 'indices_allLandmarks_allSections')

    first_detect_sec, last_detect_sec = detect_bbox_range_lookup[stack]

    progress_bar.min = first_detect_sec
    progress_bar.max = last_detect_sec
    display(progress_bar)
    
    for sec in range(first_detect_sec, last_detect_sec+1):
#     for sec in range(first_detect_sec, first_detect_sec+10):
#     for sec in range(first_detect_sec, first_detect_sec+1):
        
        if sec not in indices_allROIs_allSections.columns:
            continue
            
        progress_bar.value = sec
                
        indices_roi = indices_allROIs_allSections[sec]['roi1']
        
        n = len(indices_roi)
        print n, 'roi samples'
        
        ######################
        t = time.time()
        
        true_labels = -1 * np.ones((99999,), np.int)
        if stack_has_annotation:
            if sec in indices_allLandmarks_allSections:
                for l in indices_allLandmarks_allSections[sec].dropna().keys() & labels_surroundIncluded:
                    true_labels[indices_allLandmarks_allSections[sec][l]] = labels_surroundIncluded_index[l]
        patch_labels = true_labels[indices_roi]
        
        create_if_not_exists(test_features_rootdir + '/%(stack)s/%(sec)04d' % {'stack': stack, 'sec': sec})
        np.save(test_features_rootdir + '/%(stack)s/%(sec)04d/%(stack)s_%(sec)04d_roi1_labels.npy' % \
                {'stack': stack, 'sec': sec}, 
                patch_labels)
        
        sys.stderr.write('get true labels: %.2f seconds\n' % (time.time() - t)) # ~ 0s
                
        ######################
        
        sample_locations_roi = sample_locations[indices_roi]

        t = time.time()
        
        sat = imread(sat_rootdir + '/%(stack)s_saturation/%(stack)s_%(sec)04d_sat.jpg' % {'stack': stack, 'sec': sec})
            
        sys.stderr.write('load saturation image: %.2f seconds\n' % (time.time() - t)) # ~ 2s
    
        t = time.time()
    
        patches = np.array([sat[y-half_size:y+half_size, x-half_size:x+half_size]
                            for x, y in sample_locations_roi]) # n x 224 x 224
        patches = patches - mean_img
        patches = patches[:, None, :, :] # n x 1 x 224 x 224
#         patches = np.rollaxis(patches2, 3, 1)
    
        sys.stderr.write('extract, reshape, normalize: %.2f seconds\n' % (time.time() - t)) # ~ 6s
        
        batch_size = 256 # increasing to 500 does not save any time

        data_iter = mx.io.NDArrayIter(
            patches, 
            np.zeros((n, ), np.int), # labels are not important since it is just feed-forward
            batch_size = batch_size,
            shuffle=False
        )

        t = time.time()

        features = model.predict(data_iter)
        
        sys.stderr.write('predict: %.2f seconds\n' % (time.time() - t)) # ~40s
        
        t = time.time()
        
        save_hdf(features, test_features_rootdir + '/%(stack)s/%(sec)04d/%(stack)s_%(sec)04d_roi1_features.hdf' % \
                 {'stack': stack, 'sec': sec})
        
        sys.stderr.write('save: %.2f seconds\n' % (time.time() - t)) # ~.5s
        
        del sat, patches, sample_locations_roi, features
                
    del sample_locations


get true labels: 0.00 seconds
load saturation image: 1.66 seconds


In [None]:
# serial version

dm = DataManager(stack=stack)

sat_dir = os.path.join(sat_rootdir, '%(stack)s_saturation' % {'stack': stack})
create_if_not_exists(sat_dir)

first_detect_sec, last_detect_sec = detect_bbox_range_lookup[stack]

progress_bar.min = first_detect_sec
progress_bar.max = last_detect_sec
display(progress_bar)

for sec in range(first_detect_sec, last_detect_sec+1, 10):

    progress_bar.value = sec
    
    dm.set_slice(sec)
    dm._load_image(['rgb-jpg'])

    m = dm.image_rgb_jpg/255.
    ma = m.max(axis=-1)
    mi = m.min(axis=-1)
    s = (ma-mi)/ma
    s = 1-s
    
    pmax = s.max()
    pmin = s.min()
    s = (s - pmin) / (pmax - pmin)
    
    sat = (s*255).astype(np.uint8)
    
    cv2.imwrite(sat_dir + '/%(stack)s_%(sec)04d_sat.jpg' % {'stack': stack, 'sec': sec}, sat)