In [None]:
# Copyright 2022 Intrinsic Innovation LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

A GMRF is an 8-connected, grid-arranged Markov random field with hidden variables, originally proposed in the [Query Training](https://ojs.aaai.org/index.php/AAAI/article/view/17004) AAAI 2021 paper. 

In this notebook we demonstrate inference and gradient-based learning of a GMRF using PGMax.

In [None]:
%matplotlib inline
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax.example_libraries import optimizers
from tqdm.notebook import tqdm
import tensorflow_datasets as tfds

############
# Load PGMax
from pgmax import fgraph, fgroup, infer, vgroup

# Create the noisy MNIST dataset

In [None]:
import numpy as np
from scipy.ndimage.morphology import binary_dilation

def contour_mnist(X):
  """Extract the contours of the MNIST images."""
  X = X.astype(float) / 255.
  # Contour are obtained by dilating the digits
  X = (X > 0.5).astype(int)
  s = np.zeros((3, 3, 3))
  s[1, 1, :3] = 1
  s[1, :3, 1] = 1
  X += binary_dilation(X, s)
  X[X == 1] = -1
  X[X == 0] = 1
  X[X == -1] = 0
  contour_X = np.ones((X.shape[0], 30, 30), int)
  contour_X[:, 1:-1, 1:-1] = X
  return contour_X
  
def add_noise_and_remove_contours(
    images, 
    n_spurious_per_image, 
    p_contour_deletion, 
    n_add_contour_tries=1000, 
    seed=0
):
  """Add spurious edges and removes contours at random."""
  np.random.seed(seed)

  noise_patterns = np.array([
      [[1, 3, 1], [1, 3, 1], [1, 3, 1]],
      [[1, 1, 1], [3, 3, 3], [1, 1, 1]],
      [[3, 1, 1], [1, 3, 1], [1, 1, 3]],  
      [[1, 1, 3], [1, 3, 1], [3, 1, 1]]
  ])
  assert noise_patterns.shape[1:] == (3, 3) 

  N, H, W = images.shape
  noisy_images = images.copy()

  # Add spurious edges
  for n in range(N):
    n_spurious = 0
    for t in range(n_add_contour_tries):
      r, c = np.random.randint(H - 4), np.random.randint(W - 4)
      if (noisy_images[n, r: r + 5, c: c + 5] == 1).all():  # all out
        idx = np.random.randint(len(noise_patterns))
        noisy_images[n, r + 1:r + 4, c + 1:c + 4] = noise_patterns[idx]
        n_spurious += 1
      if n_spurious == n_spurious_per_image:
        break
    
  # Remove contours
  n, r, c = (noisy_images == 0).nonzero()
  mask = np.random.binomial(1, p=p_contour_deletion, size=len(r)).astype(bool)
  n, r, c = n[mask], r[mask], c[mask]
  noisy_images[n, r, c] = 1

  # Probability of observing 1 (contour) given label is 0 (border), 1 (out) 2(in)
  p_contour = np.array([
      1 - mask.mean(),
      ((noisy_images == 3).sum() + 0.0) / ((noisy_images == 1).sum() + (noisy_images == 3).sum()),
      1e-10  
  ])  
  noisy_images[noisy_images == 3] = 0
  noisy_images[noisy_images > 0] = 1
  return noisy_images, p_contour


def get_noisy_mnist(dataset="test", n_samples=100):
  data = tfds.as_numpy(tfds.load("mnist", split=dataset, batch_size=-1))
  if n_samples is None:
    X = data["image"][:, :, :, 0]
  else:
    X = data["image"][:n_samples, :, :, 0]
  target_images = contour_mnist(X)
  noisy_images, p_contour = add_noise_and_remove_contours(
      target_images, 
      n_spurious_per_image=8, 
      p_contour_deletion=0.2
  )
  print(f"Noisy {dataset} MNIST generated for {n_samples} samples")
  return target_images, noisy_images, p_contour

# Visualize a trained GMRF

In [None]:
target_images, noisy_images, p_contour = get_noisy_mnist(dataset="test")

In [None]:
# Load data
folder_name = "example_data/"
grmf_log_potentials = np.load(open(folder_name + "gmrf_log_potentials.npz", 'rb'), allow_pickle=True)

n_clones = grmf_log_potentials["n_clones"]
p_contour = jax.device_put(np.repeat(p_contour, n_clones))
prototype_targets = jax.device_put(
    np.array(
        [
            np.repeat(np.array([1, 0, 0]), n_clones),
            np.repeat(np.array([0, 1, 0]), n_clones),
            np.repeat(np.array([0, 0, 1]), n_clones),
        ]
    )
)

In [None]:
M, N = target_images.shape[-2:]
num_states = np.sum(n_clones)
variables = vgroup.NDVarArray(num_states=num_states, shape=(M, N))
fg = fgraph.FactorGraph(variables)

In [None]:
# Create top-down factors
top_down = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii + 1, jj]]
        for ii in range(M - 1)
        for jj in range(N)
    ],
)

