In [None]:
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# # Uncomment this block if running on colab.research.google.com
# !pip install git+https://github.com/deepmind/PGMax.git

This notebook uses the Smooth Dual LP-MAP solver (1) to run inference on an Ising Model and compare its results with BP (2) to extract sparse feature activations from visually complex binary scenes.

In [None]:
from collections import defaultdict
import time
from tqdm import tqdm

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import logit

############
# Load PGMax
from pgmax import factor, fgraph, fgroup, infer, vgroup

# 1. Inference in an Ising model

We reuse the Ising model from the example notebook.

### 1.1 Create the factor graph

In [None]:
grid_size = 50

variables = vgroup.NDVarArray(num_states=2, shape=(grid_size, grid_size))
fg = fgraph.FactorGraph(variable_groups=variables)

variables_for_factors = []
for ii in range(grid_size):
  for jj in range(grid_size):
    kk = (ii + 1) % grid_size
    ll = (jj + 1) % grid_size
    variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
    variables_for_factors.append([variables[ii, jj], variables[ii, ll]])

factor_group = fgroup.PairwiseFactorGroup(
    variables_for_factors=variables_for_factors,
    log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
)
fg.add_factors(factor_group)

### 1.2 Run inference with BP

First, we run inference with Belief Propagation.

In [None]:
rng = jax.random.PRNGKey(0)
evidence_updates={variables: jax.random.gumbel(rng, shape=(grid_size, grid_size, 2))}
num_iters = 2_000

# Initialize the BP solver
bp = infer.build_inferer(fg.bp_state, backend="bp")

# Run inference
bp_arrays = bp.init(evidence_updates=evidence_updates)
_ = bp.run(bp_arrays, num_iters=num_iters, temperature=0)
start = time.time()
bp_arrays = bp.run(bp_arrays, num_iters=num_iters, temperature=0)
bp_time = time.time() - start
print(f"BP time (after compiling): {bp_time:.4}s")

# Get the BP decoding
beliefs = bp.get_beliefs(bp_arrays)
bp_decoding = infer.decode_map_states(beliefs)

### 1.3 Run inference with SDLP

Second, we run inference with the Smooth Dual LP solver using the same evidence.

Note that both the LP and SDLP solvers support the same interface, which is initialized via:
```
infer.build_inferer(fg.bp_state, backend=backend)
```

For the same number of iterations, the timings of BP and SDLP are comparable. However in practice BP needs fewer iterations than SDLP.

In [None]:
# Initialize the SDLP solver
sdlp = infer.build_inferer(fg.bp_state, backend="sdlp")

# Run inference
sdlp_arrays = sdlp.init(evidence_updates=evidence_updates)
_ = sdlp.run(sdlp_arrays, logsumexp_temp=1e-3, num_iters=num_iters)

start = time.time()
sdlp_arrays, smooth_dual_objvals = sdlp.run_with_objvals(
    sdlp_arrays,
    logsumexp_temp=1e-3,
    num_iters=num_iters
)
sdlp_time = time.time() - start
print(f"SDLP time (after compiling): {sdlp_time:.4}s")

# Get the SDLP decoding
sdlp_beliefs = bp.get_beliefs(sdlp_arrays)
sdlp_unaries_decoding = infer.decode_map_states(sdlp_beliefs)

The SDLP solver gives access to (1) an upper bound of the optimal objective value of the relaxed LP-MAP problem and (2) a lower bound of the optimal objective of the MAP problem. If both bounds are equal, then the LP relaxation is tight and we are at the MAP solution.

In [None]:
primal_upper_bound = sdlp.get_primal_upper_bound(sdlp_arrays)
print(f"Upper bound for LP-MAP {primal_upper_bound:.3f}")

primal_lower_bound = sdlp.get_map_lower_bound(sdlp_arrays, sdlp_unaries_decoding)
print(f"Lower bound for MAP {primal_lower_bound:.3f}")

print(f"Gap: {(100 * abs(primal_upper_bound - primal_lower_bound) / abs(primal_upper_bound)):.3f}%")

We plot the Smooth Dual objective value at each gradient step to visualize its convergence.

We additionally vary the temperature and observe that at a high temperature, the objective value converge faster, but it is farther away from the MAP optimal objective value.



In [None]:
# Solve the dual for a higher log-sum-exp temperature
_, smooth_dual_objvals_higher = sdlp.run_with_objvals(
    sdlp.init(evidence_updates=evidence_updates),
    logsumexp_temp=1e-2,
    num_iters=num_iters
)
# Solve the dual for a log-sum-exp temperature of 0
_, smooth_dual_objvals_T0 = sdlp.run_with_objvals(
    sdlp.init(evidence_updates=evidence_updates),
    logsumexp_temp=0.0,  # subgradient descent
    num_iters=num_iters
)

