In [72]:
def get_example(load_example, eval_tracker, model, get_offsets):
  """Generates individual training examples.

  Args:
    load_example: callable returning a tuple of image and label ndarrays
                  as well as the seed coordinate and volume name of the example
    eval_tracker: EvalTracker object
    model: FFNModel object
    get_offsets: iterable of (x, y, z) offsets to investigate within the
        training patch

  Yields:
    tuple of:
      seed array, shape [1, z, y, x, 1]
      image array, shape [1, z, y, x, 1]
      label array, shape [1, z, y, x, 1]
  """
  seed_shape = train_canvas_size(model).tolist()[::-1]

  while True:
    full_patches, full_labels, loss_weights, coord, volname = load_example()
    # Always start with a clean seed.
    seed = logit(mask.make_seed(seed_shape, 1, pad=FLAGS.seed_pad))

    for off in get_offsets(model, seed):
      predicted = mask.crop_and_pad(seed, off, model.input_seed_size[::-1])
      patches = mask.crop_and_pad(full_patches, off, model.input_image_size[::-1])
      labels = mask.crop_and_pad(full_labels, off, model.pred_mask_size[::-1])
      weights = mask.crop_and_pad(loss_weights, off, model.pred_mask_size[::-1])

      # Necessary, since the caller is going to update the array and these
      # changes need to be visible in the following iterations.
      assert predicted.base is seed
      yield predicted, patches, labels, weights

    eval_tracker.add_patch(
        full_labels, seed, loss_weights, coord, volname, full_patches)
    break

In [73]:
def get_batch(load_example, eval_tracker, model, batch_size, get_offsets):
  """Generates batches of training examples.

  Args:
    load_example: callable returning a tuple of image and label ndarrays
                  as well as the seed coordinate and volume name of the example
    eval_tracker: EvalTracker object
    model: FFNModel object
    batch_size: desired batch size
    get_offsets: iterable of (x, y, z) offsets to investigate within the
        training patch

  Yields:
    tuple of:
      seed array, shape [b, z, y, x, 1]
      image array, shape [b, z, y, x, 1]
      label array, shape [b, z, y, x, 1]

    where 'b' is the batch_size.
  """
  def _batch(iterable):
    for batch_vals in iterable:
      # `batch_vals` is sequence of `batch_size` tuples returned by the
      # `get_example` generator, to which we apply the following transformation:
      #   [(a0, b0), (a1, b1), .. (an, bn)] -> [(a0, a1, .., an),
      #                                         (b0, b1, .., bn)]
      # (where n is the batch size) to get a sequence, each element of which
      # represents a batch of values of a given type (e.g., seed, image, etc.)
      yield zip(*batch_vals)

  # Create a separate generator for every element in the batch. This generator
  # will automatically advance to a different training example once the allowed
  # moves for the current location are exhausted.
  for seeds, patches, labels, weights in _batch(six.moves.zip(
      *[get_example(load_example, eval_tracker, model, get_offsets) for _
        in range(batch_size)])):

    batched_seeds = np.concatenate(seeds)

    yield (batched_seeds, np.concatenate(patches), np.concatenate(labels),
           np.concatenate(weights))

    # batched_seed is updated in place with new predictions by the code
    # calling get_batch. Here we distribute these updated predictions back
    # to the buffer of every generator.
    for i in range(batch_size):
      seeds[i][:] = batched_seeds[i, ...]

In [141]:
import six
import numpy as np


In [142]:
def _batch(iterable):
    for batch_vals in iterable:
        yield zip(*batch_vals)
        

In [181]:
def get_example():
    while True:
        for i in range(27): # offset duplicates
#             print(i, "it is i")
            predicted = np.full((1, 49, 49, 49, 1), 0)
            patches = np.full((1, 49, 49, 49, 1), 1)
            labels = np.full((1, 49, 49, 49, 1), 2)
            yield predicted, patches, labels
    

In [162]:
print(next(get_example())[0].shape) # seed array
print(next(get_example())[1].shape) # image array
print(next(get_example())[2].shape) # label array

(1, 49, 49, 49, 1)
(1, 49, 49, 49, 1)
(1, 49, 49, 49, 1)


In [163]:
six.moves.zip(*[get_example() for _ in range(batch_size)])

<itertools.izip at 0x7fbb9d05a908>

In [164]:
_batch(six.moves.zip(*[get_example() for _ in range(batch_size)]))

<generator object _batch at 0x7fbb9cc2b280>

In [227]:
seeds, patches, labels = next(_batch(six.moves.zip(*[get_example() for _ in range(batch_size)])))

len(seeds)


8

