In [None]:
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This notebook trains a noisy OR Bayesian network with max-product on the 2D blind deconvolution experiment, and reproduces some results presented in Section 6.6 of the [paper](https://arxiv.org/pdf/2302.00099.pdf)

In [None]:
# # Uncomment this block if running on colab.research.google.com
# !pip install git+https://github.com/deepmind/max_product_noisy_or.git
# !wget https://raw.githubusercontent.com/deepmind/PGMax/main/examples/example_data/conv_problem.npz
# !mkdir data
# !mv conv_problem.npz  data/

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load max_product_noisy_or
from mp_noisy_or import config, noisy_or_bp, results_utils

### Load data

In [None]:
# Load data
folder_name = "data/"
data = np.load(open(folder_name + "conv_problem.npz", 'rb'), allow_pickle=True)
W_gt = data["W"][0]
X_gt = data["X"]
X_gt = X_gt[:20]

results_utils.plot_images(X_gt[:8, 0], nr=2)
_ = plt.title("Convolved images", fontsize=20)

In [None]:
results_utils.plot_images(W_gt, nr=1)
_ = plt.title("Ground truth features", fontsize=22)

### Train noisy-OR Bayesian network with BP

In [None]:
this_config = config.get_config_BP_2Ddeconv()

# Here, we modify the default parameters to accelerate convergence
this_config.learning.num_iters = 600
this_config.learning.proba_init = 0.9

# Training should take 3min on a GPU
NoisyOR = noisy_or_bp.NoisyOR_BP(this_config)
results_BP = NoisyOR.train()

In [None]:
# Visualize training loss
plt.figure(figsize=(10, 5))
plt.plot(results_BP["all_train_avg_elbos"])
plt.xlabel("Training iteration", fontsize=16)
_ = plt.title("Training Elbo", fontsize=18)

### Visualize the learned potentials

In [None]:
LP_THRE = np.log(2)
log_potentials = np.array(results_BP["log_potentials"])[:5, :6, :6]
W_learned = (log_potentials > LP_THRE).astype(float)

results_utils.plot_images(log_potentials, nr=1, images_min=0, images_max=-np.log(0.01))
_ = plt.title("Parameters learned with BP", fontsize=22)

results_utils.plot_images(W_learned, nr=1)
_ = plt.title("Binary features learned with BP", fontsize=22)

### Compute metrics

In [None]:
print(f"After {this_config.learning.num_iters} training iterations")

# Test Elbo
test_avg_elbo_mode = results_BP['all_test_avg_elbos_mode'][-1]
print(f"Test elbo : {round(test_avg_elbo_mode, 3)}")
  
# Test reconstruction error
_, test_rec_X, test_rec_ratio = results_utils.BD_reconstruction(
    NoisyOR.Xv_gt_test, results_BP["test_X_samples"], W_learned
)
print(f"Test rec. error: {round(100 *test_rec_ratio, 3)}%")

# IOU matching
iou_matching = results_utils.features_iou(W_gt, W_learned)
print(f"IOU matching : {round(iou_matching, 3)}")

In [None]:
# Plot the reconstructed images
img = results_utils.plot_images(test_rec_X, nr=5)
_ = plt.title("Test images reconstructed with PMP", fontsize=22)

In [None]:
img = results_utils.plot_images(NoisyOR.Xv_gt_test, nr=5)
_ = plt.title("Ground truth test images", fontsize=22)