# Plot
plt.figure(figsize=(12, 6))
plt.plot(smooth_dual_objvals, label="SDLP objval for T=1e-3", c='b')
plt.scatter(num_iters, primal_upper_bound, c='g', s=200, marker="*", label="MAP upper bound from the solution at T=1e-3")
plt.scatter(num_iters, primal_lower_bound, c='r', s=100, marker="x", label="MAP lower bound from the solution at T=1e-3")
plt.plot(smooth_dual_objvals_higher, label="SDLP objval for T=1e-2", c='g', linewidth=0.5)
plt.plot(smooth_dual_objvals_T0, label="SDLP objval for T=0 (with SGD)", c='r', linewidth=0.5)

plt.legend(fontsize=16)
plt.xlabel("Gradient steps", fontsize=16)
plt.ylabel("Objective value", fontsize=16)
_ = plt.title("Smooth Dual LP MAP for an Ising model", fontsize=18)

### 1.4 Compare the BP and the SDLP solutions

We plot the BP and SDLP decodings.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 10))
ax[0].imshow(bp_decoding[variables])
ax[0].axis("off")
ax[0].set_title("BP decoding", fontsize=18)

ax[1].imshow(sdlp_unaries_decoding[variables])
ax[1].axis("off")
ax[1].set_title("Dual LP decoding", fontsize=18)

We compare the energies of the BP and SDLP solutions (lower is better).

In [None]:
bp_energy = infer.compute_energy(fg.bp_state, bp_arrays, bp_decoding)[0]
sdlp_energy = - primal_lower_bound

print(f"Energy of the BP decoding: {bp_energy:.3f}")
print(f"Energy of the Smooth Dual LP decoding: {sdlp_energy:.3f}")

### 1.5 (Optional) Compare with the primal LP-MAP solver

Last, we use the LP solver `cvxpy` to solve the LP-MAP problem to optimality.
The LP solver is much slower than the other methods.

In [None]:
from pgmax.utils import primal_lp

start = time.time()
cvxpy_lp_vgroups_solution, cvxpy_lp_objval = primal_lp.primal_lp_solver(fg, evidence_updates)
cvxpy_time = time.time() - start
print(f"Primal LP-MAP time: {cvxpy_time:.4}s")

# The optimal objective value is lower than the upper bound derived from the dual solution
assert cvxpy_lp_objval <= primal_upper_bound

decoding_from_cvxpy = infer.decode_map_states(cvxpy_lp_vgroups_solution)
energy_from_cvpy = infer.compute_energy(fg.bp_state, bp_arrays, decoding_from_cvxpy)[0]
print(f"Energy from the optimal LP decoding: {energy_from_cvpy:.3f}")

# 2. Sparsification of a binary scene

We use the SDLP solver to sparsify the scenes of the PMP Binary Deconvolution notebook example.

We assume that the binary features W are known and we try to recover the binary indicator S from the binary scenes X.

### 2.1 Load the data

In [None]:
# # Uncomment this block if running on colab.research.google.com
# !wget https://raw.githubusercontent.com/deepmind/PGMax/main/examples/example_data/conv_problem.npz
# !mkdir example_data
# !mv conv_problem.npz  example_data/

In [None]:
def plot_images(images, display=True, nr=None):
  "Useful function for visualizing several images."
  n_images, H, W = images.shape
  images = images - images.min()
  images /= images.max() + 1e-10

  if nr is None:
    nr = nc = np.ceil(np.sqrt(n_images)).astype(int)
  else:
    nc = n_images // nr
    assert n_images == nr * nc
  big_image = np.ones(((H + 1) * nr + 1, (W + 1) * nc + 1, 3))
  big_image[..., :3] = 0
  big_image[:: H + 1] = [0.5, 0, 0.5]

  im = 0
  for r in range(nr):
    for c in range(nc):
      if im < n_images:
        big_image[
            (H + 1) * r + 1 : (H + 1) * r + 1 + H,
            (W + 1) * c + 1 : (W + 1) * c + 1 + W,
            :,
        ] = images[im, :, :, None]
        im += 1

  if display:
    plt.figure(figsize=(10, 10))
    plt.imshow(big_image, interpolation="none")
    plt.axis("off")
  return big_image

In [None]:
# Load data
folder_name = "example_data/"
data = np.load(open(folder_name + "conv_problem.npz", 'rb'), allow_pickle=True)
W_gt = data["W"]
X_gt = data["X"]
X_gt = X_gt[:20]

_ = plot_images(X_gt[:8, 0], nr=2)
plt.title("Convolved images", fontsize=20)

### 2.2 Create the factor graph

We use a similar factor graph as in the Binary Deconvolution notebook.

In [None]:
_, n_feat, feat_height, feat_width = W_gt.shape
n_images, n_chan, im_height, im_width = X_gt.shape
s_height = im_height - feat_height + 1
s_width = im_width - feat_width + 1

# Binary features
W = vgroup.NDVarArray(num_states=2, shape=(n_chan, n_feat, feat_height, feat_width))

# Binary indicators of features locations
S = vgroup.NDVarArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))

# Auxiliary binary variables combining W and S
SW = vgroup.NDVarArray(
    num_states=2,
    shape=(n_images, n_chan, im_height, im_width, n_feat, feat_height, feat_width),
)

