In [None]:
# Copyright 2022 Intrinsic Innovation LLC.
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

A GridMRF is an 8-connected, grid-arranged Markov random field with hidden variables, originally proposed in the [Query Training](https://ojs.aaai.org/index.php/AAAI/article/view/17004) AAAI 2021 paper.

In this notebook we demonstrate inference and gradient-based learning of a GridMRF using PGMax.

In [None]:
# # Uncomment this block if running on colab.research.google.com
# !pip install git+https://github.com/deepmind/PGMax.git
# !wget https://raw.githubusercontent.com/deepmind/PGMax/main/examples/example_data/gmrf_log_potentials.npz
# !mkdir example_data
# !mv gmrf_log_potentials.npz  example_data/

In [None]:
%matplotlib inline
import functools
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax.example_libraries import optimizers
from tqdm.notebook import tqdm
import tensorflow_datasets as tfds

############
# Load PGMax
from pgmax import fgraph, fgroup, infer, vgroup

# Create the noisy MNIST dataset

In [None]:
import numpy as np
from scipy.ndimage.morphology import binary_dilation

def contour_mnist(X):
  """Extract the contours of the MNIST images."""
  X = X.astype(float) / 255.
  # Contour are obtained by dilating the digits
  X = (X > 0.5).astype(int)
  s = np.zeros((3, 3, 3))
  s[1, 1, :3] = 1
  s[1, :3, 1] = 1
  X += binary_dilation(X, s)
  X[X == 1] = -1
  X[X == 0] = 1
  X[X == -1] = 0
  contour_X = np.ones((X.shape[0], 30, 30), int)
  contour_X[:, 1:-1, 1:-1] = X
  return contour_X

def add_noise_and_remove_contours(
    images,
    n_spurious_per_image,
    p_contour_deletion,
    n_add_contour_tries=1000,
    seed=0
):
  """Add spurious edges and removes contours at random."""
  np.random.seed(seed)

  noise_patterns = np.array([
      [[1, 3, 1], [1, 3, 1], [1, 3, 1]],
      [[1, 1, 1], [3, 3, 3], [1, 1, 1]],
      [[3, 1, 1], [1, 3, 1], [1, 1, 3]],
      [[1, 1, 3], [1, 3, 1], [3, 1, 1]]
  ])
  assert noise_patterns.shape[1:] == (3, 3)

  N, H, W = images.shape
  noisy_images = images.copy()

  # Add spurious edges
  for n in range(N):
    n_spurious = 0
    for t in range(n_add_contour_tries):
      r, c = np.random.randint(H - 4), np.random.randint(W - 4)
      if (noisy_images[n, r: r + 5, c: c + 5] == 1).all():  # all out
        idx = np.random.randint(len(noise_patterns))
        noisy_images[n, r + 1:r + 4, c + 1:c + 4] = noise_patterns[idx]
        n_spurious += 1
      if n_spurious == n_spurious_per_image:
        break

  # Remove contours
  n, r, c = (noisy_images == 0).nonzero()
  mask = np.random.binomial(1, p=p_contour_deletion, size=len(r)).astype(bool)
  n, r, c = n[mask], r[mask], c[mask]
  noisy_images[n, r, c] = 1

  # Probability of observing 1 (contour) given label is 0 (border), 1 (out) 2(in)
  p_contour = np.array([
      1 - mask.mean(),
      ((noisy_images == 3).sum() + 0.0) / ((noisy_images == 1).sum() + (noisy_images == 3).sum()),
      1e-10
  ])
  noisy_images[noisy_images == 3] = 0
  noisy_images[noisy_images > 0] = 1
  return noisy_images, p_contour


def get_noisy_mnist(dataset="test", n_samples=100):
  data = tfds.as_numpy(tfds.load("mnist", split=dataset, batch_size=-1))
  if n_samples is None:
    X = data["image"][:, :, :, 0]
  else:
    X = data["image"][:n_samples, :, :, 0]
  target_images = contour_mnist(X)
  noisy_images, p_contour = add_noise_and_remove_contours(
      target_images,
      n_spurious_per_image=8,
      p_contour_deletion=0.2
  )
  print(f"Noisy {dataset} MNIST generated for {n_samples} samples")
  return target_images, noisy_images, p_contour

# Visualize a trained GridMRF

In [None]:
# Load the dataset
target_images_test, noisy_images_test, p_contour = get_noisy_mnist(dataset="test", n_samples=100)

In [None]:
# Load a pretrained large model
folder_name = "example_data/"
grmf_log_potentials = np.load(open(folder_name + "gmrf_log_potentials.npz", 'rb'), allow_pickle=True)

# The number of clones defines the number of states of the categorical variables
n_clones = grmf_log_potentials["n_clones"]
num_states = int(np.sum(n_clones))

In [None]:
# Create the factor graph
M, N = target_images_test.shape[-2:]
variables = vgroup.NDVarArray(num_states=num_states, shape=(M, N))
fg = fgraph.FactorGraph(variables)

# Create top-down factors
top_down = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii + 1, jj]]
        for ii in range(M - 1)
        for jj in range(N)
    ],
)

