# Refactoring HistoGAN: Validation and Benchmark

The refactored histogram blocks extracted common functionality and -data into a separate base class.
Here we test, whether the results match the original implementation and assess potential performance benefits.

## Reference Implementation

Below are the reference implemenations of all three histogram blocks: RGB-uv, rg-chroma, and Lab.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

EPS = 1e-6

class OriginalRGBuvHistBlock(nn.Module):
  def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=True,
               hist_boundary=None, green_only=False, device='cuda'):
    super().__init__()
    self.h = h
    self.insz = insz
    self.device = device
    self.resizing = resizing
    self.method = method
    self.intensity_scale = intensity_scale
    self.green_only = green_only
    if hist_boundary is None:
      hist_boundary = [-3, 3]
    hist_boundary.sort()
    self.hist_boundary = hist_boundary
    if self.method == 'thresholding':
      self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h
    else:
      self.sigma = sigma

  def forward(self, x):
    x = torch.clamp(x, 0, 1)
    if x.shape[2] > self.insz or x.shape[3] > self.insz:
      if self.resizing == 'interpolation':
        x_sampled = F.interpolate(x, size=(self.insz, self.insz),
                                  mode='bilinear', align_corners=False)
      elif self.resizing == 'sampling':
        inds_1 = torch.LongTensor(
          np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
          device=self.device)
        inds_2 = torch.LongTensor(
          np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
          device=self.device)
        x_sampled = x.index_select(2, inds_1)
        x_sampled = x_sampled.index_select(3, inds_2)
      else:
        raise Exception(
          f'Wrong resizing method. It should be: interpolation or sampling. '
          f'But the given value is {self.resizing}.')
    else:
      x_sampled = x

    L = x_sampled.shape[0]  # size of mini-batch
    if x_sampled.shape[1] > 3:
      x_sampled = x_sampled[:, :3, :, :]
    X = torch.unbind(x_sampled, dim=0)
    hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2,
                         self.h, self.h)).to(device=self.device)
    for l in range(L):
      I = torch.t(torch.reshape(X[l], (3, -1)))
      II = torch.pow(I, 2)
      if self.intensity_scale:
        Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
                             dim=1)
      else:
        Iy = 1
      if not self.green_only:
        Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] +
                                                                   EPS), dim=1)
        Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] +
                                                                   EPS), dim=1)
        diff_u0 = abs(
          Iu0 - torch.unsqueeze(torch.tensor(np.linspace(
            self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
            dim=0).to(self.device))
        diff_v0 = abs(
          Iv0 - torch.unsqueeze(torch.tensor(np.linspace(
            self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
            dim=0).to(self.device))
        if self.method == 'thresholding':
          diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
          diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
        elif self.method == 'RBF':
          diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_u0 = torch.exp(-diff_u0)  # Radial basis function
          diff_v0 = torch.exp(-diff_v0)
        elif self.method == 'inverse-quadratic':
          diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_u0 = 1 / (1 + diff_u0)  # Inverse quadratic
          diff_v0 = 1 / (1 + diff_v0)
        else:
          raise Exception(
            f'Wrong kernel method. It should be either thresholding, RBF,'
            f' inverse-quadratic. But the given value is {self.method}.')
        diff_u0 = diff_u0.type(torch.float32)
        diff_v0 = diff_v0.type(torch.float32)
        a = torch.t(Iy * diff_u0)
        hists[l, 0, :, :] = torch.mm(a, diff_v0)

      Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
                            dim=1)
      Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
                            dim=1)
      diff_u1 = abs(
        Iu1 - torch.unsqueeze(torch.tensor(np.linspace(
          self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
          dim=0).to(self.device))
      diff_v1 = abs(
        Iv1 - torch.unsqueeze(torch.tensor(np.linspace(
          self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
          dim=0).to(self.device))

      if self.method == 'thresholding':
        diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
        diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
      elif self.method == 'RBF':
        diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
                            2) / self.sigma ** 2
        diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
                            2) / self.sigma ** 2
        diff_u1 = torch.exp(-diff_u1)  # Gaussian
        diff_v1 = torch.exp(-diff_v1)
      elif self.method == 'inverse-quadratic':
        diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
                            2) / self.sigma ** 2
        diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
                            2) / self.sigma ** 2
        diff_u1 = 1 / (1 + diff_u1)  # Inverse quadratic
        diff_v1 = 1 / (1 + diff_v1)

      diff_u1 = diff_u1.type(torch.float32)
      diff_v1 = diff_v1.type(torch.float32)
      a = torch.t(Iy * diff_u1)
      if not self.green_only:
        hists[l, 1, :, :] = torch.mm(a, diff_v1)
      else:
        hists[l, 0, :, :] = torch.mm(a, diff_v1)

      if not self.green_only:
        Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] +
                                                                   EPS), dim=1)
        Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] +
                                                                   EPS), dim=1)
        diff_u2 = abs(
          Iu2 - torch.unsqueeze(torch.tensor(np.linspace(
            self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
            dim=0).to(self.device))
        diff_v2 = abs(
          Iv2 - torch.unsqueeze(torch.tensor(np.linspace(
            self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
            dim=0).to(self.device))
        if self.method == 'thresholding':
          diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
          diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
        elif self.method == 'RBF':
          diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_u2 = torch.exp(-diff_u2)  # Gaussian
          diff_v2 = torch.exp(-diff_v2)
        elif self.method == 'inverse-quadratic':
          diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
                              2) / self.sigma ** 2
          diff_u2 = 1 / (1 + diff_u2)  # Inverse quadratic
          diff_v2 = 1 / (1 + diff_v2)
        diff_u2 = diff_u2.type(torch.float32)
        diff_v2 = diff_v2.type(torch.float32)
        a = torch.t(Iy * diff_u2)
        hists[l, 2, :, :] = torch.mm(a, diff_v2)

    # normalization
    hists_normalized = hists / (
        ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)

    return hists_normalized

class OriginalrgChromaHistBlock(nn.Module):
  def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=False,
               hist_boundary=None, device='cuda'):
    super().__init__()
    self.h = h
    self.insz = insz
    self.device = device
    self.resizing = resizing
    self.method = method
    self.intensity_scale = intensity_scale
    if hist_boundary is None:
      hist_boundary = [0, 1]
    hist_boundary.sort()
    self.hist_boundary = hist_boundary
    if self.method == 'thresholding':
      self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h
    else:
      self.sigma = sigma

  def forward(self, x):
    x = torch.clamp(x, 0, 1)
    if x.shape[2] > self.insz or x.shape[3] > self.insz:
      if self.resizing == 'interpolation':
        x_sampled = F.interpolate(x, size=(self.insz, self.insz),
                                  mode='bilinear', align_corners=False)
      elif self.resizing == 'sampling':
        inds_1 = torch.LongTensor(
          np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
          device=self.device)
        inds_2 = torch.LongTensor(
          np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
          device=self.device)
        x_sampled = x.index_select(2, inds_1)
        x_sampled = x_sampled.index_select(3, inds_2)
      else:
        raise Exception(
          f'Wrong resizing method. It should be: interpolation or sampling. '
          f'But the given value is {self.resizing}.')
    else:
      x_sampled = x

    L = x_sampled.shape[0]  # size of mini-batch
    if x_sampled.shape[1] > 3:
      x_sampled = x_sampled[:, :3, :, :]
    X = torch.unbind(x_sampled, dim=0)
    hists = torch.zeros((x_sampled.shape[0], 1, self.h, self.h)).to(
      device=self.device)
    for l in range(L):
      I = torch.t(torch.reshape(X[l], (3, -1)))
      II = torch.pow(I, 2)
      if self.intensity_scale:
        Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
                             dim=1)
      else:
        Iy = 1

      Ir = torch.unsqueeze(I[:, 0] / (torch.sum(I, dim=-1) + EPS), dim=1)
      Ig = torch.unsqueeze(I[:, 1] / (torch.sum(I, dim=-1) + EPS), dim=1)

      diff_r = abs(Ir - torch.unsqueeze(torch.tensor(np.linspace(
        self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
        dim=0).to(self.device))
      diff_g = abs(Ig - torch.unsqueeze(torch.tensor(np.linspace(
        self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
        dim=0).to(self.device))

      if self.method == 'thresholding':
        diff_r = torch.reshape(diff_r, (-1, self.h)) <= self.eps / 2
        diff_g = torch.reshape(diff_g, (-1, self.h)) <= self.eps / 2
      elif self.method == 'RBF':
        diff_r = torch.pow(torch.reshape(diff_r, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_g = torch.pow(torch.reshape(diff_g, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_r = torch.exp(-diff_r)  # Gaussian
        diff_g = torch.exp(-diff_g)
      elif self.method == 'inverse-quadratic':
        diff_r = torch.pow(torch.reshape(diff_r, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_g = torch.pow(torch.reshape(diff_g, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_r = 1 / (1 + diff_r)  # Inverse quadratic
        diff_g = 1 / (1 + diff_g)

      diff_r = diff_r.type(torch.float32)
      diff_g = diff_g.type(torch.float32)
      a = torch.t(Iy * diff_r)

      hists[l, 0, :, :] = torch.mm(a, diff_g)

    # normalization
    hists_normalized = hists / (
        ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)

    return hists_normalized

class OriginalLabHistBlock(nn.Module):
  def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=False,
               hist_boundary=None, device='cuda'):
    super().__init__()
    self.h = h
    self.insz = insz
    self.device = device
    self.resizing = resizing
    self.method = method
    self.intensity_scale = intensity_scale
    if hist_boundary is None:
      hist_boundary = [0, 1]
    hist_boundary.sort()
    self.hist_boundary = hist_boundary
    if self.method == 'thresholding':
      self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h
    else:
      self.sigma = sigma

  def forward(self, x):
    x = torch.clamp(x, 0, 1)
    if x.shape[2] > self.insz or x.shape[3] > self.insz:
      if self.resizing == 'interpolation':
        x_sampled = F.interpolate(x, size=(self.insz, self.insz),
                                  mode='bilinear', align_corners=False)
      elif self.resizing == 'sampling':
        inds_1 = torch.LongTensor(
          np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
          device=self.device)
        inds_2 = torch.LongTensor(
          np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
          device=self.device)
        x_sampled = x.index_select(2, inds_1)
        x_sampled = x_sampled.index_select(3, inds_2)
      else:
        raise Exception(
          f'Wrong resizing method. It should be: interpolation or sampling. '
          f'But the given value is {self.resizing}.')
    else:
      x_sampled = x

    L = x_sampled.shape[0]  # size of mini-batch
    if x_sampled.shape[1] > 3:
      x_sampled = x_sampled[:, :3, :, :]
    X = torch.unbind(x_sampled, dim=0)
    hists = torch.zeros((x_sampled.shape[0], 1, self.h, self.h)).to(
      device=self.device)
    for l in range(L):
      I = torch.t(torch.reshape(X[l], (3, -1)))
      if self.intensity_scale:
        Il = torch.unsqueeze(I[:, 0], dim=1)
      else:
        Il = 1

      Ia = torch.unsqueeze(I[:, 1], dim=1)
      Ib = torch.unsqueeze(I[:, 2], dim=1)

      diff_a = abs(Ia - torch.unsqueeze(torch.tensor(np.linspace(
        self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
        dim=0).to(self.device))
      diff_b = abs(Ib - torch.unsqueeze(torch.tensor(np.linspace(
        self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
        dim=0).to(self.device))

      if self.method == 'thresholding':
        diff_a = torch.reshape(diff_a, (-1, self.h)) <= self.eps / 2
        diff_b = torch.reshape(diff_b, (-1, self.h)) <= self.eps / 2
      elif self.method == 'RBF':
        diff_a = torch.pow(torch.reshape(diff_a, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_b = torch.pow(torch.reshape(diff_b, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_a = torch.exp(-diff_a)  # Gaussian
        diff_b = torch.exp(-diff_b)
      elif self.method == 'inverse-quadratic':
        diff_a = torch.pow(torch.reshape(diff_a, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_b = torch.pow(torch.reshape(diff_b, (-1, self.h)),
                           2) / self.sigma ** 2
        diff_a = 1 / (1 + diff_a)  # Inverse quadratic
        diff_b = 1 / (1 + diff_b)

      diff_a = diff_a.type(torch.float32)
      diff_b = diff_b.type(torch.float32)
      a = torch.t(Il * diff_a)

      hists[l, 0, :, :] = torch.mm(a, diff_b)

    # normalization
    hists_normalized = hists / (
        ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)

    return hists_normalized

## Refactored Implementation

Next, let's define the refactored versions.
Since pixel sampling, histogram value scaling, and kernel methods are the same across all histogram blocks, it makes sense to extract them into separate functions.
We can also do some micro-optimizations here, like replacing divisions with multiplications where applicable.

### Kernel Functions

We start by defining the kernel functions for intensity scaling, resizing, and pixel counting.

Minor changes: first, squaring the input for intensity scaling has been moved into the intensity scaling function itself.
This means that when `intensity_scale` is set to `False`, the histogram block's `forward()` method no longer includes this
calculation, which saves on memory and computation time (only calculate what we need and when we actually need it).

Another change can be found in the sampling method. Here, we no longer use `LongTensor`, which creates 64-bit indexes and
is slow on many consumer as well as professional GPU devices. It's also unnecessary, since the index values only exceed
the 32-bit range once we get to image sizes beyond 45k x 45k pixels, i.e. √(2^31) by √(2^31) images. Given that a 3-channel
image of this size would require ~6GiB or VRAM, it's reasonable to assume that we can limit ourselves to sub-2 gigapixel
images for the time being.

In [None]:
from typing import Union

Device = Union[str, torch.device]


def no_scaling(_: torch.Tensor) -> int:
    return 1


def intensity_scaling(X: torch.Tensor) -> torch.Tensor:
    XX = X ** 2
    return (XX[:, 0] + XX[:, 1] + XX[:, 2] + EPS).sqrt().unsqueeze(dim=1)


def resizing_interpolate(max_size: int, X: torch.Tensor) -> torch.Tensor:
    H, W = X.shape[2:]
    if H > max_size or W > max_size:
        return F.interpolate(
            X, size=(max_size, max_size), mode='bilinear', align_corners=False
        )
    return X


def resizing_sample(
    h: int, max_size: int, device: Device, X: torch.Tensor
) -> torch.Tensor:
    H, W = X.shape[2:]
    if H > max_size or W > max_size:
        index_H = torch.linspace(0, H - H/h, h, dtype=torch.int32).to(device)
        index_W = torch.linspace(0, W - W/h, h, dtype=torch.int32).to(device)
        sampled = X.index_select(dim=2, index=index_H)
        return sampled.index_select(dim=3, index=index_W)
    return X


def thresholding_kernel(h: int, eps: float, X: torch.Tensor) -> torch.Tensor:
    return (X.reshape(-1, h) <= eps).float()


def rbf_kernel(h: int, inv_sigma_sq: float, X: torch.Tensor) -> torch.Tensor:
    Y = (X.reshape(-1, h) ** 2) * inv_sigma_sq
    return (-Y).exp()


def inverse_quadratic_kernel(
    h: int, inv_sigma_sq: float, X: torch.Tensor
) -> torch.Tensor:
    Y = (X.reshape(-1, h) ** 2) * inv_sigma_sq
    return 1. / (1. + Y)

### HistBlock Base Class

Next we define a base class for all histogram blocks. The base class ctor selects the kernel functions depending on the provided parameter and precalculates tensors that only depend on ctor arguments. This includes the delta-values used for calculating differences.

We can compute these once and upload them onto the device in a suitable data format. Factory functions are used to map function names to actual kernel functions.
Partial function application helps setting kernel function parameters that don't depend on the input tensor and thus can be precalculated.

In [None]:
from functools import partial
from typing import Callable, List, Sequence, Union

_KernelMethod = Callable[[torch.Tensor], torch.Tensor]
_Device = Union[str, torch.device]

def _get_resizing(
    mode: str, h: int, max_size: int, device: Device
) -> _KernelMethod:
    if mode == 'interpolation':
        return partial(resizing_interpolate, max_size)
    elif mode == 'sampling':
        return partial(resizing_sample, h, max_size, device)
    else:
        raise ValueError(
            f'Unknown resizing method: "{mode}". Supported methods are '
            '"interpolation" or "sampling"'
        )


def _get_scaling(intensity_scale: bool):
    return intensity_scaling if intensity_scale else no_scaling 


def _get_kernel(
    method: str, h: int, sigma: float, boundary: Sequence[int]
) -> _KernelMethod: 
    if method == 'thresholding':
        eps = (boundary[1] - boundary[0]) / (2 * h)
        return partial(thresholding_kernel, h, eps)
    elif method == 'RBF':
        inv_sigma_sq = 1 / sigma ** 2
        return partial(rbf_kernel, h, inv_sigma_sq)
    elif method == 'inverse-quadratic':
        inv_sigma_sq = 1 / sigma ** 2
        return partial(inverse_quadratic_kernel, h, inv_sigma_sq)
    else:
        raise ValueError(
            f'Unknown kernel method: "{method}". Supported methods are '
            '"thresholding", "RBF", or "inverse-quadratic".'
        )

class HistBlock(nn.Module):
    def __init__(
        self, h: int, insz: int, resizing: str, method: str, sigma: float,
        intensity_scale: str, hist_boundary: List[int], device: _Device
    ) -> None:
        super().__init__()
        hist_boundary.sort()
        start, end = hist_boundary[:2]
        self.h = h
        self.device = torch.device(device)
        self.resize = _get_resizing(resizing, h, insz, self.device)
        self.kernel = _get_kernel(method, h, sigma, hist_boundary)
        self.scaling = _get_scaling(intensity_scale)
        self.delta = torch.linspace(
            start, end, steps=h, device=self.device, dtype=torch.float32
        ).unsqueeze(dim=0)

    def forward(self, _: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

## Refactored Histogram Blocks

With all the pieces in place, we can modify the original histogram blocks to make use of our common components.

The RGB-uv block can be simplified by observing that the difference calculations for the channels only differ in
tensor indexing. We can extract the calculation into a function and pass these indexes as arguments to the difference
calculation.

We can also speed up historgram normalization by only summing elements once. This will cause a little more numeric
instability due to loss of significance with unsorted tensor values. Performance is a bit better, though, and the
differences should be minimal, but we get back to later in the validation part.

Just from personal prefence and for sake of consistency, member functions are used on tensors where appropriate
(type hints would also help a lot, but I didn't want to change too much).

In [None]:
class RGBuvHistBlock(HistBlock):
  def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=True,
               hist_boundary=None, green_only=False, device='cuda'):
    super().__init__(
      h, insz, resizing, method, sigma, intensity_scale,
      hist_boundary or [-3, 3], device
    )
    self.green_only = green_only

  def forward(self, x):
    x_sampled = self.resize(x.clamp(0, 1))

    N = x_sampled.shape[0]  # size of mini-batch
    if x_sampled.shape[1] > 3:
      x_sampled = x_sampled[:, :3, :, :]
    X = torch.unbind(x_sampled, dim=0)
    C = 1 + int(not self.green_only) * 2
    hists = torch.zeros(N, C, self.h, self.h, device=self.device)
    for n in range(N):
      Ix = X[n].reshape(3, -1).t()
      Iy = self.scaling(Ix)
      if not self.green_only:
        Du, Dv = self._diff_uv(Ix, i=0, j=1, k=2)
        a = (Iy * Du).t()
        hists[n, 0, :, :] = torch.mm(a, Dv)

      Du, Dv = self._diff_uv(Ix, i=1, j=0, k=2)
      a = (Iy * Du).t()
      hists[n, int(not self.green_only), :, :] = torch.mm(a, Dv)

      if not self.green_only:
        Du, Dv = self._diff_uv(Ix, i=2, j=0, k=1)
        a = (Iy * Du).t()
        hists[n, 2, :, :] = torch.mm(a, Dv)

    # normalization
    norm = hists.view(-1, C * self.h * self.h).sum(dim=1).view(-1, 1, 1, 1)
    hists_normalized = hists / (norm + EPS)

    return hists_normalized

  def _diff_uv(self, X: torch.Tensor, i: int, j: int, k: int):
    U = ((X[:, i] + EPS).log() - (X[:, j] + EPS).log()).unsqueeze(dim=1)
    V = ((X[:, i] + EPS).log() - (X[:, k] + EPS).log()).unsqueeze(dim=1)
    Du = (U - self.delta).abs()
    Dv = (V - self.delta).abs()
    Du = self.kernel(Du)
    Dv = self.kernel(Dv)
    return Du, Dv

class rgChromaHistBlock(HistBlock):
  def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=False,
               hist_boundary=None, device='cuda'):
    super().__init__(
      h, insz, resizing, method, sigma, intensity_scale, 
      hist_boundary or [0, 1], device
    )

  def forward(self, x):
    x_sampled = self.resize(x.clamp(0, 1))

    N = x_sampled.shape[0]  # size of mini-batch
    if x_sampled.shape[1] > 3:
      x_sampled = x_sampled[:, :3, :, :]
    X = torch.unbind(x_sampled, dim=0)
    hists = torch.zeros(N, 1, self.h, self.h, device=self.device)
    for n in range(N):
      Ix = X[n].reshape(3, -1).t()
      Inorm = Ix.sum(dim=-1) + EPS
      Ir = (Ix[:, 0] / Inorm).unsqueeze(dim=1)
      Ig = (Ix[:, 1] / Inorm).unsqueeze(dim=1)

      diff_r = (Ir - self.delta).abs()
      diff_g = (Ig - self.delta).abs()
      diff_r = self.kernel(diff_r)
      diff_g = self.kernel(diff_g)
      Iy = self.scaling(Ix)
      a = torch.t(Iy * diff_r)

      hists[n, 0, :, :] = torch.mm(a, diff_g)

    # normalization
    norm = hists.view(-1, self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + EPS
    hists_normalized = hists / norm

    return hists_normalized

class LabHistBlock(HistBlock):
  def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=False,
               hist_boundary=None, device='cuda'):
    super().__init__(
      h, insz, resizing, method, sigma, intensity_scale, 
      hist_boundary or [0, 1], device
    )

  def forward(self, x):
    x_sampled = self.resize(x.clamp(0, 1))

    N = x_sampled.shape[0]  # size of mini-batch
    if x_sampled.shape[1] > 3:
      x_sampled = x_sampled[:, :3, :, :]
    X = torch.unbind(x_sampled, dim=0)
    hists = torch.zeros(N, 1, self.h, self.h, device=self.device)
    for n in range(N):
      Ix = X[n].reshape(3, -1).t()

      Ia = Ix[:, 1].unsqueeze(dim=1)
      Ib = Ix[:, 2].unsqueeze(dim=1)

      diff_a = (Ia - self.delta).abs()
      diff_b = (Ib - self.delta).abd()

      diff_a = self.kernel(diff_a)
      diff_b = self.kernel(diff_b)
      Iy = self.scaling(Ix)
      a = torch.t(Iy * diff_a)

      hists[n, 0, :, :] = torch.mm(a, diff_b)

    # normalization
    norm = hists.view(-1, self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + EPS
    hists_normalized = hists / norm

    return hists_normalized

## Validation

In order to validate our work, let's run A-B-tests for each possible parameter combination.
We can keep most values at their default, but resizing, sampling, and intensity scaling options should be thoroughly tested.
We can define a `dict` that holds all test cases - histogram block classes and the tested parameters.

Next we run both reference and refactored models with a batch of randomly generated images and compare the results.
For the comparison, we use the *arctangent absolute percentage error* (AAPE) as proposed in

    Sungil Kim, Heeyoung Kim,
    "A new metric of absolute percentage error for intermittent demand forecasts",
    International Journal of Forecasting,
    Volume 32, Issue 3,
    2016,
    Pages 669-679,
    https://doi.org/10.1016/j.ijforecast.2015.12.003

with the AAPE rescaled from its original [0, ½π] range to [0, 100] to obtain more readable percentages.

In [None]:
DEVICE = 'cuda'                  # device to run the tests on (e.g. 'cuda' or 'cpu')
BATCHES = 8                      # how many samples per mini-batch
SAMPLE_SIZE = 256                # sample image size in pixels
ERR_THRESHOLD = 0.05             # validation error threshold in percent
RANDOM_SEED = 4793               # for reproducibility we seed the rng, use torch.random.seed() instead to explore
                                 # the selected seed produces a more colourful output ;)

In [None]:
from dataclasses import dataclass, field
from itertools import product, repeat
import matplotlib.pyplot as plt


RESIZING_VALS = ['interpolation', 'sampling']
BOOL_VALS = [False, True]
METHOD_VALS = ['thresholding', 'RBF', 'inverse-quadratic']
TESTS = {
    'RGB-uv': {
        'A': OriginalRGBuvHistBlock, 'B': RGBuvHistBlock,
        'params': {'resizing': RESIZING_VALS, 'method': METHOD_VALS, 'intensity_scale': BOOL_VALS, 'green_only': BOOL_VALS}
    },
    'rg-chroma': {
        'A': OriginalrgChromaHistBlock, 'B': rgChromaHistBlock,
        'params': {'resizing': RESIZING_VALS, 'method': METHOD_VALS, 'intensity_scale': BOOL_VALS}
    },
    'Lab': {
        'A': OriginalrgChromaHistBlock, 'B': rgChromaHistBlock,
        'params': {'resizing': RESIZING_VALS, 'method': METHOD_VALS, 'intensity_scale': BOOL_VALS}
    }
}

def _to_dict(names, values):
    return {key: val for key, val in zip(names, values)}

def _param_info(p):
    if isinstance(p[1], bool):
        return f"{'+' if p[1] else '-'}{p[0][:3].upper()}"
    return p[1][:4].upper()

@dataclass
class Result:
    test: str
    params: str
    min_err: float
    max_err: float
    avg_err: float
    median_err: float
    outcome: str = field(init=False)

    def __post_init__(self):
        valid = self.avg_err < ERR_THRESHOLD
        self.outcome = 'PASS' if valid else 'FAILED'

def _to_row(result: Result):
    return [
        result.test, result.params, f'{result.min_err:.6f}', f'{result.max_err:.6f}',
        f'{result.avg_err:.6f}', f'{result.median_err:.6f}', result.outcome
    ]

def _aape(X, Y):
    eps = torch.full_like(input=X, fill_value=1.1921e-7).float()
    phi = ((X - Y).abs() / torch.maximum(X.abs(), eps)).arctan()
    return phi * (2 / torch.pi) * 100

torch.random.manual_seed(RANDOM_SEED)
samples = torch.randint(low=0, high=256, size=(BATCHES, 3, SAMPLE_SIZE, SAMPLE_SIZE))
samples = (samples / 255.).float().to(DEVICE)

def _validate(name, A_, B_, model_args):
    A = A_(**model_args).eval().to(DEVICE)
    B = B_(**model_args).eval().to(DEVICE)
    with torch.no_grad():
        Ay = A(samples)
        By = B(samples)
    err = _aape(Ay, By)
    params = ','.join(map(_param_info, model_args.items()))
    return Result(name, params, err.min().item(), err.max().item(), err.mean().item(), err.median().item())

validation_results: List[Result] = []
for name, args in TESTS.items():
    A, B, params = args['A'], args['B'], args['params']
    test_params = map(_to_dict, repeat(params), product(*params.values()))
    items = map(_validate, repeat(name), repeat(A), repeat(B), test_params)
    validation_results.extend(items)

Let's visualize the validation results next - a table will do for now.

The tested parameter combinations are listed by their first four letters in caps. Boolean flags are indicated by -*FLAG*
if the parameter is `False` and +*FLAG* if the parameter is set to `True`.

In [None]:
if any(filter(lambda i: i.outcome == 'FAILED', validation_results)):
    print(f'Got some failures; RANDOM_SEED to reproduce: {RANDOM_SEED}')

# plot the results
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
col_labels = ['HIST BLOCK', 'PARAMS', 'MIN ERR %', 'MAX ERR %', 'AVG ERR %', 'MEDIAN ERR %', 'RESULT']
ax.axis('tight')
ax.axis('off')
tbl = ax.table(
    cellText=list(map(_to_row, validation_results)), colLabels=col_labels, loc='center',
    colColours=['slategrey']*7
)
# format table
_ = list(map(lambda col: tbl[(0, col)].set_text_props(fontweight='bold'), range(7)))
for row, item in enumerate(validation_results):
    tbl[(row+1, 0)].set_facecolor('lightsteelblue')
    if item.min_err > ERR_THRESHOLD:
        tbl[(row+1, 2)].set_facecolor('darkorange')
    if item.max_err > ERR_THRESHOLD:
        tbl[(row+1, 3)].set_facecolor('darkorange')
    if item.avg_err > ERR_THRESHOLD:
        tbl[(row+1, 4)].set_facecolor('darkorange')
    if item.median_err > ERR_THRESHOLD:
        tbl[(row+1, 5)].set_facecolor('darkorange')
    tbl[(row+1, 6)].set_facecolor('g' if item.outcome=='PASS' else 'r')
    tbl[(row+1, 6)].set_text_props(fontweight='bold')
tbl.auto_set_font_size(False)
tbl.set_fontsize(10)
tbl.scale(2, 2)
plt.show()

## Validation Result Discussion

The chosen criterion for passing validation is a somewhat arbitrary but low average error threshold
(across all samples in the mini-batch).
The reported maximum error might exceed this threshold (I've seen values as high as 10%), but we
need to keep two things in mind here:

First, the maximum error refers to a single bucket value in the histogram. A single outlier in a
64x64 histogram shouldn't account for a validation failure.

Secondly, I noticed that the deviations occur in conjunction with interpolation only. Match the
sample dimensions (`SAMPLE_SIZE`) with the maximum histogram input size (`insz`) and the errors
go away. Since the resize function is identical to the original version in every way, I'm a bit at a
loss as to why that is.

## Benchmark

With that out of the way, let's get some performance numbers to see whether we actually improved things.
We assess the performance differences by running each histogram block a given number of times on a
mini-batch of random samples. Inference time is measured and results are plotted to a diagram.

In [None]:
ITERATIONS = 100    # number of benchmark passes per model
BATCHES = 16        # mini-batch size
SAMPLE_SIZE = 256   # sample image size for benchmarking
DEVICE = 'cuda'     # computation device to run the benchmark on (e.g. 'cuda' or 'cpu')
RANDOM_SEED = 0     # random seed for producing sample data (again, for reproducibility)

In [None]:
from tqdm import tqdm
from time import perf_counter
from typing import Dict
import numpy as np


BASELINE = [(key, val['A']) for key, val in TESTS.items()]
REFACTORED = [(key, val['B']) for key, val in TESTS.items()]

torch.random.manual_seed(RANDOM_SEED)


def _gen_minibatch():
    while True:
        X = torch.randint(low=0, high=256, size=(BATCHES, 3, SAMPLE_SIZE, SAMPLE_SIZE))
        yield (X / 255.).float().to(DEVICE)


def _benchmark(model, sample):
    start = perf_counter()
    _ = model(sample)
    return (perf_counter() - start) * 1_000


baseline_results: Dict[str, float] = { }
for name, Model in BASELINE:
    with torch.no_grad():
        models = repeat(Model().eval().to(DEVICE), times=ITERATIONS)
        runs = tqdm(models, total=ITERATIONS, desc=f'Benchmarking baseline {name}')
        baseline_results[name] = list(map(_benchmark, runs, _gen_minibatch()))

refactored_results: Dict[str, float] = { }
for name, Model in REFACTORED:
    with torch.no_grad():
        models = repeat(Model().eval().to(DEVICE), times=ITERATIONS)
        runs = tqdm(models, total=ITERATIONS, desc=f'Benchmarking refactored {name}')
        refactored_results[name] = list(map(_benchmark, runs, _gen_minibatch()))

a = np.array(list(baseline_results.values()))
a_mins = a.min(axis=1)
a_maxes = a.max(axis=1)
a_means = a.mean(axis=1)
a_std = a.std(axis=1)

b = np.array(list(refactored_results.values()))
b_mins = b.min(axis=1)
b_maxes = b.max(axis=1)
b_means = b.mean(axis=1)
b_std = b.std(axis=1)

With our data collected, let's print the mean relative execution time differences and plot some charts. 

In [None]:
from itertools import chain

N = len(baseline_results)

for name, speedup in zip(TESTS, a_means / b_means):
    print(f'Refactored {name}: {speedup:.1f}x faster on average')

fig, ax = plt.subplots(3, 1, figsize=(12, 18))

tick_labels = [(f'{n} (baseline)', f'{n} (refactored)') for n in TESTS]
tick_labels = list(chain(*tick_labels))
ax[0].errorbar(np.arange(N) * 2, a_means, a_std, fmt='_k', lw=3, ms=11, capsize=3)
ax[0].errorbar(np.arange(N) * 2, a_means, [a_means - a_mins, a_maxes - a_means], fmt='.k', ecolor='grey', lw=1, capsize=3)
ax[0].errorbar(np.arange(N) * 2 + 1, b_means, b_std, fmt='_b', lw=3, ms=11, capsize=3)
ax[0].errorbar(np.arange(N) * 2 + 1, b_means, [b_means - b_mins, b_maxes - b_means], fmt='.k', ecolor='lightsteelblue', lw=1, capsize=3)
ax[0].set_xticks(np.arange(2*N), minor=False)
ax[0].set_xtick_labels(tick_labels)
ax[0].set_title(f'Benchmark results for {ITERATIONS} iterations and mini-batch size of {BATCHES}')
ax[0].set_ylabel('Iteration time in ms')

labels = [name for name in TESTS]
width = 0.35
ax[1].bar(labels, a_means, width, yerr=a_std, label='Baseline', capsize=3)
ax[1].bar(labels, b_means, width, yerr=b_std, label='Refactored', capsize=3)
ax[1].set_ylabel('Iteration in ms')
ax[1].set_title('Execution time difference')
ax[1].legend()


a_mean_its = 1_000 / a_means
b_mean_its = 1_000 / b_means
a_it_std = (1_000 / a).std(axis=1)
b_it_std = (1_000 / b).std(axis=1)
width = 0.35
ax[2].bar(labels, a_mean_its, width, yerr=a_it_std, label='Baseline', capsize=3)
ax[2].bar(labels, b_mean_its, width, yerr=b_it_std, label='Refactored', bottom=a_mean_its, capsize=3)
ax[2].set_ylabel('Iteration per second')
ax[2].set_title('Performance difference')
ax[2].legend()

plt.show()