# End-to-end DeepCell Benchmark

**IMPORTANT** - set notebook runtime to `TensorFlow 2.12 (Local)`  
This notebook runs an input image through the DeepCell Mesmer application to generate whole-cell cellular segmentation predictions. It benchmarks the above along several measures. (todo: document benchmark CSV output format)

This notebook accepts a numpy array of two input channels: the nuclear & membrane channels in that order. The array must be named `input_channels` in an `npz` file.

The array is given as a URL, for example: `gs://a-data-bucket/a-path/input_channels.npz`. In principle `smart_open` works for "all" cloud providers but I have only tested (+ specified dependencies) for GCS (Google Cloud Storage).

The notebook optionally outputs the predicted segments to `predictions_output_path` if provided, into an `npz` file containing a single array: `predictions`.

If the `visualize` parameter is `True`, the notebook generates 2 image files: `input_png_output_path` and `predictions_png_output_path`. A green/blue visualization will be written to `input_png_output_path`, and the same visualization with predictions overlaid will be written to `predictions_png_output_path`.

In [None]:
# Install DeepCell

!pip install deepcell --user

In [None]:
# Verify DeepCell Version

import deepcell
print("DeepCell version:", deepcell.__version__)

In [None]:
# Imports

import csv
from datetime import datetime, timezone
import deepcell
from deepcell.applications import Mesmer
import io
import logging
import math
import numpy as np
import os
import psutil
import re
import resource
import smart_open
import tensorflow as tf
import timeit
import urllib.parse

## Parameters

Parameters:

* `input_channels_path` (required): the URL to the input numpy array containing 2 input channels (nuclear & membrane).
* `model_path` : the URL to the MultiplexSegmentation Tensorflow Keras model.
* `predictions_output_path` : the URL to write predictions, or `None`/`""` if not saving predictions.

In [None]:
# This cell is a notebook 'parameters' cell.

input_channels_path = 'gs://davids-genomics-data-public/cellular-segmentation/deep-cell/vanvalenlab-multiplex-20200810_tissue_dataset/mesmer-sample-3-dev/input_channels.npz'

model_path = os.path.expanduser('~') + '/.keras/models/MultiplexSegmentation'

predictions_output_path = None
# predictions_output_path = 'gs://davids-genomics-data-public/cellular-segmentation/deep-cell/vanvalenlab-multiplex-20200810_tissue_dataset/mesmer-sample-3-dev/segmentation_predictions.npz'

# Set to GCP region + instance ID, or None for local execution.
notebook_instance_id = None
location = 'us-west1'

# Set this to True to visualize & save input & prediction.
# If true, the paths may be provided to save the visualizations to storage.
visualize = True
input_png_output_path = predictions_png_output_path = None
# input_png_output_path = 'gs://davids-genomics-data-public/cellular-segmentation/deep-cell/vanvalenlab-multiplex-20200810_tissue_dataset/mesmer-sample-3-dev/input.png'
# predictions_png_output_path = 'gs://davids-genomics-data-public/cellular-segmentation/deep-cell/vanvalenlab-multiplex-20200810_tissue_dataset/mesmer-sample-3-dev/segmentation_predictions.png'

# Benchmark warm-up

It takes tens of seconds to load the Mesmer tensorflow model. Let's assume a production environment would have this primed already. Here we handle such priming aka warm-up as well as Python imports.

In [None]:
# Model warm-up
t = timeit.default_timer()
model = tf.keras.models.load_model(model_path)
print("Loaded model in %s s" % (timeit.default_timer() - t))

app = Mesmer(model=model)

# Need to reset top-level logging for intercept to work.
# I dunno 🤷🏻‍♂️ There's probably a better way to do this...
logging.basicConfig(force=True)

# Prediction

In [None]:
start_time = timeit.default_timer()

# Load inputs

t = timeit.default_timer()
with smart_open.open(input_channels_path, 'rb') as input_channel_file:
    with np.load(input_channel_file) as loader:
        # An array of shape [height, width, channel] containing intensity of nuclear & membrane channels
        input_channels = loader['input_channels']
input_load_time_s = timeit.default_timer() - t

print('Loaded input in %s s' % input_load_time_s)

# Generate predictions

## Intercept log (many shenanigans, such hacking)
logger = logging.getLogger()
old_level = logger.getEffectiveLevel()
logger.setLevel(logging.DEBUG)
logs_buffer = io.StringIO()
buffer_log_handler = logging.StreamHandler(logs_buffer)
logger.addHandler(buffer_log_handler)

## The actual prediction

t = timeit.default_timer()

try:
    segmentation_predictions = app.predict(input_channels[np.newaxis, ...], image_mpp=0.5)[0]
    prediction_success = True
    
except Exception as e:
    # The exception is nom-nom'd. Safe? We'll see 🤔
    prediction_success = False
    logger.error('Prediction exception: %s', e)

prediction_time_s = timeit.default_timer() - t
total_time_s = timeit.default_timer() - start_time

