Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 529549340
Change-Id: Idf92eaa9cea7c63a8e4fb8daa634a675e029ed39
  • Loading branch information
PGMax team authored and antoine-dedieu committed May 5, 2023
1 parent a2af982 commit e07db52
Show file tree
Hide file tree
Showing 40 changed files with 4,035 additions and 706 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ to install JAX for GPU.
Here are a few self-contained Colab notebooks to help you get started on using PGMax:

- [Tutorial on basic PGMax usage](https://colab.research.google.com/github/deepmind/PGMax/blob/master/examples/rbm.ipynb)
- [LBP inference on Ising model](https://colab.research.google.com/github/deepmind/PGMax/blob/master/examples/ising_model.ipynb)
- [Implementing max-product LBP](https://colab.research.google.com/github/deepmind/PGMax/blob/master/examples/rcn.ipynb)
for [Recursive Cortical Networks](https://www.science.org/doi/10.1126/science.aag2612)
- [End-to-end differentiable LBP for gradient-based PGM training](https://colab.research.google.com/github/deepmind/PGMax/blob/master/examples/gmrf.ipynb)
- [2D binary deconvolution](https://colab.research.google.com/github/deepmind/PGMax/blob/master/examples/pmp_binary_deconvolution.ipynb)
- [Alternative inference with Smooth Dual LP-MAP](https://colab.research.google.com/github/deepmind/PGMax/blob/master/examples/sdlp_examples.ipynb)

## Citing PGMax

Expand Down
21 changes: 19 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ Infer
.. autosummary::
bp
bp_state
dual_lp
energy
inferer

bp
~~~~~~~~~~~~~

.. autoclass:: BeliefPropagation
.. autofunction:: BP
.. autofunction:: decode_map_states
.. autofunction:: get_marginals

bp_state
Expand All @@ -119,11 +120,24 @@ bp_state
.. autoclass:: Evidence
.. autoclass:: BPState

dual_lp
~~~~~~~~~~~~~

.. autoclass:: SmoothDualLP
.. autofunction:: SDLP

energy
~~~~~~~~~~~~~

.. autofunction:: compute_energy

inferer
~~~~~~~~~~~~~

.. autofunction:: decode_map_states
.. autoclass:: Inferer
.. autoclass:: InfererContext


Vgroup
===================
Expand Down Expand Up @@ -152,4 +166,7 @@ vdict

Utils
===================
.. automodule:: pgmax.utils
.. automodule:: pgmax.utils

.. autosummary::
primal_lp
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ pre-commit>=2.13.0
pytest-cov>=2.12.1
tensorflow_datasets>=4.6.0
tensorflow>=2.9.1
cvxpy>=1.2.0
8 changes: 5 additions & 3 deletions examples/gmrf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
},
"outputs": [],
"source": [
"bp = infer.BP(fg.bp_state, temperature=1.0)"
"bp = infer.build_inferer(fg.bp_state, backend=\"bp\")"
]
},
{
Expand Down Expand Up @@ -314,13 +314,14 @@
" target = prototype_targets[target_image]\n",
" marginals = infer.get_marginals(\n",
" bp.get_beliefs(\n",
" bp.run_bp(\n",
" bp.run(\n",
" bp.init(\n",
" evidence_updates={variables: evidence},\n",
" log_potentials_updates=log_potentials,\n",
" ),\n",
" num_iters=15,\n",
" damping=0.0,\n",
" temperature=1.0\n",
" )\n",
" )\n",
" )[variables]\n",
Expand Down Expand Up @@ -376,13 +377,14 @@
" target = prototype_targets[target_image]\n",
" marginals = infer.get_marginals(\n",
" bp.get_beliefs(\n",
" bp.run_bp(\n",
" bp.run(\n",
" bp.init(\n",
" evidence_updates={variables: evidence},\n",
" log_potentials_updates=log_potentials,\n",
" ),\n",
" num_iters=15,\n",
" damping=0.0,\n",
" temperature=1.0\n",
" )\n",
" )\n",
" )\n",
Expand Down
65 changes: 44 additions & 21 deletions examples/ising_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,33 +95,31 @@
"id": "Nq7Z-aAzprQi"
},
"source": [
"### Run inference and visualize results"
"### Run inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2T2qLWXFpqwV"
"id": "xqLfZzxhu3Vl"
},
"outputs": [],
"source": [
"bp = infer.BP(fg.bp_state, temperature=0)"
"evidence_updates={variables: np.random.gumbel(size=(50, 50, 2))}\n",
"\n",
"inferer = infer.build_inferer(fg.bp_state, backend=\"bp\")\n",
"inferer_arrays = inferer.init(evidence_updates=evidence_updates)\n",
"inferer_arrays, msgs_deltas = inferer.run_with_diffs(inferer_arrays, num_iters=3000, temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {
"id": "sS0-wqntpxYC"
"id": "MNiOlx0v53n_"
},
"outputs": [],
"source": [
"bp_arrays = bp.init(\n",
" evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}\n",
")\n",
"bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)\n",
"beliefs = bp.get_beliefs(bp_arrays)"
"### Visualize the decoding and compute its energy"
]
},
{
Expand All @@ -132,22 +130,46 @@
},
"outputs": [],
"source": [
"# Get the map states\n",
"beliefs = inferer.get_beliefs(inferer_arrays)\n",
"map_states = infer.decode_map_states(beliefs)\n",
"\n",
"# Compute the energy\n",
"decoding_energy = (\n",
" infer.compute_energy(fg.bp_state, inferer_arrays, map_states)[0]\n",
")\n",
"print(\"The energy of the decoding is\", decoding_energy)\n",
"\n",
"# Plot the image\n",
"img = map_states[variables]\n",
"fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
"ax.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tAtQ1Mat5MGa"
},
"source": [
"### Assess BP convergence"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KFnpQgMJmsGW"
"id": "-cHbxe2a42O3"
},
"outputs": [],
"source": [
"decoding_energy = infer.compute_energy(fg.bp_state, bp_arrays, map_states)[0]\n",
"print(\"The energy of the decoding is\", decoding_energy)"
"assert np.max(msgs_deltas[-10:]) \u003c 1e-3\n",
"\n",
"plt.figure(figsize=(8, 5))\n",
"plt.plot(msgs_deltas)\n",
"plt.title(\"Max-product convergence\", fontsize=18)\n",
"plt.xlabel(\"BP iteration\", fontsize=16)\n",
"plt.ylabel(\"Max abs msgs diff\", fontsize=16)"
]
},
{
Expand All @@ -168,11 +190,12 @@
"outputs": [],
"source": [
"def loss(log_potentials_updates, evidence_updates):\n",
" bp_arrays = bp.init(\n",
" log_potentials_updates=log_potentials_updates, evidence_updates=evidence_updates\n",
" inferer_arrays = inferer.init(\n",
" log_potentials_updates=log_potentials_updates,\n",
" evidence_updates=evidence_updates\n",
" )\n",
" bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)\n",
" beliefs = bp.get_beliefs(bp_arrays)\n",
" inferer_arrays = inferer.run(inferer_arrays, num_iters=3000)\n",
" beliefs = inferer.get_beliefs(inferer_arrays)\n",
" loss = -jnp.sum(beliefs[variables])\n",
" return loss\n",
"\n",
Expand All @@ -189,7 +212,7 @@
},
"outputs": [],
"source": [
"batch_loss(None, {variables: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})"
"batch_loss(None, {variables: np.random.gumbel(size=(10, 50, 50, 2))})"
]
},
{
Expand Down Expand Up @@ -223,7 +246,7 @@
},
"outputs": [],
"source": [
"bp_state = bp.to_bp_state(bp_arrays)\n",
"bp_state = inferer.to_bp_state(inferer_arrays)\n",
"\n",
"# Query evidence for variable (0, 0)\n",
"bp_state.evidence[variables[0, 0]]"
Expand Down
4 changes: 2 additions & 2 deletions examples/pmp_binary_deconvolution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@
},
"outputs": [],
"source": [
"bp = infer.BP(fg.bp_state, temperature=0.0)"
"bp = infer.build_inferer(fg.bp_state, backend=\"bp\")"
]
},
{
Expand Down Expand Up @@ -400,7 +400,7 @@
")\n",
"\n",
"bp_arrays = jax.vmap(\n",
" functools.partial(bp.run_bp, num_iters=100, damping=0.5),\n",
" functools.partial(bp.run, num_iters=100, damping=0.5, temperature=0.0),\n",
" in_axes=0,\n",
" out_axes=0,\n",
")(bp_arrays)\n",
Expand Down

0 comments on commit e07db52

Please sign in to comment.