In [222]:
def get_batch():
    batch_size = 8
    step = 0

    for seeds, patches, labels in _batch(six.moves.zip(*[get_example() for _ in range(batch_size)])): 

        batched_seeds = np.concatenate(seeds)
        print("seeds shape", len(seeds))

        yield (batched_seeds, np.concatenate(patches), np.concatenate(labels))
        
        for i in range(batch_size):
            print("batched_seeds[i, ...]", i,batched_seeds[i, ...].shape)
            seeds[i][:] = batched_seeds[i, ...]
            

In [226]:
seed, patches, labels = next(get_batch())
    

('seeds shape', 8)


In [224]:
print(seed.shape)
print(patches.shape)
print(labels.shape)


(8, 49, 49, 49, 1)
(8, 49, 49, 49, 1)
(8, 49, 49, 49, 1)


In [180]:
tmp = []
for i, j, k in get_batch():
    print(i.shape, j.shape, k.shape)
#     tmp.append([i.shape, j.shape, k.shape])

('seeds shape', 8)
((8, 49, 49, 49, 1), (8, 49, 49, 49, 1), (8, 49, 49, 49, 1))
('batched_seeds[i, ...]', 0, (49, 49, 49, 1))
('batched_seeds[i, ...]', 1, (49, 49, 49, 1))
('batched_seeds[i, ...]', 2, (49, 49, 49, 1))
('batched_seeds[i, ...]', 3, (49, 49, 49, 1))
('batched_seeds[i, ...]', 4, (49, 49, 49, 1))
('batched_seeds[i, ...]', 5, (49, 49, 49, 1))
('batched_seeds[i, ...]', 6, (49, 49, 49, 1))
('batched_seeds[i, ...]', 7, (49, 49, 49, 1))
('seeds shape', 8)
((8, 49, 49, 49, 1), (8, 49, 49, 49, 1), (8, 49, 49, 49, 1))
('batched_seeds[i, ...]', 0, (49, 49, 49, 1))
('batched_seeds[i, ...]', 1, (49, 49, 49, 1))
('batched_seeds[i, ...]', 2, (49, 49, 49, 1))
('batched_seeds[i, ...]', 3, (49, 49, 49, 1))
('batched_seeds[i, ...]', 4, (49, 49, 49, 1))
('batched_seeds[i, ...]', 5, (49, 49, 49, 1))
('batched_seeds[i, ...]', 6, (49, 49, 49, 1))
('batched_seeds[i, ...]', 7, (49, 49, 49, 1))
('seeds shape', 8)
((8, 49, 49, 49, 1), (8, 49, 49, 49, 1), (8, 49, 49, 49, 1))
('batched_seeds[i, ...]'

((8, 49, 49, 49, 1), (8, 49, 49, 49, 1), (8, 49, 49, 49, 1))
('batched_seeds[i, ...]', 0, (49, 49, 49, 1))
('batched_seeds[i, ...]', 1, (49, 49, 49, 1))
('batched_seeds[i, ...]', 2, (49, 49, 49, 1))
('batched_seeds[i, ...]', 3, (49, 49, 49, 1))
('batched_seeds[i, ...]', 4, (49, 49, 49, 1))
('batched_seeds[i, ...]', 5, (49, 49, 49, 1))
('batched_seeds[i, ...]', 6, (49, 49, 49, 1))
('batched_seeds[i, ...]', 7, (49, 49, 49, 1))
('seeds shape', 8)
((8, 49, 49, 49, 1), (8, 49, 49, 49, 1), (8, 49, 49, 49, 1))
('batched_seeds[i, ...]', 0, (49, 49, 49, 1))
('batched_seeds[i, ...]', 1, (49, 49, 49, 1))
('batched_seeds[i, ...]', 2, (49, 49, 49, 1))
('batched_seeds[i, ...]', 3, (49, 49, 49, 1))
('batched_seeds[i, ...]', 4, (49, 49, 49, 1))
('batched_seeds[i, ...]', 5, (49, 49, 49, 1))
('batched_seeds[i, ...]', 6, (49, 49, 49, 1))
('batched_seeds[i, ...]', 7, (49, 49, 49, 1))
('seeds shape', 8)
((8, 49, 49, 49, 1), (8, 49, 49, 49, 1), (8, 49, 49, 49, 1))
('batched_seeds[i, ...]', 0, (49, 49, 49, 1

In [228]:
from scipy.special import logit


In [240]:
p = 0.95

seed = np.full((1, 49, 49, 49, 1), 0.05)
seed[:, 24, 24, 24, :] = p
seed[:, 24, 24, 24, :]
logit_seed = logit(seed)
logit_seed[:, 24, 24, 24, :]
np.where(logit_seed == logit(0.95))


(array([0]), array([24]), array([24]), array([24]), array([0]))