## Undo log intercept
logger.removeHandler(buffer_log_handler)
logger.setLevel(old_level)

## Wrap up
print('Prediction finished in %s s' % prediction_time_s)
print('Overall operation finished in %s s' % total_time_s)

In [None]:
# Parse the intercepted debug logs to extract step timing.

debug_logs = logs_buffer.getvalue()
pattern = r"(?sm)Pre-processed data with mesmer_preprocess in (.+?) s.*Model inference finished in (.+?) s.*Post-processed results with mesmer_postprocess in (.+?) s"
match = re.search(re.compile(pattern, re.MULTILINE), debug_logs)

if match:
    preprocess_time_s = float(match.group(1))
    inference_time_s = float(match.group(2))
    postprocess_time_s = float(match.group(3))
else:
    logger.warning("Couldn't parse step timings from debug_logs")
    preprocess_time_s = inference_time_s = postprocess_time_s = math.nan

## Save predictions [optional]

In [None]:
if predictions_output_path:
    with smart_open.open(predictions_output_path, 'wb') as predictions_file:
        np.savez_compressed(predictions_file, predictions=segmentation_predictions)

# Visualization [optional]

## Input channel visualization

In [None]:
if visualize:
    from deepcell.utils.plot_utils import create_rgb_image
    from PIL import Image

    nuclear_color = 'green'
    membrane_color = 'blue'

    # Create rgb overlay of image data for visualization
    # Note that this normalizes the values from "whatever" to rgb range 0..1
    input_rgb = create_rgb_image(input_channels[np.newaxis, ...], channel_colors=[nuclear_color, membrane_color])[0]

    if input_png_output_path:
        # The png needs to normalize rgb values from 0..1, so normalize to 0..255
        im = Image.fromarray((input_rgb * 255).astype(np.uint8))
        with smart_open.open(input_png_output_path, 'wb') as input_png_file:
            im.save(input_png_file, mode='RGB')
        del im

    from matplotlib import pyplot as plt

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(input_rgb)
    ax.set_title('Predictions')
    plt.show()

## Prediction overlay visualization

In [None]:
if visualize:
    from deepcell.utils.plot_utils import make_outline_overlay
    
    overlay_data = make_outline_overlay(
        rgb_data=input_rgb[np.newaxis, ...],
        predictions=segmentation_predictions[np.newaxis, ...],
    )[0]

    from PIL import Image

    if predictions_png_output_path:
        # The rgb values are 0..1, so normalize to 0..255
        im = Image.fromarray((overlay_data * 255).astype(np.uint8))
        with smart_open.open(predictions_png_output_path, 'wb') as predictions_png_file:
            im.save(predictions_png_file, mode='RGB')
        del im

    from matplotlib import pyplot as plt

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(overlay_data)
    ax.set_title('Predictions')
    plt.show()

# Benchmark data

In [None]:
headers = [
    'Input file id', 'File size (MB)', 'Benchmark datetime (UTC)',
    'Machine type', 'GPU type', '# GPUs',
    'Success?', 'Total time (s)', 'Peak memory (GB)',
    'Load time (s)', 'Total prediction time (s)', 'Prediction overhead (s)',
    'Predict preprocess time (s)', 'Predict inference time (s)', 'Predict postprocess time (s)',
    'deepcell-tf version'
]

parsed_url = urllib.parse.urlparse(input_channels_path)
filename = parsed_url.path.split("/")[-2]
file_size = round(input_channels.nbytes / 1000 / 1000, 2)

if notebook_instance_id:
    # For running on vertex AI:
    # Shell command:
    # $ gcloud notebooks runtimes describe --format="value(virtualMachine.virtualMachineConfig.machineType)" --location=<region> <notebook_id>
    machine_type = 'todo-machine'
else:
    # assume a generic python environment
    # See also:
    # https://docs.python.org/3.10/library/os.html#os.cpu_count
    try:
        num_cpus = len(os.sched_getaffinity(0))
    except AttributeError:
        num_cpus = os.cpu_count()
    total_mem = psutil.virtual_memory().total
    machine_type = 'local {} CPUs {} GB RAM'.format(num_cpus, round(total_mem/1000000000, 1))

peak_mem_b = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
prediction_overhead_s = prediction_time_s - preprocess_time_s - inference_time_s - postprocess_time_s

# Write benchmarking data as CSV:

output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC)
writer.writerow(headers)

deepcell_version = deepcell.__version__

writer.writerow([
    filename, file_size, datetime.now(timezone.utc),
    machine_type, 'todo-gpu-type', 'todo-gpu-count',
    prediction_success, round(total_time_s, 2), round(peak_mem_b / 1000000000, 1),
    round(input_load_time_s, 2), round(prediction_time_s, 2), round(prediction_overhead_s, 2),
    round(preprocess_time_s, 2), round(inference_time_s, 2), round(postprocess_time_s, 2),
    deepcell_version,
])

print(output.getvalue())