In [None]:
# Copyright 2022 Intrinsic Innovation LLC.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# # Uncomment this block if running on colab.research.google.com
# !pip install git+https://github.com/deepmind/PGMax.git

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

############
# Load PGMax
from pgmax import fgraph, fgroup, infer, vgroup

### Construct variable grid, initialize factor graph, and add factors

In [None]:
variables = vgroup.NDVarArray(num_states=2, shape=(50, 50))
fg = fgraph.FactorGraph(variable_groups=variables)

variables_for_factors = []
for ii in range(50):
  for jj in range(50):
    kk = (ii + 1) % 50
    ll = (jj + 1) % 50
    variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
    variables_for_factors.append([variables[ii, jj], variables[ii, ll]])

factor_group = fgroup.PairwiseFactorGroup(
    variables_for_factors=variables_for_factors,
    log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
)
fg.add_factors(factor_group)

### Run inference

In [None]:
evidence_updates={variables: np.random.gumbel(size=(50, 50, 2))}

inferer = infer.build_inferer(fg.bp_state, backend="bp")
inferer_arrays = inferer.init(evidence_updates=evidence_updates)
inferer_arrays, msgs_deltas = inferer.run_with_diffs(inferer_arrays, num_iters=3000, temperature=0)

### Visualize the decoding and compute its energy

In [None]:
# Get the map states
beliefs = inferer.get_beliefs(inferer_arrays)
map_states = infer.decode_map_states(beliefs)

# Compute the energy
decoding_energy = (
    infer.compute_energy(fg.bp_state, inferer_arrays, map_states)[0]
)
print("The energy of the decoding is", decoding_energy)

# Plot the image
img = map_states[variables]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

### Assess BP convergence

In [None]:
assert np.max(msgs_deltas[-10:]) < 1e-3

plt.figure(figsize=(8, 5))
plt.plot(msgs_deltas)
plt.title("Max-product convergence", fontsize=18)
plt.xlabel("BP iteration", fontsize=16)
plt.ylabel("Max abs msgs diff", fontsize=16)

### Gradients and batching

In [None]:
def loss(log_potentials_updates, evidence_updates):
  inferer_arrays = inferer.init(
    log_potentials_updates=log_potentials_updates,
    evidence_updates=evidence_updates
  )
  inferer_arrays = inferer.run(inferer_arrays, num_iters=3000)
  beliefs = inferer.get_beliefs(inferer_arrays)
  loss = -jnp.sum(beliefs[variables])
  return loss


batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0))
log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))

In [None]:
batch_loss(None, {variables: np.random.gumbel(size=(10, 50, 50, 2))})

In [None]:
grads = log_potentials_grads(
    {factor_group: jnp.eye(2)},
    {variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))},
)

### Message and evidence manipulation

In [None]:
bp_state = inferer.to_bp_state(inferer_arrays)

# Query evidence for variable (0, 0)
bp_state.evidence[variables[0, 0]]

In [None]:
# Set evidence for variable (0, 0)
bp_state.evidence[variables[0, 0]] = np.array([1.0, 1.0])
bp_state.evidence[variables[0, 0]]

In [None]:
# Set evidence for all variables using an array
evidence = np.random.randn(50, 50, 2)
bp_state.evidence[variables] = evidence
np.allclose(bp_state.evidence[variables[10, 10]], evidence[10, 10])