# SYMPAIS Torus Demo
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ethanluoyc/sympais/blob/master/notebooks/torus_demo.ipynb)

This notebook provides a visual illustration of the SYMPAIS algorithm.

## Setup

In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [None]:
GIT_TOKEN = ""
if IN_COLAB:
    !pip install -U pip setuptools wheel
    if GIT_TOKEN:
        !pip install git+https://{GIT_TOKEN}@github.com/ethanluoyc/sympais.git#egg=sympais
    else:
        !pip install git+https://github.com/ethanluoyc/sympais.git#egg=sympais


In [None]:
if IN_COLAB:
    !curl -L "https://drive.google.com/uc?export=download&id=1_Im0Ot5TjkzaWfid657AV_gyMpnPuVRa" -o realpaver
    !chmod u+x realpaver
    !cp realpaver /usr/local/bin

In [None]:
import jax
import jax.numpy as jnp

from sympais import tasks
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import numpy as onp
import math

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# Configure z3
import z3
z3.set_option("smt.arith.random_initial_value", True)
z3.set_option("auto_config", False)
z3.set_option("smt.phase_selection", 5)
z3.set_option("smt.random_seed", 42)

## Implementation



### SYMPAIS
Here we show how to implement SYMPAIS from the different components described in the paper.

In [None]:
from typing import Optional

import jax
import jax.numpy as jnp
import numpy as onp

from sympais import constraint
from sympais import logger as logger_lib
from sympais import tasks
from sympais.infer import importance
from sympais.infer import utils
from sympais.initializer import Initializer
from sympais.methods.importance import ProposalBuilder
from sympais.methods.importance import RandomWalkMetropolisKernel
from sympais.methods.importance import refine_domains
from sympais.methods.importance import sample_chain
from sympais.methods.importance import WindowedScaleAdaptor

def run(task: tasks.Task,
            seed: int,
            num_samples: int = int(1e6),
            num_proposals: int = 100,
            num_samples_per_iter: int = 5,
            proposal_scale_multiplier: float = 0.5,
            rmh_scale: float = 1.0,
            init: str = "z3",
            tune: bool = True,
            num_warmup_steps: int = 500,
            window_size: int = 100,
            resample: bool = True,
            proposal_std_num_samples: int = 100,
            logger: Optional[logger_lib.Logger] = None):
  """Run SYMPAIS.
  Refer to the inline comments for a better idea of what this function
  implements
  """
  profile = task.profile
  pcs = task.constraints
  # Build callable for 1_PC(x)
  constraint_fn = constraint.build_constraint_fn(pcs, task.domains)
  # Build the unnormalized density \bar{p}(x) = 1_PC(x) p(x)
  target_log_prob_fn = constraint.build_target_log_prob_fn(
      task.profile, task.domains, constraint_fn)
  domains = task.domains
  # Find a coarse approximation of the solution space.
  # This is used both by the RMH kernel for making proposals and
  # by the IS proposal for proposing from truncated distributions.
  refined_domains = refine_domains(pcs, domains)

  key = jax.random.PRNGKey(seed)
  key, subkey = jax.random.split(key)
  # The proposal builder is a callable that
  # constructs importance sampling proposal distribution q(x)
  # for performing MIS at every iteration of PIMAIS
  proposal_builder = ProposalBuilder(profile, refined_domains,
                                     proposal_scale_multiplier,
                                     proposal_std_num_samples, subkey)
  # The initializer finds initial feasible solution to bootstrap
  # the MCMC chains
  initializer_ = Initializer(profile, pcs, domains, init, resample)
  initial_chain_state = initializer_(num_proposals, subkey)
  # Construct a RMH transition kernel
  kernel = RandomWalkMetropolisKernel(target_log_prob_fn,
                                      jnp.ones(num_proposals) * rmh_scale,
                                      refined_domains)
  kernel.step = jax.jit(kernel.step)
  key, subkey = jax.random.split(key)

  # Initialize kernel parameters and run warmup
  # and optional parameter adaptation
  params = kernel.init()
  key, subkey = jax.random.split(key)
  if num_warmup_steps < 1:
    print("Not running warmup")
    chain_state = initial_chain_state
  else:
    if tune:
      print("Tuning the kernel")
      params, chain_state, _ = WindowedScaleAdaptor(kernel, window_size)(
          subkey, params, initial_chain_state, num_warmup_steps)
    else:
      print("Not tuning the kernel")
      chain_state, (_, extra) = sample_chain(kernel, params, subkey,
                                             initial_chain_state,
                                             num_warmup_steps)
    print("Finished warm-up")
  # Comput the number of iterations given total sampling budget
  num_samples_warmup = num_proposals * num_warmup_steps
  num_iterations = (
      # 1) subtract the samples used during warmup
      (num_samples - num_samples_warmup) //
      # 2) For each PI-MAIS iteration, we sample from each mixture component
      # `num_samples_per_iter` samples, plus an additional sample used by
      # each chain for making a single step of transition
      ((num_proposals + 1) * num_samples_per_iter))

  # Initialize the state for running PI-MAIS.
  state = importance.pimais_init(chain_state)

  @jax.jit
  def pimais_step_fn(params, rng, state):
    kernel_fn = lambda key, state: kernel.step(params, key, state)
    return importance.pimais_step(rng, target_log_prob_fn, kernel_fn,
                                  proposal_builder, num_samples_per_iter, state)

  # Start running the PI-MAIS iterations
  rngs = jax.random.split(key, num_iterations)
  states = []
  extras = []
  for idx in range(num_iterations):
    # tic = time.time()
    state, extra = pimais_step_fn(params, rngs[idx], state)
    states.append(state)
    extras.append(extra)
    # Make sure async dispatch is accounted for in measuring running time.
    utils.block_until_ready((state, extra))
    # toc = time.time()
    # print("Time elapsed", toc - tic, "Mean", state.Ztot)
  if logger is not None:
    logger.write({"mean": state.Ztot})
  print("Final estimated probability {}".format(state.Ztot))
  # Collect some intermediate results for post-processing and viz.
  output = {
      "pimais_states": jax.tree_multimap(lambda *x: jnp.stack(x, 0), *states),
      "pimais_extras": jax.tree_multimap(lambda *x: jnp.stack(x, 0), *extras),
      "constraint_fn": constraint_fn,
      "target_log_prob_fn": target_log_prob_fn,
      "initial_chain_state": initial_chain_state,
      "proposal_builder": proposal_builder,
      "prob": state.Ztot,
  }
  return output

