In [None]:
# Copyright 2022 Intrinsic Innovation LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# # Uncomment this block if running on colab.research.google.com
# !pip install git+https://github.com/deepmind/PGMax.git
# !wget https://raw.githubusercontent.com/deepmind/PGMax/main/examples/example_data/rcn.npz
# !mkdir example_data
# !mv rcn.npz  example_data/

In [None]:
%matplotlib inline
import os
import time

import jax
import matplotlib.pyplot as plt
import numpy as np
from jax import numpy as jnp
from jax import tree_util
from joblib import Memory
from scipy.ndimage import maximum_filter
from scipy.signal import fftconvolve
import tensorflow_datasets as tfds

############
# Load PGMax
from pgmax import fgraph, fgroup, infer, vgroup

os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# 1. Load the data

A Recursive Cortical Model (RCN) is a neuroscience-inspired probabilistic generative model for vision
published in [Science 2017](https://www.science.org/doi/10.1126/science.aag2612)
which performance for object recognition and scene segmentation tasks are comparable to deep learning approaches while being interpretable and orders of magnitude more data efficient. 

In this notebook, we load a two-level RCN model pre-trained on a small subset of 20 training images of the MNIST train dataset. 
We test this model on 20 test images of the MNIST test dataset.

In [None]:
from typing import Tuple
def fetch_mnist_dataset(test_size: int, seed: int = 5) -> Tuple[np.ndarray, np.ndarray]:
  """Returns test images randomly sampled from the MNIST test dataset.

  Args:
    test_size: Desired number of test images.
    seed: Random seed.

  Returns:
    test_set: An array containing test_size images from the MNIST test dataset.
    test_labels: Corresponding labels for the test images.
  """
  np.random.seed(seed)
  num_per_class = test_size // 10

  dataset = tfds.as_numpy(tfds.load("mnist", split="test", batch_size=-1))
  print("Successfully downloaded the MNIST dataset")

  full_mnist_test_images = dataset["image"]
  full_mnist_test_labels = dataset["label"].astype("int")

  test_set = []
  test_labels = []
  for i in range(10):
    idxs = np.random.choice(
        np.argwhere(full_mnist_test_labels == i)[:, 0], num_per_class
    )
    for idx in idxs:
      img = full_mnist_test_images[idx].reshape(28, 28)
      img_arr = jax.image.resize(image=img, shape=(112, 112), method="bicubic")
      img = jnp.pad(
          img_arr,
          pad_width=tuple([(p, p) for p in (44, 44)]),
          mode="constant",
          constant_values=0,
      )

      test_set.append(img)
      test_labels.append(i)

  return np.array(test_set), np.array(test_labels)

In [None]:
train_size = test_size = 20
test_set, test_labels = fetch_mnist_dataset(test_size)

# 2. Load the model

We load a pre-trained rcn model that has been trained using the code [here](https://github.com/vicariousinc/science_rcn/tree/master/science_rcn). The details of the variables are.
- train_set and train_labels - A sample of MNIST train dataset containing 100 train images and their labels.
- frcs and edges - Used to represent the learned rcn graphical models.
- suppression_masks and filters - Saved numpy arrays that are used to detect the presence or absence of an oriented/directed edge in an image. Please refer to the function get_bu_msg to see how they are used.

In [None]:
# Load data
folder_name = "example_data/"
data = np.load(open(folder_name + "rcn.npz", "rb"), allow_pickle=True)
idxs = range(0, 100, 100 // train_size)

train_set, train_labels, frcs, edges, suppression_masks, filters = (
    data["train_set"][idxs, :, :],
    data["train_labels"][idxs],
    data["frcs"][idxs],
    data["edges"][idxs],
    data["suppression_masks"],
    data["filters"],
)

We initialize the following hyper-parameters.
- hps and vps - Horizontal and vertical pool sizes respectively for RCN models. This represents the radius of the window around a pool vertex. Thus, a pool vertex will be activated by an input pixel in a rectangle of size [2*hps+1, 2*vps+1].
- num_orients - The number of different orientations at which edges are detected.
- brightness_diff_threshold - The brightness level at a pixel at which we declare the presence of an edge.

In [None]:
hps, vps = 12, 12
num_orients = filters.shape[0]
brightness_diff_threshold = 40.0

# 3. Visualize loaded model

In RCN, a learned model is a weighted graph. 

The weights (or the 'perturb_radius') constraints how the two vertices are allowed to vary during inference.

In [None]:
img_size = 200
pad = 44
img_idx = 4

model_img = np.ones((200, 200))
fig, ax = plt.subplots(1, 2, figsize=(20, 10))

frc, edge, train_img = frcs[img_idx], edges[img_idx], train_set[img_idx]
ax[0].imshow(train_img[pad : 200 - pad, pad : 200 - pad], cmap="gray")
ax[0].axis("off")
ax[0].set_title("Example training image", fontsize=40)

for e in edge:
  i1, i2, w = e  # The vertices for this edge along with the perturn radius.
  f1, r1, c1 = frc[i1]
  f2, r2, c2 = frc[i2]

  model_img[r1, c1] = 0
  model_img[r2, c2] = 0
  ax[1].text(
      (c1 + c2) // 2 - pad, (r1 + r2) // 2 - pad, str(w), color="green", fontsize=25
  )
  ax[1].plot([c1 - pad, c2 - pad], [r1 - pad, r2 - pad], color="green", linewidth=0.5)

ax[1].axis("off")
ax[1].imshow(model_img[pad : 200 - pad, pad : 200 - pad], cmap="gray")
ax[1].set_title("Corresponding RCN template", fontsize=40)

fig.tight_layout()

## 3.1 Visualize the filters

The filters are used to detect the oriented edges on a given image. They are pre-computed using Gabor filters.

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
for i in range(filters.shape[0]):
  idx = np.unravel_index(i, (4, 4))
  ax[idx].imshow(filters[i], cmap="gray")
  ax[idx].axis("off")

fig.tight_layout()

# 4. Make pgmax graph

Converting the pre-learned RCN model to PGMax factor graph so as to run inference.

## 4.1 Make variables

In [None]:
assert frcs.shape[0] == edges.shape[0]

 # The number of pool choices for the different variables of the PGM.
M = (2 * hps + 1) * (2 * vps + 1) 

variables_all_models = []
for idx in range(frcs.shape[0]):
  frc = frcs[idx]
  variables_all_models.append(vgroup.NDVarArray(num_states=M, shape=(frc.shape[0],)))

## 4.2 Make factors

### 4.2.1 Pre-compute the valid configs for different perturb radius.

In [None]:
def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
  """Returns the valid configurations for a factor given the perturb radius.

  Args:
    r: Peturb radius
    hps: The horizontal pool size.
    vps: The vertical pool size.

  Returns:
    An array of shape (num_valid_configs, 2) containing all valid configurations
  """
  configs = []
  for i, (r1, c1) in enumerate(
    np.array(
      np.unravel_index(
        np.arange((2 * hps + 1) * (2 * vps + 1)), (2 * hps + 1, 2 * vps + 1)
      )
    ).T
  ):
      r2_min = max(r1 - r, 0)
      r2_max = min(r1 + r, 2 * hps)
      c2_min = max(c1 - r, 0)
      c2_max = min(c1 + r, 2 * vps)
      j = np.ravel_multi_index(
          tuple(np.mgrid[r2_min : r2_max + 1, c2_min : c2_max + 1]),
          (2 * hps + 1, 2 * vps + 1),
      ).ravel()
      configs.append(np.stack([np.full(j.shape, fill_value=i), j], axis=1))

  return np.concatenate(configs)

# The maximum perturb radius for which to pre-compute the valid configs.
max_perturb_radius = 22
valid_configs_list = [valid_configs(r, hps, vps) for r in range(max_perturb_radius)]

### 4.2.2 Make the factor graph

In [None]:
fg = fgraph.FactorGraph(variables_all_models)

# Adding rcn model edges to the pgmax factor graph.
for idx in range(edges.shape[0]):
  edge = edges[idx]

  for e in edge:
    i1, i2, r = e
    factor_group = fgroup.EnumFactorGroup(
        variables_for_factors=[
            [variables_all_models[idx][i1], variables_all_models[idx][i2]]
        ],
        factor_configs=valid_configs_list[r],
    )
    fg.add_factors(factor_group)

# 5. Run inference

## 5.1 Helper functions to initialize the evidence for a given image

In [None]:
def get_bu_msg(img: np.ndarray) -> np.ndarray:
  """Computes the bottom-up messages given a test image.

  Args:
    img: The rgb image to compute bottom up messages on [H, W, 3].

  Returns:
    An array of shape [16, H, W] denoting the presence or absence of an oriented 
    and directional line-segments at a particular location. 
    The elements of this array belong to the set {+1, -1}.
  """

  # Convolving the image with different gabor filters.
  filtered = np.zeros((filters.shape[0],) + img.shape, dtype=np.float32)
  for i in range(filters.shape[0]):
    kern = filters[i, :, :]
    filtered[i] = fftconvolve(img, kern, mode="same")

  # Applying non-max suppression to all the filtered images.
  localized = np.zeros_like(filtered)
  cross_orient_max = filtered.max(0)
  filtered[filtered < 0] = 0
  for i, (layer, suppress_mask) in enumerate(zip(filtered, suppression_masks)):
    competitor_maxs = maximum_filter(layer, footprint=suppress_mask, mode="nearest")
    localized[i] = competitor_maxs <= layer
  localized[cross_orient_max > filtered] = 0

  # Threshold and binarize
  localized *= (filtered / brightness_diff_threshold).clip(0, 1)
  localized[localized < 1] = 0

  # Apply cross-channel pooling.
  pooled_channels = [-np.ones_like(sf) for sf in localized]
  for i, pc in enumerate(pooled_channels):
    for channel_offset in [0, -1, 1]:
      ch = (i + channel_offset) % num_orients
      pos_chan = localized[ch]
      np.maximum(pc, pos_chan, pc)

  # Remapping the elements to set {+1, -1}.
  bu_msg = np.array(pooled_channels)
  bu_msg[bu_msg == 0] = -1
  return bu_msg

### 5.1.1 Visualizing bu_msg for a sample image

bu_msg has shape (16, H, W) where each 1 <= f <= 16 denotes the present or absense of a oriented edge

In [None]:
r_test_img = test_set[4]
r_bu_msg = get_bu_msg(r_test_img)
img = np.ones((200, 200))
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(r_test_img, cmap="gray")
ax[0].axis("off")
ax[0].set_title("Input image", fontsize=40)
for i in range(r_bu_msg.shape[0]):
  img[r_bu_msg[i] > 0] = 0

ax[1].imshow(img, cmap="gray")
ax[1].axis("off")
ax[1].set_title("Max filter response across 16 channels", fontsize=40)
fig.tight_layout()

### 5.2 Run MAP inference on all test images

In [None]:
def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
  """Returns the evidence, of shape (n_frcs, M).

  Args:
    bu_msg: Array of shape (n_features, 200, 200). Contains BU messages
    frc: Array of shape (n_frcs, 3).
  """
  evidence = np.zeros((frc.shape[0], M))
  for v, (f, r, c) in enumerate(frc):
    evidence[v] = bu_msg[f, r - hps : r + hps + 1, c - vps : c + vps + 1].ravel()
  return evidence

In [None]:
frcs_dict = {
    variables_all_models[model_idx]: frcs[model_idx]
    for model_idx in range(frcs.shape[0])
}
bp = infer.build_inferer(fg.bp_state, backend="bp")
scores = np.zeros((len(test_set), frcs.shape[0]))
map_states_dict = {}

for test_idx in range(len(test_set)):
  img = test_set[test_idx]

  # Initializing evidence
  bu_msg = get_bu_msg(img)
  evidence_updates = jax.tree_util.tree_map(
      lambda frc: get_evidence(bu_msg, frc), frcs_dict
  )

  # Max-product inference
  start = time.time()
  bp_arrays = bp.run(
      bp.init(evidence_updates=evidence_updates),
      num_iters=30,
      temperature=0.0
  )
  map_states = infer.decode_map_states(bp.get_beliefs(bp_arrays))
  end = time.time()
  print(f"Max product inference took {end-start:.3f} seconds for image {test_idx}")

  map_states_dict[test_idx] = map_states
  score = tree_util.tree_map(
      lambda evidence, map: jnp.sum(evidence[jnp.arange(map.shape[0]), map]),
      evidence_updates,
      map_states,
  )
  for model_idx in range(frcs.shape[0]):
    scores[test_idx, model_idx] = score[variables_all_models[model_idx]]

# 6. Compute metrics (accuracy)

In [None]:
best_model_idx = np.argmax(scores, axis=1)
test_preds = train_labels[best_model_idx]
accuracy = (test_preds == test_labels).sum() / test_labels.shape[0]
print(f"accuracy = {accuracy}")

# 7. Visualize predictions - backtrace for the top model

In [None]:
fig, ax = plt.subplots(5, 4, figsize=(16, 20))
for test_idx in range(20):
  idx = np.unravel_index(test_idx, (5, 4))
  map_state = map_states_dict[test_idx][
      variables_all_models[best_model_idx[test_idx]]
  ]
  offsets = np.array(
      np.unravel_index(map_state, (2 * hps + 1, 2 * vps + 1))
  ).T - np.array([hps, vps])
  
  activations = frcs[best_model_idx[test_idx]][:, 1:] + offsets
  for rd, cd in activations:
    ax[idx].plot(cd, rd, "r.")

  ax[idx].imshow(test_set[test_idx], cmap="gray")
  ax[idx].set_title(
      f"Ground Truth: {test_labels[test_idx]}, Pred: {test_preds[test_idx]}",
      fontsize=20,
  )
  ax[idx].axis("off")

fig.tight_layout()