# EHT Tutorial 3: Diffusion-Based Imaging

In this tutorial, we use [InverseBench](https://devzhk.github.io/InverseBench/) to perform imaging from EHT M87 data with a pretrained diffusion model. We highlight the [PnP-DM algorithm](https://imaging.cms.caltech.edu/pnpdm/), although other algorithms are available in InverseBench.

## Environment setup

In [None]:
# Install required packages.
!pip install ehtim
!pip install hydra-core
!pip install piq
!pip install torch

In [None]:
# Clone `InverseBench` repo.
!git clone https://github.com/devzhk/InverseBench
%cd InverseBench

# Download data for blackhole problem.
!wget https://sdsc.osn.xsede.org/ini230004-bucket01/zg89b-mpv16/blackhole.zip
!unzip blackhole.zip

# Download preprocessed obs file for M87.
!wget https://github.com/berthyf96/eht_tutorial/raw/refs/heads/main/obs_095_preprocessed.uvfits

# Make a copy of the obs file in the default location.
!cp obs_095_preprocessed.uvfits blackhole/measure/obs.uvfits

# Download weights of pretrained blackhole diffusion model.
!wget https://github.com/devzhk/InverseBench/releases/download/diffusion-prior/blackhole-50k.pt

# Move the checkpoint to the default location.
!mkdir checkpoints
!mv blackhole-50k.pt checkpoints/

In [None]:
# Import libraries.
import os
import pickle
from hydra import initialize_config_dir, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf

import ehtim as eh
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader

from utils.helper import open_url

# Use GPU if available, else CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## InverseBench

### Set up forward model and inverse solver

In [None]:
# Initialize config with Hydra.
abs_config_dir = os.path.abspath('/content/InverseBench/configs')
with initialize_config_dir(version_base='1.3', config_dir=abs_config_dir):
  config = compose(
      config_name='config.yaml',
      overrides=[
          'problem=blackhole',
          'pretrain=blackhole',
          'algorithm=pnpdm',
          # NOTE: any algorithm hyperparameter overrides would go here
          'num_samples=10'
      ]
  )

In [None]:
# Load pre-trained model.
try:
  with open_url(config.problem.prior, 'rb') as f:
    ckpt = pickle.load(f)
    net = ckpt['ema'].to(device)
except:
  net = instantiate(config.pretrain.model)
  ckpt = torch.load(config.problem.prior, map_location=device)
  if 'ema' in ckpt.keys():
    net.load_state_dict(ckpt['ema'])
  else:
    net.load_state_dict(ckpt['net'])
  net = net.to(device)

del ckpt
net.eval()

In [None]:
# We will use the fiducial total compact flux density for the flux data term.
zbl = 0.6

# Create forward operator.
forward_op = instantiate(
    config.problem.model,
    device=device,
    root='blackhole/measure',
    ttype='fast',
    ref_flux=zbl,
    w1=0.,  # amplitudes weight
    w2=1.,  # closure phases weight (default=1)
    w3=1.,  # log closure amplitudes weight (default=1)
    w4=0.5  # flux constraint weight (default=0.5)
)

# Set up PnPDP sampling algorithm.
algo = instantiate(config.algorithm.method, forward_op=forward_op, net=net)

### Create a measurement object from the M87 data

In [None]:
def precalibrate_obs(obs_orig, npix, fov, sys_noise=0.0,
                     reverse_taper_uas=0.0, ttype='nfft'):
  """Precalibrate preprocessed Obsdata as in eht-imaging M87 pipeline."""
  obs = obs_orig.copy()

  # Reverse taper the observation: this enforces a maximum resolution on
  # reconstructed features.
  if reverse_taper_uas > 0:
    obs = obs.reverse_taper(reverse_taper_uas * eh.RADPERUAS)

  # Add non-closing systematic noise to the observation.
  obs = obs.add_fractional_noise(sys_noise)

  # Make a copy of the initial data
  # (before any self-calibration but after the taper)
  obs_sc_init = obs.copy()

  # Self-calibrate the LMT to a Gaussian model
  # (Refer to Section 4's "Pre-Imaging Considerations")
  obs_LMT = obs_sc_init.flag_uvdist(uv_max=2e9) # only consider the
                                                # short baselines (LMT-SMT)
  if reverse_taper_uas > 0:
    # Start with original data that had no reverse taper applied.
    # Re-taper, if necessary.
    obs_LMT = obs_LMT.taper(reverse_taper_uas * eh.RADPERUAS)

  # Make a Gaussian image that would result in the LMT-SMT baseline visibility
  # amplitude as estimated in Section 4's "Pre-Imaging Considerations".
  # This is achieved with a Gaussian of size 60 microarcseconds and total flux
  # of 0.6 Jy.
  gausspriorLMT = eh.image.make_square(obs, npix, fov)
  gausspriorLMT = gausspriorLMT.add_gauss(
    0.6,
    (60.0 * eh.RADPERUAS, 60.0 * eh.RADPERUAS, 0, 0, 0))

  # Self-calibrate the LMT visibilities to the gausspriorLMT image
  # to enforce the estimated LMT-SMT visibility amplitude.
  caltab = eh.selfcal(obs_LMT, gausspriorLMT, sites=['LM'], gain_tol=1.0,
                      method='both', ttype=ttype, caltable=True)

  # Supply the calibration solution to the full (and potentially tapered)
  # dataset.
  obs = caltab.applycal(obs, interp='nearest', extrapolate=True)

  return obs

In [None]:
# Load the preprocessed M87 observation data.
obs = eh.obsdata.load_uvfits('obs_095_preprocessed.uvfits')

# We found that precalibrating the data helps with the diffusion model results.
obs = precalibrate_obs(
    obs,
    npix=64,
    fov=128 * eh.RADPERUAS,
    sys_noise=0.03,
    ttype='fast'
)

In [None]:
# The diffusion model was trained on images whose pixels are between [0, 1],
# so we have to divide the measured amplitudes by a multiplier to get them
# closer to a value expected by the diffusion model. We will later multiply the
# diffusion model's output by the same multiplier to get an image that has the
# actual total flux.
# We'll assume a total flux based on a Gaussian prior.
im = eh.image.make_square(obs, npix=64, fov=128 * eh.RADPERUAS)
prior_fwhm = 40 * eh.RADPERUAS
im = im.add_gauss(zbl, (prior_fwhm, prior_fwhm, 0, 0, 0))
multiplier = im.ivec.max()

# Uncomment the line below to instead define the flux multiplier based on the
# total flux of a GRMHD image.
# multiplier = forward_op.ref_multiplier

print(multiplier)

In [None]:
# Get the amplitudes and their sigmas.
obs.add_amp()
amp = torch.from_numpy(
    obs.amp['amp'])[None, None, :, None].float().to(device)
sigmaamp = torch.from_numpy(
    obs.amp['sigma'])[None, None, :, None].float().to(device)

# Rescale the amplitudes and their sigmas to be more in the range expected for
# an image with a max pixel value of 1.
amp = amp / multiplier
sigmaamp = sigmaamp / multiplier

In [None]:
# Get closure phases and their sigmas from the minimal set of closure phases.
obs.add_cphase(count='min')
cp = torch.from_numpy(
    obs.cphase['cphase'])[None, None, :, None].float().to(device) * eh.DEGREE
sigmacp = torch.from_numpy(
    obs.cphase['sigmacp'])[None, None, :, None].float().to(device) * eh.DEGREE

In [None]:
# Get log closure amplitudes and their sigmas from the minimal set of
# closure ampltudes.
obs.add_logcamp(count='min')
camp = torch.from_numpy(
    obs.logcamp['camp'])[None, None, :, None].float().to(device)
sigmaca = torch.from_numpy(
    obs.logcamp['sigmaca'])[None, None, :, None].float().to(device)

In [None]:
# The flux constraint is based on the assumed compact flux density.
flux = torch.tensor([zbl])[None, None, :, None].float().to(device)

# We also rescale the assumed flux to be closer to that of an image whose
# max pixel value is 1.
flux = flux / multiplier

In [None]:
# The observation `y` is a concatenation of the amplitude, closure phase,
# log closure amplitude, and flux data.
y = torch.cat([amp, sigmaamp, cp, sigmacp, camp, sigmaca, flux], dim=2)

### Inference

In [None]:
# Run the algorithm.
print(f'Running inference on M87 data...', flush=True)
recon = algo.inference(y, num_samples=config.num_samples)
print('Peak GPU memory usage: '
      f'{torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB')

result_dict = {
  'observation': y,
  'recon': forward_op.unnormalize(recon),
}

In [None]:
# Evaluate chi-squared metrics.
cp_chi2s, logcamp_chi2s = forward_op.evaluate_chisq(result_dict['recon'], y)

# Plot a histogram of the closure phase chi-squared values.
plt.hist(cp_chi2s.cpu().numpy())
plt.xlabel('closure phase chi2')
plt.ylabel('# samples')
plt.show()

In [None]:
# Show the image samples and their chi-squared values.
recon_images = result_dict['recon'].cpu().permute(0, 2, 3, 1).numpy()

fig, axs = plt.subplots(1, 10, figsize=(20, 3))
for ax, image, cp_chi2, logcamp_chi2 in zip(axs, recon_images,
                                            cp_chi2s, logcamp_chi2s):
  ax.imshow(image, cmap='afmhot')
  ax.axis('off')
  ax.set_title(f"cp $\chi^2$: {cp_chi2:.2f}\nlogca $\chi^2$: {logcamp_chi2:.2f}")
plt.show()

### Look at unconditional samples from the diffusion model

In [None]:
# Create a new config for unconditional sampling.
with initialize_config_dir(version_base='1.3', config_dir=abs_config_dir):
  config = compose(
      config_name='config.yaml',
      overrides=[
          'problem=blackhole',
          'pretrain=blackhole',
          'algorithm=uncond',
          'num_samples=10'
      ]
  )

# Set up unconditional sampling algorithm.
algo = instantiate(config.algorithm.method, forward_op=forward_op, net=net)

# Use dummy measurements.
dummy_y = torch.zeros((config.num_samples, 1, 1, 1)).to(device)

# Run the algorithm.
uncond_recon = algo.inference(dummy_y, num_samples=config.num_samples)

# Renormalize to [0, 1].
uncond_recon = forward_op.unnormalize(uncond_recon)

# Reformat to NumPy.
uncond_images = uncond_recon.cpu().permute(0, 2, 3, 1).numpy()

In [None]:
# Show unconditional samples.
fig, axs = plt.subplots(1, 10, figsize=(20, 3))
for ax, image in zip(axs, uncond_images):
  ax.imshow(image, cmap='afmhot')
  ax.axis('off')
plt.show()