<a href="https://colab.research.google.com/github/jonbarron/hist_thresh/blob/master/interactive_viewer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
##### Copyright 2020 Google LLC. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Load an image
import ipywidgets as widgets
from IPython.display import display
uploader = widgets.FileUpload(multiple=False)
display(uploader)

In [None]:
#@title An interactive viewer for tuning hyperparameters

# Load the image
from PIL import Image
import io
import numpy as np

image_data = list(uploader.value.values())[0]['content']
im = np.array(Image.open(io.BytesIO(image_data)))

# A fast numpy reference implementation of GHT, as per
# "A Generalization of Otsu's Method and Minimum Error Thresholding"
# Jonathan T. Barron, ECCV, 2020

csum = lambda z: np.cumsum(z)[:-1]
dsum = lambda z: np.cumsum(z[::-1])[-2::-1]
argmax = lambda x, f: np.mean(x[:-1][f == np.max(f)])  # Use the mean for ties.
clip = lambda z: np.maximum(1e-30, z)

def preliminaries(n, x):
  """Some math that is shared across multiple algorithms."""
  assert np.all(n >= 0)
  x = np.arange(len(n), dtype=n.dtype) if x is None else x
  assert np.all(x[1:] >= x[:-1])
  w0 = clip(csum(n))
  w1 = clip(dsum(n))
  p0 = w0 / (w0 + w1)
  p1 = w1 / (w0 + w1)
  mu0 = csum(n * x) / w0
  mu1 = dsum(n * x) / w1
  d0 = csum(n * x**2) - w0 * mu0**2
  d1 = dsum(n * x**2) - w1 * mu1**2
  return x, w0, w1, p0, p1, mu0, mu1, d0, d1

def GHT(n, x=None, nu=0, tau=0, kappa=0, omega=0.5, prelim=None):
  assert nu >= 0
  assert tau >= 0
  assert kappa >= 0
  assert omega >= 0 and omega <= 1
  x, w0, w1, p0, p1, _, _, d0, d1 = prelim or preliminaries(n, x)
  v0 = clip((p0 * nu * tau**2 + d0) / (p0 * nu + w0))
  v1 = clip((p1 * nu * tau**2 + d1) / (p1 * nu + w1))
  f0 = -d0 / v0 - w0 * np.log(v0) + 2 * (w0 + kappa *      omega)  * np.log(w0)
  f1 = -d1 / v1 - w1 * np.log(v1) + 2 * (w1 + kappa * (1 - omega)) * np.log(w1)
  return argmax(x, f0 + f1), f0 + f1

def im2hist(im, zero_extents=False):
  # Convert an image to grayscale, bin it, and optionally zero out the first and last bins.
  max_val = np.iinfo(im.dtype).max
  x = np.arange(max_val+1)
  e = np.arange(-0.5, max_val+1.5)
  assert len(im.shape) in [2, 3]
  im_bw = np.amax(im[...,:3], -1) if len(im.shape) == 3 else im
  n = np.histogram(im_bw, e)[0]
  if zero_extents:
    n[0] = 0
    n[-1] = 0
  return n, x, im_bw

# Precompute a histogram and some integrals.
n, x, im_bw = im2hist(im)
prelim = preliminaries(n, x)

import matplotlib.pyplot as plt

default_nu = np.sum(n)
default_tau = np.sqrt(1/12)
default_kappa = np.sum(n)
default_omega = 0.5

_nu = default_nu
_tau = default_tau
_kappa = default_kappa
_omega = default_omega

continuous_update = True
nu_slider = widgets.FloatLogSlider(min=-30, max=30, value=_nu, continuous_update=continuous_update)
tau_slider = widgets.FloatLogSlider(min=-30, max=30, value=_tau, continuous_update=continuous_update)
kappa_slider = widgets.FloatLogSlider(min=-30, max=30, value=_kappa, continuous_update=continuous_update)
omega_slider = widgets.FloatSlider(min=0, max=1, value=_omega, step=0.01, continuous_update=continuous_update)

def render():
  global _nu, _tau, _kappa, _omega
  t, score = GHT(n, x, _nu, _tau, _kappa, _omega, prelim)

  plt.figure(0, figsize=(16,5))
  plt.subplot(1,3,1)
  plt.imshow(im, cmap='gray')
  plt.axis('off')

  plt.subplot(1,3,2)
  plt.imshow(im_bw > t, cmap='gray', vmin=0, vmax=1)
  plt.gca().set_xticks([])
  plt.gca().set_yticks([])

  plt.subplot(1,3,3)
  normalize = lambda x : (x - np.min(score)) * np.max(n) / (np.max(score) - np.min(score))
  plt.plot((x[:-1] + x[1:])/2, normalize(score))
  plt.scatter(t, normalize(score[int(t)]))
  plt.bar(x, n, width=1)
  plt.gca().set_yticks([]);

def update(nu=None, tau=None, kappa=None, omega=None):
  global _nu, _tau, _kappa, _omega
  _nu = nu or _nu
  _tau = tau or _tau
  _kappa = kappa or _kappa
  _omega = omega or _omega

def reset(nu=None, tau=None, kappa=None, omega=None):
  global nu_slider, tau_slider, kappa_slider, omega_slider
  if nu:
    nu_slider.value = nu
  if tau:
    tau_slider.value = tau
  if kappa:
    kappa_slider.value = kappa
  if omega:
    omega_slider.value = omega

def update_and_render(nu=None, tau=None, kappa=None, omega=None):
  update(nu, tau, kappa, omega)
  render()


default_button = widgets.Button(description="Default")
def default_fun(b):
  reset(nu=default_nu, tau=default_tau, kappa=default_kappa, omega=default_omega)
default_button.on_click(default_fun)

otsu_button = widgets.Button(description="Otsu's Method")
def otsu_fun(b):
  reset(nu=1e30, tau=1e-30, kappa=1e-30)
otsu_button.on_click(otsu_fun)

met_button = widgets.Button(description="MET")
def met_fun(b):
  reset(nu=1e-30, kappa=1e-30,)
met_button.on_click(met_fun)

percentile_button = widgets.Button(description="Percentile")
def percentile_fun(b):
  reset(nu=1e-30, kappa=1e30)
percentile_button.on_click(percentile_fun)

display(default_button)
display(widgets.HBox([otsu_button, met_button, percentile_button]), widgets.Output())

widgets.interact(update_and_render,
                 nu=nu_slider,
                 tau=tau_slider,
                 kappa=kappa_slider,
                 omega=omega_slider);