### Visualization
Some helper functions for visualization

In [None]:
def _plot_proposals(axes, initial_proposal_state, proposal_states, proposal_fn):
    for i, t in enumerate([0, 10, 100]):
        if t == 0:
            proposal_state = initial_proposal_state
        else:
            proposal_state = jax.tree_map(
                lambda x: x[t - 1], proposal_states.proposal_state
            )
        proposal_dist = proposal_fn(proposal_state)
        ax = axes[i]
        plot_proposal_dist(ax, proposal_dist)
        plot_proposal_state(ax, proposal_state)
        _plot_torus_boundaries(ax)
        ax.set_xlim(-5, 5)
        ax.set_ylim(-5, 5)
        ax.set(title=f"$t = {t}$")


def _plot_torus_boundaries(ax):
    inner_circle = plt.Circle(
        (0, 0),
        2,
        fill=False,
        linestyle="--",
        linewidth=.5,
        edgecolor="black",
        label="inner",
    )
    outer_circle = plt.Circle(
        (0, 0), 4, fill=False, linestyle="--", linewidth=.5, edgecolor="black"
    )
    ax.add_artist(inner_circle)
    ax.add_artist(outer_circle)


def _fixup_axes(ax):
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.set(xlabel="$x$", ylabel="$y$", title="$z = 0$")


def plot_proposal_dist(ax, proposal, resolution=100):
    x = jnp.linspace(-5, 5, resolution)
    y = jnp.linspace(-5, 5, resolution)
    xx, yy = jnp.meshgrid(x, y)
    logp = proposal.log_prob(
        {
            "x": xx.reshape(-1, 1),
            "y": yy.reshape(-1, 1),
            "z": jnp.zeros((resolution * resolution, 1)),
        }
    )
    p = jnp.mean(jnp.exp(logp), -1).reshape((resolution, resolution))
    ax.contourf(xx, yy, p, cmap="Blues")


def plot_proposal_state(ax, proposal_state, s=5, linewidth=0.5):
    ax.scatter(
        proposal_state["x"],
        proposal_state["y"],
        marker="o",
        c="white",
        edgecolors="black",
        linewidth=linewidth,
        s=s,
        alpha=0.75,
    )

In [None]:
output = run(
    tasks.Torus(), 
    seed=0,
    init="z3", num_warmup_steps=0, 
    proposal_scale_multiplier=1.0
)