# Binary images obtained by convolution
X = vgroup.NDVarArray(num_states=2, shape=X_gt.shape)

In [None]:
# Factor graph
fg = fgraph.FactorGraph(variable_groups=[S, W, SW, X])

# Define the ANDFactors
variables_for_ANDFactors = []
variables_for_ORFactors_dict = defaultdict(list)
for idx_img in tqdm(range(n_images)):
  for idx_chan in range(n_chan):
    for idx_s_height in range(s_height):
      for idx_s_width in range(s_width):
        for idx_feat in range(n_feat):
          for idx_feat_height in range(feat_height):
            for idx_feat_width in range(feat_width):
              idx_img_height = idx_feat_height + idx_s_height
              idx_img_width = idx_feat_width + idx_s_width
              SW_var = SW[
                  idx_img,
                  idx_chan,
                  idx_img_height,
                  idx_img_width,
                  idx_feat,
                  idx_feat_height,
                  idx_feat_width,
              ]

              variables_for_ANDFactor = [
                  S[idx_img, idx_feat, idx_s_height, idx_s_width],
                  W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
                  SW_var,
              ]
              variables_for_ANDFactors.append(variables_for_ANDFactor)

              X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
              variables_for_ORFactors_dict[X_var].append(SW_var)

# Add ANDFactorGroup, which is computationally efficient
AND_factor_group = fgroup.ANDFactorGroup(variables_for_ANDFactors)
fg.add_factors(AND_factor_group)

# Define the ORFactors
variables_for_ORFactors = [
    list(tuple(variables_for_ORFactors_dict[X_var]) + (X_var,))
    for X_var in variables_for_ORFactors_dict
]

# Add ORFactorGroup, which is computationally efficient
OR_factor_group = fgroup.ORFactorGroup(variables_for_ORFactors)
fg.add_factors(OR_factor_group)

### 2.3 Run inference with SDLP

In [None]:
sdlp = infer.build_inferer(fg.bp_state, backend="sdlp")

We define the unaries to address the posterior sampling query *p(S | X, W)*

In [None]:
pS = 1e-5
pX = 1e-100

# Unaries for the known W and X
uW = np.zeros((W.shape) + (2,))
uW[..., 0] = (2 * W_gt - 1) * logit(pX)

uX = np.zeros((X_gt.shape) + (2,))
uX[..., 0] = (2 * X_gt - 1) * logit(pX)

# Sparsity inducing prior for S
uS = np.zeros((S.shape) + (2,))
uS[..., 1] = logit(pS)

evidence_updates={
    S: uS + jax.random.gumbel(rng, shape=uS.shape),
    W: uW,
    SW: np.zeros(shape=SW.shape),
    X: uX,
}

Run SDLP inference

In [None]:
sdlp_arrays = sdlp.init(evidence_updates=evidence_updates)
sdlp_arrays = sdlp.run(sdlp_arrays, num_iters=2_000, logsumexp_temp=0.001)
primal_unaries_decoded, _ = sdlp.decode_primal_unaries(sdlp_arrays)

### 3.3 Compute the reconstruction

We compute the recontructed scenes and the reconstruction error.

In [None]:
def or_layer(S, W):
  """2D convolution of S and W."""
  _, n_feat, s_height, s_width = S.shape
  _, n_feat, feat_height, feat_width = W.shape
  im_height, im_width = s_height + feat_height - 1, s_width + feat_width - 1

  # Revert the features to have the proper orientations
  Wrev = W[:, :, ::-1, ::-1]

  # Pad the binary indicators
  Spad = jax.lax.pad(
      S,
      0.0,
      (
          (0, 0, 0), # first dim of W
          (0, 0, 0),
          (feat_height - 1, feat_height - 1, 0),
          (feat_width - 1, feat_width - 1, 0), # last dim of W
      ),
  )

  # Convolve Spad and W
  def compute_sample(Spad1):
    def compute_pixel(r, c):
      X1 = (
          1
          - jax.lax.dynamic_slice(Spad1, (0, r, c), (n_feat, feat_height, feat_width))
          * Wrev
      ).prod((1, 2, 3))
      return 1 - X1

    compute_cols = jax.vmap(compute_pixel, in_axes=(None, 0), out_axes=1)
    compute_rows_cols = jax.vmap(compute_cols, in_axes=(0, None), out_axes=1)
    return compute_rows_cols(jnp.arange(im_height), jnp.arange(im_width))

  return jax.vmap(compute_sample, in_axes=0, out_axes=0)(Spad)

In [None]:
X_rec = or_layer(primal_unaries_decoded[S].astype(float), W_gt)

rec_ratio = np.abs(X_gt != X_rec).sum() / X_gt.size
print(f"Reconstruction error: {(100 * rec_ratio):.3f}%")

In [None]:
_ = plot_images(X_rec[:, 0], nr=4)
plt.title("Reconstructed images", fontsize=18)

In [None]:
_ = plot_images(X_gt[:, 0], nr=4)
plt.title("Original images", fontsize=18)