# Create left-right factors
left_right = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii, jj + 1]]
        for ii in range(M)
        for jj in range(N - 1)
    ],
)

# Create diagonal factors
diagonal0 = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii + 1, jj + 1]]
        for ii in range(M - 1)
        for jj in range(N - 1)
    ],
)
diagonal1 = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii - 1, jj + 1]]
        for ii in range(1, M)
        for jj in range(N - 1)
    ],
)

# Add factors to the factor graph
fg.add_factors([top_down, left_right, diagonal0, diagonal1])

In [None]:
bp = infer.BP(fg.bp_state, temperature=1.0)

In [None]:
log_potentials = {
    top_down: grmf_log_potentials["top_down"],
    left_right: grmf_log_potentials["left_right"],
    diagonal0: grmf_log_potentials["diagonal0"],
    diagonal1: grmf_log_potentials["diagonal1"],
}

n_plots = 5
indices = np.random.permutation(noisy_images.shape[0])[:n_plots]
fig, ax = plt.subplots(n_plots, 3, figsize=(12, 4 * n_plots))
for plot_idx, idx in tqdm(enumerate(indices), total=n_plots):
  noisy_image = noisy_images[idx]
  target_image = target_images[idx]
  evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
  target = prototype_targets[target_image]
  marginals = infer.get_marginals(
      bp.get_beliefs(
          bp.run_bp(
              bp.init(
                  evidence_updates={variables: evidence},
                  log_potentials_updates=log_potentials,
              ),
              num_iters=15,
              damping=0.0,
          )
      )
  )[variables]

  pred_image = np.argmax(
      np.stack(
          [
              np.sum(marginals[..., :-2], axis=-1),
              marginals[..., -2],
              marginals[..., -1],
          ],
          axis=-1,
      ),
      axis=-1,
  )
  ax[plot_idx, 0].imshow(noisy_image)
  ax[plot_idx, 0].axis("off")
  ax[plot_idx, 1].imshow(target_image)
  ax[plot_idx, 1].axis("off")
  ax[plot_idx, 2].imshow(pred_image)
  ax[plot_idx, 2].axis("off")
  if plot_idx == 0:
    ax[plot_idx, 0].set_title("Input noisy image", fontsize=30)
    ax[plot_idx, 1].set_title("Ground truth", fontsize=30)
    ax[plot_idx, 2].set_title("GMRF prediction", fontsize=30)

fig.tight_layout()

# Train the model from scratch

The following training loop requires a GPU with at least 11 GB of memory.

In [None]:
@jax.jit
def loss(noisy_image, target_image, log_potentials):
  """Computes the cross entropy loss between the predicted marginals and the ground truth."""
  evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
  target = prototype_targets[target_image]
  marginals = infer.get_marginals(
      bp.get_beliefs(
          bp.run_bp(
              bp.init(
                  evidence_updates={variables: evidence},
                  log_potentials_updates=log_potentials,
              ),
              num_iters=15,
              damping=0.0,
          )
      )
  )
  logp = jnp.mean(jnp.log(jnp.sum(target * marginals[variables], axis=-1)))
  return -logp


@jax.jit
def batch_loss(noisy_images, target_images, log_potentials):
  """Averages the loss across multiple images."""
  return jnp.mean(
      jax.vmap(loss, in_axes=(0, 0, None), out_axes=0)(
          noisy_images, target_images, log_potentials
      )
  )

In [None]:
value_and_grad = jax.jit(jax.value_and_grad(batch_loss, argnums=2))
init_fun, opt_update, get_params = optimizers.adam(2e-3)

@jax.jit
def update(step, batch_noisy_images, batch_target_images, opt_state):
  """Update the parameters."""
  value, grad = value_and_grad(
      batch_noisy_images, batch_target_images, get_params(opt_state)
  )
  opt_state = opt_update(step, grad, opt_state)
  return value, opt_state

In [None]:
# Get the training data
target_images_train, noisy_images_train, _ = get_noisy_mnist(dataset="train", n_samples=None)

# Initialize the optimizer
opt_state = init_fun(
    {
        top_down: np.random.randn(num_states, num_states),
        left_right: np.random.randn(num_states, num_states),
        diagonal0: np.random.randn(num_states, num_states),
        diagonal1: np.random.randn(num_states, num_states),
    }
)

# Training loop
batch_size = 10
n_epochs = 10
n_batches = noisy_images_train.shape[0] // batch_size
with tqdm(total=n_epochs * n_batches) as pbar:
  for epoch in range(n_epochs):
    indices = np.random.permutation(noisy_images_train.shape[0])
    for idx in range(n_batches):
      batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
      batch_noisy_images, batch_target_images = (
          noisy_images_train[batch_indices],
          target_images_train[batch_indices],
      )
      step = epoch * n_batches + idx
      value, opt_state = update(
          step, batch_noisy_images, batch_target_images, opt_state
      )
      pbar.update()
      pbar.set_postfix(loss=value)

  
# Get the trained parameters
params = get_params(opt_state)
for factor in [top_down, left_right, diagonal0, diagonal1]:
  assert params[factor].shape == (num_states, num_states)