# Create left-right factors
left_right = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii, jj + 1]]
        for ii in range(M)
        for jj in range(N - 1)
    ],
)

# Create diagonal factors
diagonal0 = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii + 1, jj + 1]]
        for ii in range(M - 1)
        for jj in range(N - 1)
    ],
)
diagonal1 = fgroup.PairwiseFactorGroup(
    variables_for_factors=[
        [variables[ii, jj], variables[ii - 1, jj + 1]]
        for ii in range(1, M)
        for jj in range(N - 1)
    ],
)

# Add factors to the factor graph
fg.add_factors([top_down, left_right, diagonal0, diagonal1])

In [None]:
# Create the BP functions
bp = infer.build_inferer(fg.bp_state, backend="bp")

# We need this quantity to set the unaries
p_contour_augmented = jax.device_put(np.repeat(p_contour, n_clones))

In [None]:
def run_inference_and_plot(noisy_images, target_images, log_potentials, n_plots=5):
  """Run inference on 5 randomly selected images and plot the predictions."""
  fig, ax = plt.subplots(n_plots, 3, figsize=(7, 2 * n_plots))

  indices = np.random.permutation(noisy_images.shape[0])[:n_plots]
  for plot_idx, idx in tqdm(enumerate(indices), total=n_plots):
    noisy_image = noisy_images[idx]
    target_image = target_images[idx]

    # Update the evidence
    evidence = jnp.log(
        jnp.where(
            noisy_image[..., None] == 0,
            p_contour_augmented,
            1 - p_contour_augmented
        )
    )

    # Run sum-product to estimate the marginaks
    marginals = infer.get_marginals(
        bp.get_beliefs(
            bp.run(
                bp.init(
                    evidence_updates={variables: evidence},
                    log_potentials_updates=log_potentials,
                ),
                num_iters=15,
                damping=0.0,
                temperature=1.0
            )
        )
    )[variables]

    # Look at the decoded image
    pred_image = np.argmax(
        np.stack(
            [
                np.sum(marginals[..., :-2], axis=-1),
                marginals[..., -2],
                marginals[..., -1],
            ],
            axis=-1,
        ),
        axis=-1,
    )
    ax[plot_idx, 0].imshow(noisy_image)
    ax[plot_idx, 0].axis("off")
    ax[plot_idx, 1].imshow(target_image)
    ax[plot_idx, 1].axis("off")
    ax[plot_idx, 2].imshow(pred_image)
    ax[plot_idx, 2].axis("off")
    if plot_idx == 0:
      ax[plot_idx, 0].set_title("Input noisy image", fontsize=18)
      ax[plot_idx, 1].set_title("Ground truth", fontsize=18)
      ax[plot_idx, 2].set_title("GridMRF prediction", fontsize=18)

  fig.tight_layout()

In [None]:
# Run inference using the pretrained potentials
log_potentials_pretrained = {
    top_down: grmf_log_potentials["top_down"],
    left_right: grmf_log_potentials["left_right"],
    diagonal0: grmf_log_potentials["diagonal0"],
    diagonal1: grmf_log_potentials["diagonal1"],
}

run_inference_and_plot(
    noisy_images_test,
    target_images_test,
    log_potentials_pretrained
)

# Finetune a perturbed model