## Visualizing the input distribution and corresponding optimal proposal distribution

In [None]:
task = tasks.Torus()
constraint_fn = constraint.build_constraint_fn(task.constraints, task.domains)
target_log_prob_fn = constraint.build_target_log_prob_fn(task.profile, task.domains,
                                              constraint_fn)
fig, ax = plt.subplots(2, 2, sharey=True)
fig.set_size_inches(6, 6.)
resolution = 100
x = jnp.linspace(-5, 5, resolution)
y = jnp.linspace(-5, 5, resolution)
xx, yy = jnp.meshgrid(x, y)

logp = (
    task.profile.log_prob(
        {
            "x": xx.reshape(-1, 1),
            "y": yy.reshape(-1, 1),
            "z": jnp.zeros((resolution * resolution, 1)),
        }
    )
    .reshape((resolution, resolution))
)


ax[0, 0].contourf(xx, yy, jnp.exp(logp), cmap="Blues")
_plot_torus_boundaries(ax[0, 0])
_fixup_axes(ax[0, 0])

logp_gt = target_log_prob_fn(
    {
        "x": xx.reshape(-1, 1),
        "y": yy.reshape(-1, 1),
        "z": jnp.zeros((resolution * resolution, 1)),
    }
)

p_gt = jnp.exp(logp_gt).reshape((resolution, resolution))
p_gt = onp.ma.masked_where(p_gt == 0.0, p_gt)

ax[0, 1].contourf(xx, yy, p_gt, cmap="Blues")
_plot_torus_boundaries(ax[0, 1])
_fixup_axes(ax[0, 1])
for a in ax.flat:
    a.set_aspect(1)
ax[0, 1].set_ylabel("")
ax[0, 0].set_title("$p(x)$")
ax[0, 1].set_title("$q^*({x})$")

# task = tasks.torus(profile_type="correlated")
task = tasks.Torus(profile_type="correlated")
constraint_fn = constraint.build_constraint_fn(task.constraints, task.domains)
target_log_prob_fn = constraint.build_target_log_prob_fn(
    task.profile, task.domains, constraint_fn)

logp = task.profile.log_prob(
    {
        "x": xx.reshape(-1, 1),
        "y": yy.reshape(-1, 1),
        "z": jnp.zeros((resolution * resolution, 1)),
    }
).reshape((resolution, resolution))

ax[1, 0].contourf(xx, yy, jnp.exp(logp), cmap="Blues")
_plot_torus_boundaries(ax[1, 0])
_fixup_axes(ax[1, 0])

logp_gt = target_log_prob_fn(
    {
        "x": xx.reshape(-1, 1),
        "y": yy.reshape(-1, 1),
        "z": jnp.zeros((resolution * resolution, 1)),
    }
)
p_gt = jnp.exp(logp_gt).reshape((resolution, resolution))
p_gt = onp.ma.masked_where(p_gt == 0.0, p_gt)

ax[1, 1].contourf(xx, yy, p_gt, cmap="Blues")
_plot_torus_boundaries(ax[1, 1])
_fixup_axes(ax[1, 1])
for a in ax.flat:
    a.set_aspect(1)
ax[1, 1].set_ylabel("")
ax[1, 1].set_title("")
ax[1, 0].set_title("")
ax[1,1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator([-4, 0, 4]))
ax[0,0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator([-4, 0, 4]))
ax[1,0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator([-4, 0, 4]))
ax[1,0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator([-4, 0, 4]))
for a in ax[0]:
    a.xaxis.set_visible(False)

## Example trajectory of SYMPAIS

In [None]:
output = run(
    tasks.Torus(), 
    seed=0,
    init="z3", 
    num_warmup_steps=0, 
    num_samples=int(1e5),
    proposal_scale_multiplier=1.0)

