In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np

import sys
import os

sys.path.append(os.environ['REPO_DIR'] + '/utilities')
from utilities2015 import *

import matplotlib.pyplot as plt
%matplotlib inline

import pandas as pd
import mxnet as mx

from joblib import Parallel, delayed
import time



In [2]:
patches_rootdir = '/home/yuncong/CSHL_data_patches/'
model_dir = '/home/yuncong/mxnet_models/'

In [3]:
labels = ['BackG', '5N', '7n', '7N', '12N', 'Pn', 'VLL', 
          '6N', 'Amb', 'R', 'Tz', 'RtTg', 'LRt', 'LC', 'AP', 'sp5']

label_dict = dict([(l,i) for i, l in enumerate(labels)])

In [4]:
mean_img = mx.nd.load(os.path.join(model_dir, 'mean_224.nd'))['mean_img'].asnumpy()

In [5]:
progress_bar = FloatProgress(min=0, max=np.inf)

In [6]:
from collections import defaultdict

In [47]:
patches_allClasses = defaultdict(list)

for stack in ['MD589']:

    dm = DataManager(stack=stack)
    stack_has_annotation = True

    if stack_has_annotation:
        table_filepath = os.path.join(patches_rootdir, '%(stack)s_indices_allLandmarks_allSections.h5'%{'stack':stack})
        indices_allLandmarks_allSections = pd.read_hdf(table_filepath, 'indices_allLandmarks_allSections')
        grid_parameters = pd.read_hdf(table_filepath, 'grid_parameters')
        
    patch_size, stride, w, h = grid_parameters.tolist()
    half_size = patch_size/2
    ys, xs = np.meshgrid(np.arange(half_size, h-half_size, stride), np.arange(half_size, w-half_size, stride),
                     indexing='xy')
    sample_locations = np.c_[xs.flat, ys.flat]
        
    first_bs_sec, last_bs_sec = section_range_lookup[stack]
    first_detect_sec, last_detect_sec = detect_bbox_range_lookup[stack]

    progress_bar.min = first_detect_sec
    progress_bar.max = last_detect_sec
    display(progress_bar)
    
    for sec in range(first_detect_sec, last_detect_sec+1):
#     for sec in range(first_detect_sec, first_detect_sec+10):
#     for sec in range(first_detect_sec, first_detect_sec+1):
        
        if sec not in indices_allLandmarks_allSections.columns:
            continue

        progress_bar.value = sec
#         print sec

        ## define grid, generate patches

        dm.set_slice(sec)
        dm._load_image(['rgb-jpg'])

        q = indices_allLandmarks_allSections[sec].dropna()
        if len(q.index) == 0:
            continue

        for label in q.index:
            if label == 'bg':
                continue

            print label
            indices_roi = q[label]
    
            n2 = len(indices_roi)
            print n2, 'samples'

            num_sample_each_polygon = 100
            indices_roi = np.random.choice(indices_roi, min(num_sample_each_polygon, n2), replace=False)

            n = len(indices_roi)
            print n, 'used samples'

            sample_locations_roi = sample_locations[indices_roi]

            patches2 = np.asarray([dm.image_rgb_jpg[y-half_size:y+half_size, x-half_size:x+half_size]
                                  for x, y in sample_locations_roi])

            patches = np.rollaxis(patches2, 3, 1)
            patches_allClasses[label].append(patches - mean_img)

            del patches, patches2, sample_locations_roi

    del sample_locations

5N
121 samples
100 used samples
5N_surround
612 samples
100 used samples
7n
48 samples
48 used samples
7n_surround
515 samples
100 used samples
VLL
83 samples
83 used samples
VLL_surround
535 samples
100 used samples


In [None]:
training_features_dir = '/home/yuncong/CSHL_patch_features/train'
if not os.path.exists(training_features_dir):
    os.makedirs(training_features_dir)

In [None]:
# for label, patches in patches_allClasses_arr.iteritems():

for label, patches1 in patches_allClasses.iteritems():
    
    patches = np.concatenate(patches1)

    n = len(patches)
    
    batch_size = 512 # increasing to 892 does not save any time

    if n < batch_size:
        sys.stderr.write('data size smaller than batch size: %s\n' % label)
        continue
    
    train_iter = mx.io.NDArrayIter(
        patches, 
        np.zeros((n, ), np.int),
        batch_size = batch_size,
        shuffle=False
    )
    #         sys.stderr.write('load iterator: %.2f seconds\n' % (time.time() - t))

    t = time.time()

    features = model.predict(train_iter)

    bp.pack_ndarray_file(features, training_features_dir + '/%(stack)s_%(label)s_features.bp'% {'stack': stack,
                                                                                               'label': label})

    sys.stderr.write('predict: %.2f seconds\n' % (time.time() - t))

In [None]:
model_name = 'experiment0317'
model_iteration = 6

model = mx.model.FeedForward.load(os.path.join(model_dir, model_name), model_iteration, ctx=mx.gpu())

# model.arg_params['fullc_bias'].asnumpy()

# fc_output = model.symbol.get_internals()['fc_output']
flatten_output = model.symbol.get_internals()['flatten_output']
# fc_output = model.symbol.get_internals()['fullc_output']
# sm_output = model.symbol.get_internals()['softmax_output']
# grouped_output = mx.symbol.Group([flatten_output, sm_output])

model = mx.model.FeedForward(ctx=mx.gpu(), symbol=flatten_output, num_epoch=model_iteration,
                            arg_params=model.arg_params, aux_params=model.aux_params,
                            allow_extra_params=True)

In [None]:
for label in ['6N', '6N_surround']:

    patches1 = patches_allClasses[label]

    patches = np.concatenate(patches1)

    n = len(patches)

    batch_size = 128 # increasing to 892 does not save any time

    train_iter = mx.io.NDArrayIter(
        patches, 
        np.zeros((n, ), np.int),
        batch_size = batch_size,
        shuffle=False
    )
    #         sys.stderr.write('load iterator: %.2f seconds\n' % (time.time() - t))

    t = time.time()

    features = model.predict(train_iter)

    bp.pack_ndarray_file(features, training_features_dir + '/%(stack)s_%(label)s_features.bp'% {'stack': stack,
                                                                                               'label': label})

    sys.stderr.write('predict: %.2f seconds\n' % (time.time() - t))

In [None]:
[(label, len(patches)) for label, patches in patches_allClasses_arr.iteritems()]