We now illustrate how we can train a small GridMRF.

For this example to run fast, we do not initialize the parameters from scratch. Instead we perturb the pretrained parameters and finetune the model on a small number of 500 training samples for one epoch.

In [None]:
# We need this quantity to compute the loss
prototype_targets = jax.device_put(
    np.array(
        [
            np.repeat(np.array([1, 0, 0]), n_clones),
            np.repeat(np.array([0, 1, 0]), n_clones),
            np.repeat(np.array([0, 0, 1]), n_clones),
        ]
    )
)

In [None]:
@jax.jit
def loss(noisy_image, target_image, log_potentials):
  """Computes the cross-entropy loss between the predicted marginals and the ground truth."""
  target = prototype_targets[target_image]

  # Update the evidence
  evidence = jnp.log(
      jnp.where(
          noisy_image[..., None] == 0,
          p_contour_augmented,
          1 - p_contour_augmented
      )
  )

  # Rum sum-product to estimate the marginals
  marginals = infer.get_marginals(
      bp.get_beliefs(
          bp.run(
              bp.init(
                  evidence_updates={variables: evidence},
                  log_potentials_updates=log_potentials,
              ),
              num_iters=15,
              damping=0.0,
              temperature=1.0
          )
      )
  )

  # Compute the cross-entropy loss
  logp = jnp.mean(jnp.log(jnp.sum(target * marginals[variables], axis=-1)))
  return -logp


@jax.jit
def batch_loss(noisy_images, target_images, log_potentials):
  """Averages the loss across multiple images."""
  return jnp.mean(
      jax.vmap(loss, in_axes=(0, 0, None), out_axes=0)(
          noisy_images, target_images, log_potentials
      )
  )

In [None]:
@functools.partial(jax.jit, static_argnames="opt")
def update(log_potentials, batch_noisy_images, batch_target_images, opt, opt_state):
  """Update the log potentials."""
  loss, grad_log_potentials = jax.value_and_grad(batch_loss, argnums=2)(
      batch_noisy_images, batch_target_images, log_potentials
  )
  updates, new_opt_state = opt.update(grad_log_potentials, opt_state, log_potentials)
  new_log_potentials = optax.apply_updates(log_potentials, updates)
  return loss, new_log_potentials, new_opt_state

In [None]:
# Load the training data
target_images_train, noisy_images_train, _ = get_noisy_mnist(dataset="train", n_samples=500)

# Perturb the pretrained potentials
temp = 0.2
log_potentials_finetuned = {
    top_down: grmf_log_potentials["top_down"] + temp * np.random.randn(num_states, num_states),
    left_right: grmf_log_potentials["left_right"] + temp * np.random.randn(num_states, num_states),
    diagonal0: grmf_log_potentials["diagonal0"] + temp * np.random.randn(num_states, num_states),
    diagonal1: grmf_log_potentials["diagonal1"] + temp * np.random.randn(num_states, num_states),
}


# Create the optimizer
opt = optax.adam(learning_rate=3e-3)
opt_state = opt.init(log_potentials_finetuned)

# Training loop
batch_size = 10
n_epochs = 1
n_batches = noisy_images_train.shape[0] // batch_size

losses = []
with tqdm(total=n_epochs * n_batches) as pbar:
  for epoch in range(n_epochs):
    indices = np.random.permutation(noisy_images_train.shape[0])
    for idx in range(n_batches):
      batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
      batch_noisy_images, batch_target_images = (
          noisy_images_train[batch_indices],
          target_images_train[batch_indices],
      )
      loss, log_potentials_finetuned, opt_state = update(
          log_potentials_finetuned,
          batch_noisy_images,
          batch_target_images,
          opt,
          opt_state
      )
      pbar.update()
      pbar.set_postfix(loss=loss)
      losses.append(loss)

In [None]:
# Visualize the cross-entropy losses
plt.figure(figsize=(6, 4))
plt.plot(losses)
plt.xlabel("Training iteration", fontsize=16)
plt.ylabel("Cross-entropy loss", fontsize=16)

In [None]:
# Plot the inference results
run_inference_and_plot(
    noisy_images_test,
    target_images_test,
    log_potentials_finetuned
)