In [None]:
def plot_trajectory(output):
    fig, axes = plt.subplots(1, 4, 
                            figsize=matplotlib.figure.figaspect(0.25),
                            sharey=True)
    for i, t in enumerate([0, 10, 100, 1000]):
        if t == 0:
            proposal_state = output["initial_chain_state"]
        else:
            proposal_state = jax.tree_map(
                lambda x: x[t - 1], output["pimais_states"].proposal_state
            )
        proposal_dist = output['proposal_builder'](proposal_state)
        ax = axes.flat[i]
        plot_proposal_dist(ax, proposal_dist)
        # plot_proposal_state(ax, proposal_state)
        ax.scatter(
                proposal_state['x'],
                proposal_state['y'],
                marker="o",
                c="white",
                edgecolors="black",
                alpha=0.75,
        )
        _plot_torus_boundaries(ax)
        if t > 0:
            ax.plot(
                output["pimais_states"].proposal_state["x"][:t, 0],
                output["pimais_states"].proposal_state["y"][:t, 0],
                color="#F2528D",
                linestyle="dashed",
                alpha=0.75,
                linewidth=2,
            )
        ax.set_xlim(-5, 5)
        ax.set_ylim(-5, 5)
        ax.set(title=f"$t = {t}$")

    for ax in axes.flat:
        ax.set_aspect(1)
        ax.set_xlabel("$x$")
        ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator([-5, 0, 5]))
        ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator([-5, 0, 5]))
    # ="#F27781",
    # for a in axes[0]:
    #     a.set_xticks([])
    #     a.set_xlabel("")
    axes[0].set(ylabel="$y$")
    # axes[1,0].set(ylabel="$y$")
    plt.subplots_adjust(left=.2, hspace=.3, top=.9, bottom=.15, right=.95)
plot_trajectory(output)

### Interactive visualization of the SYMPAIS trajectory

We also wrote a interactive widget to visualize the SYMPAIS trajectory.

In [None]:
from ipywidgets import widgets

In [None]:
def make_interactive_plot(output):
    @widgets.interact(
        T=(0, output["pimais_states"].proposal_state['x'].shape[0], 1)
    )
    def update(T):
        fig, ax = plt.subplots(figsize=(4,4))
        inner_circle = plt.Circle(
            (0, 0),
            2,
            fill=False,
            linestyle="--",
            linewidth=1,
            edgecolor="black",
            label="inner",
        )
        outer_circle = plt.Circle(
            (0, 0), 4, fill=False, 
            linestyle="--", linewidth=1, edgecolor="black"
        )
        ax.add_artist(inner_circle)
        ax.add_artist(outer_circle)
        ax.set_xlim(-5, 5)
        ax.set_ylim(-5, 5)
        ax.set(xlabel="$x$", ylabel="$y$", title="$z = 0$")
        
        if T == 0:
            proposal_state = output['initial_chain_state']
        else:
            proposal_state = jax.tree_map(
                lambda x: x[T - 1], output["pimais_states"].proposal_state
            )
        proposal_dist = output['proposal_builder'](proposal_state)
        plot_proposal_dist(ax, proposal_dist)
        ax.scatter(
            proposal_state['x'],
            proposal_state['y'],
            marker="o",
            c="white",
            edgecolors="black",
            alpha=0.75,
        )
        plt.show()
    return update

In [None]:
make_interactive_plot(output);

## Effect of Initialization

In this section, we compare the different initialization settings, which is described in our 
optimization section.

In [None]:
task = tasks.Torus(profile_type="correlated")
num_warmup_steps = 0
num_samples = int(1e5)
seed = 0

z3_init_output = run(task, 
                     seed=seed,
                     num_samples=num_samples, 
                     init='z3', 
                     num_warmup_steps=num_warmup_steps)
rp_init_output = run(task, 
                     seed=seed,
                     num_samples=num_samples, 
                     init='realpaver', 
                     num_warmup_steps=num_warmup_steps, 
                     resample=False)
rp_resample_output = run(task, 
                         seed=seed,
                         num_samples=num_samples, 
                         init='realpaver', 
                         num_warmup_steps=num_warmup_steps,
                         resample=True)

In [None]:
fig, axes = plt.subplots(3, 3, sharey=True)
fig.set_size_inches(6, 6)

for ax, o in zip(axes, (z3_init_output, rp_init_output, rp_resample_output)):
    _plot_proposals(
        ax,
        o["initial_chain_state"],
        o["pimais_states"],
        o['proposal_builder']
    )
for ax in axes[1:, :].flat:
    ax.set(title="")
    
for ax in axes[2:, :].flat:
    ax.set(xlabel="$x$")
    ax.set(title="")

axes[0, 0].set_ylabel("Single Solution \n $ y $")
axes[1, 0].set_ylabel("Diverse Solution \n $ y $")
axes[2, 0].set_ylabel("Re-sample \n $ y $")

# for ax in axes.flat:
#     ax.set_aspect(1.0)
plt.setp(axes, aspect=1.0)
for ax in axes[0:2].flat:
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(left=0.2, bottom=0.15, right=0.975, wspace=0.05)
# fig.savefig("images/pimais_init.pdf")