In [7]:
!pip install -U seaborn

Collecting seaborn
  Downloading seaborn-0.13.0-py3-none-any.whl (294 kB)
     -------------------------------------- 294.6/294.6 kB 4.5 MB/s eta 0:00:00
Installing collected packages: seaborn
  Attempting uninstall: seaborn
    Found existing installation: seaborn 0.12.2
    Uninstalling seaborn-0.12.2:
      Successfully uninstalled seaborn-0.12.2
Successfully installed seaborn-0.13.0


In [1]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import pickle
import hydra
import omegaconf
import csv
import torch

import pyro.distributions as dist
import pyro.contrib.gp as gp

sys.path.append("../")
from models.pyro_extensions.infer import SDVI
from models.pyro_extensions.resource_allocation import SuccessiveHalving
from models import normal_model
from run_baselines import NormalModel

torch.set_default_dtype(torch.float64)
matplotlib.style.use('default')

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="baselines_config", config_name="config")


In [2]:
from cycler import cycler
line_cycler = (cycler(color=["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]) +
               cycler(linestyle=["-", "--", "-.", ":", "-", "--", "-."]))

In [3]:
algo2config = {
    "DCC": {"color": "#E69F00", "linestyle": "--"},
    "BBVI": {"color": "#56B4E9", "linestyle": "-."},
    "Pyro AutoGuide": {"color": "#009E73", "linestyle": ":"},
    "SDVI": {"color": "#D55E00", "linestyle": "-"},
    "Stochastic SDVI": {"color": "#0072B2", "linestyle": ":"},
    "S-SDVI": {"color": "#0072B2", "linestyle": ":"}
}

In [4]:
major = 8.0
minor = 5.0

major_tick_width = 2.0

update_rc_params = {
    'font.family': "serif", 
    'font.size': 24, 
    'legend.fontsize': 15,
    'text.usetex': True,
    'xtick.major.size': major,
    'xtick.major.width': major_tick_width,
    'xtick.minor.size': minor,
    'ytick.major.size': major,
    'ytick.major.width': major_tick_width,
    'ytick.minor.size': minor,
    'axes.linewidth': 2.0,
}

plt.rcParams.update(update_rc_params)
plt.rc("axes", prop_cycle=line_cycler)


# Figure 1

In [5]:
BRANCH1_PRIOR_MEAN = -3
BRANCH2_PRIOR_MEAN = 3
PRIOR_STD = 1
LIKELIHOOD_STD = 2

OBSERVED_DATA = 2

def marginal_likelihood(data, likelihood_std, prior_mean, prior_std):
    """Calculate the marginal likelihood of a branch. Assumes we observe only
    a single data point.

    Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf.
    """
    likelihood_var = math.pow(likelihood_std, 2)
    prior_var = math.pow(prior_std, 2)

    first_term = likelihood_std / (
        (math.sqrt(2 * math.pi) * likelihood_std)
        * math.sqrt(prior_var + likelihood_var)
    )
    second_term = math.exp(
        -(math.pow(data, 2) / (2 * likelihood_var))
        - (math.pow(prior_mean, 2) / (2 * prior_var))
    )
    third_term = math.exp(
        (
            (prior_var * math.pow(data, 2) / likelihood_var)
            + (likelihood_var * math.pow(prior_mean, 2) / prior_var)
            + 2 * data * prior_mean
        )
        / (2 * (prior_var + likelihood_var))
    )
    return first_term * second_term * third_term


def posterior_params(data, likelihood_std, prior_mean, prior_std):
    """Calculate the posterior mean and standard deviation of a branch. Assumes we
    observe only a single data point."""
    prior_precision = 1 / math.pow(prior_std, 2)
    likelihood_precision = 1 / math.pow(likelihood_std, 2)
    post_mean = (prior_precision * prior_mean + likelihood_precision * data) / (
        prior_precision + likelihood_precision
    )
    post_std = 1 / (prior_precision + likelihood_precision)
    return post_mean, post_std

class ToyModel:
    observed_data = torch.tensor(OBSERVED_DATA)

    branch1_prior_mean = BRANCH1_PRIOR_MEAN
    branch2_prior_mean = BRANCH2_PRIOR_MEAN
    prior_std = PRIOR_STD
    likelihood_std = LIKELIHOOD_STD

    def __init__(self, cut_point=0.0):
        self.branch1_post_mean, self.branch1_post_std = posterior_params(
            self.observed_data,
            self.likelihood_std,
            self.branch1_prior_mean,
            self.prior_std,
        )
        self.branch1_Z = marginal_likelihood(
            self.observed_data,
            self.likelihood_std,
            self.branch1_prior_mean,
            self.prior_std,
        )
        self.branch2_post_mean, self.branch2_post_std = posterior_params(
            self.observed_data,
            self.likelihood_std,
            self.branch2_prior_mean,
            self.prior_std,
        )
        self.branch2_Z = marginal_likelihood(
            self.observed_data,
            self.likelihood_std,
            self.branch2_prior_mean,
            self.prior_std,
        )

        self.cut_point = cut_point

        z0_prior = dist.Normal(0, 1)
        self.branch1_prior = z0_prior.cdf(torch.tensor(cut_point))
        self.marginal_likelihood = self.branch1_prior * self.branch1_Z + (1 - self.branch1_prior) * self.branch2_Z

        self.branch1_post_prob = (self.branch1_prior * self.branch1_Z) / self.marginal_likelihood

    def __call__(self):
        z0 = pyro.sample("z0", dist.Normal(0, 1))
        if z0 < self.cut_point:
            z1 = pyro.sample("z1", dist.Normal(self.branch1_prior_mean, self.prior_std))
        else:
            z1 = pyro.sample("z2", dist.Normal(self.branch2_prior_mean, self.prior_std))

        x = pyro.sample(
            "x", dist.Normal(z1, self.likelihood_std), obs=self.observed_data
        )
        return z0.item(), z1, x

In [6]:
def merge_weights(file_list):
    weight_keys = ["br2_weights_pyro", "br2_weights_bbvi", "br2_weights_sdvi"]
    weights_dict = {
        "br2_weights_pyro": [],
        "br2_weights_bbvi": [],
        "br2_weights_sdvi": [],
    }

    for filename in file_list:
        with open(filename, "rb") as f:
            weights = pickle.load(f)
        for k in weight_keys:
            weights_dict[k].append(weights[k])
    
    for k in weight_keys:
        weights_dict[k] = torch.cat(weights_dict[k], 0)
    
    return weights_dict

In [7]:
weights_file_list = [
    # TODO: Fill out with list of files which are output of ../scripts/make_motivating_example_plot.py.
]

weights_dict = merge_weights(weights_file_list)

br2_weights_pyro = weights_dict["br2_weights_pyro"]
br2_weights_bbvi = weights_dict["br2_weights_bbvi"]
br2_weights_sdvi = weights_dict["br2_weights_sdvi"]

RuntimeError: torch.cat(): expected a non-empty list of Tensors

In [None]:
algo2config = {
    "BBVI": {"color": "#56B4E9", "linestyle": "-."},
    "Pyro": {"color": "#009E73", "linestyle": ":"},
    "SDVI": {"color": "#D55E00", "linestyle": "-"},
}

algo2config = {
    "BBVI": {"color": "#56B4E9"},
    "Pyro": {"color": "#009E73"},
    "SDVI": {"color": "#D55E00"},
}

font = {'size' : 16}

matplotlib.rc('font', **font)

update_rc_params = {
    'font.family': "serif", 
    'text.usetex': True,
    'font.size': 20, 
    'legend.fontsize': 15,
}
plt.rcParams.update(update_rc_params)

fig, ax = plt.subplots()
#  Pyro weights
weight_mean = br2_weights_pyro.mean(dim=0).numpy()
weight_std = br2_weights_pyro.std(dim=0).numpy()
ax.plot(weight_mean, lw=4, label="Pyro AutoGuide", **algo2config["Pyro"])
ax.fill_between(
    torch.arange(weight_mean.shape[0]),
    weight_mean - weight_std,
    weight_mean + weight_std,
    alpha=0.5,
    color=algo2config["Pyro"]["color"],
)

#  BBVI weights
weight_mean = br2_weights_bbvi.mean(dim=0).numpy()
weight_std = br2_weights_bbvi.std(dim=0).numpy()
ax.plot(weight_mean, lw=4, label="BBVI", **algo2config["BBVI"])
ax.fill_between(
    torch.arange(weight_mean.shape[0]),
    weight_mean - weight_std,
    weight_mean + weight_std,
    alpha=0.5,
    color=algo2config["BBVI"]["color"],
)

#  SDVI weights
weight_mean = br2_weights_sdvi[:,::8].mean(dim=0).numpy()
weight_std = br2_weights_sdvi[:,::8].std(dim=0).numpy()
ax.plot(torch.arange(weight_mean.shape[0]) * 16, weight_mean, lw=4, label="SDVI", **algo2config["SDVI"])
ax.fill_between(
    torch.arange(weight_mean.shape[0]) * 16,
    weight_mean - weight_std,
    weight_mean + weight_std,
    alpha=0.5,
    color=algo2config["SDVI"]["color"],
)

model = ToyModel(cut_point=0.0)
branch2_prob = (1 - model.branch1_post_prob)
ax.axhline(branch2_prob, ls="--", color="black", lw=4, label="Ground Truth")

ax.set_xlabel("Number of Iterations")
ax.set_ylabel(r"Probability of Branch $x \geq 0$")
ax.set_yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
# ax.set_xlim(0, 1000)
x0, y0, width, height = 0.3, 0.3, 0.4, 0.1
ax.legend(loc="center", bbox_to_anchor=(x0, y0, width, height), ncol=2)
# ax.legend()
fig.tight_layout()
fig.savefig(
    os.path.join("figures", "motivating_example_weights_sdvi.pdf")
)

# Model selection

## Weight Error

In [8]:
model_selection_gtws, _ = NormalModel().calculate_ground_truth_weights()

In [9]:
def baseline_extract_errors(sweep_dir, ground_truth_weights, method_name, fname="estimated_weights.csv"):
    slp_identifiers = list(ground_truth_weights.keys())
    ground_truth_array = np.stack(
        [np.array(ground_truth_weights[a]) for a in slp_identifiers], axis=0
    )

    experiment_dirs = [f.path for f in os.scandir(sweep_dir) if f.is_dir()]
    dfs = []
    for experiment_dir in experiment_dirs:
        run_id = os.path.basename(experiment_dir)
        if run_id == ".submitit":
            continue
        
        data = np.genfromtxt(os.path.join(experiment_dir, fname), delimiter=",", names=True)
        # Make sure columns are ordered as in ground_truth_array
        data = data[slp_identifiers]
        # For each iteration calcluate error
        num_iterations = data.shape[0]
        errors = np.zeros(num_iterations)
        for ix in range(num_iterations):
            row_vals = np.zeros(len(slp_identifiers))
            for j, id in enumerate(slp_identifiers):
                row_vals[j] = data[id][ix]
            errors[ix] = np.linalg.norm(ground_truth_array - row_vals) ** 2

        # Create pandas dataframe with iteration,error as columns and append to dfs
        errors_pd = pd.DataFrame(data={"weight_error": errors, "iteration": np.arange(num_iterations)})
        errors_pd["run_id"] = int(run_id)
        dfs.append(errors_pd)
    
    # Concat all dataframes to each other
    metrics = pd.concat(dfs)

    # Group-by iteration id and calculate mean and standard deviation so that we get a table "iteration,mean,std"
    grouped_by_iteration = metrics[["iteration", "weight_error"]].groupby("iteration")
    losses_stats = grouped_by_iteration.mean()
    losses_stats.rename(columns={"weight_error": "mean"}, inplace=True)
    losses_stats["std"] = grouped_by_iteration.std()["weight_error"]
    losses_stats["method_name"] = method_name
    return losses_stats


In [10]:
def sdvi_extract_errors(sweep_dir, ground_truth_weights, method_name, fname="exclusive_kl_results.csv"):
    slp_identifiers = list(ground_truth_weights.keys())
    ground_truth_array = np.stack(
        [np.array(ground_truth_weights[a]) for a in slp_identifiers], axis=0
    )

    # Loop through the different run idxs
    experiment_dirs = [f.path for f in os.scandir(sweep_dir) if f.is_dir()]
    dfs = []
    for experiment_dir in experiment_dirs:
        run_id = os.path.basename(experiment_dir)
        if run_id == ".submitit":
            continue
        
        # Load metrics csv into pandas dataframe
        metrics = pd.read_csv(os.path.join(experiment_dir, fname))

        num_iterations = len(metrics.index)
        errors = np.zeros(num_iterations)
        for ix in range(num_iterations):
            row_vals = np.zeros(len(slp_identifiers))
            for j, id in enumerate(slp_identifiers):
                # bt = int(id) * "0" + "1"
                bt = f"u,x_{int(id)}"
                row_vals[j] = metrics[f"weight_{bt}"][ix]
            errors[ix] = np.linalg.norm(ground_truth_array - row_vals) ** 2

        # Create pandas dataframe with iteration,error as columns and append to dfs
        errors_pd = pd.DataFrame(data={"weight_error": errors, "iteration": np.arange(num_iterations)})
        errors_pd["run_id"] = int(run_id)
        dfs.append(errors_pd)

    # Concat all dataframes to each other
    metrics = pd.concat(dfs)
    metrics.describe()

    # Group-by iteration id and calculate mean and standard deviation so that we get a table "iteration,mean,std"
    grouped_by_iteration = metrics[["iteration", "weight_error"]].groupby("iteration")
    losses_stats = grouped_by_iteration.mean()
    losses_stats.rename(columns={"weight_error": "mean"}, inplace=True)
    losses_stats["std"] = grouped_by_iteration.std()["weight_error"]
    losses_stats["method_name"] = method_name
    return losses_stats

In [11]:
sdvi_dirs = [
    ("SDVI", "TODO: Fill out path to results dir."),
]

In [12]:
baselines = [
    ("DCC", "TODO: Fill out path to results dir."),
    ("Pyro AutoGuide", "TODO: Fill out path to results dir."),
    ("BBVI", "TODO: Fill out path to results dir."),
]

method2errors = {
    n: baseline_extract_errors(sweep_dir, model_selection_gtws, n)
    for n, sweep_dir in baselines
}
for n, sweep_dir in sdvi_dirs:
    method2errors[n] = sdvi_extract_errors(sweep_dir, model_selection_gtws, n)


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'TODO: Fill out path to results dir.'

In [None]:
# fig, ax = plt.subplots(figsize=(8,4))
fig, ax = plt.subplots()
for name, errors in method2errors.items():
    ixs = errors.index
    if name == "DCC":
        ixs = ixs * 1000 + 10000
        start_ix = 0
    elif name == "SDVI":
        ixs = ixs * 500 + 2000
        start_ix = 0
    elif name == "Pyro AutoGuide":
        ixs = ixs * 1000
        start_ix = 0
    elif name == "BBVI":
        ixs = ixs * 1000
        start_ix = 0
    ax.plot(
        ixs[start_ix:], 
        errors["mean"][start_ix:], 
        label=name, 
        alpha=1.0, 
        lw=4, 
        **algo2config[name]
    )
    ax.fill_between(
        ixs[start_ix:], 
        errors["mean"][start_ix:]-errors["std"][start_ix:], 
        errors["mean"][start_ix:]+errors["std"][start_ix:], 
        alpha=0.3,
        **algo2config[name]
    )

ax.legend()
ax.set_xlabel("Computational Cost")
ax.set_ylabel("Squared Error")
# ax.set_ylim((0.0, 0.05))
# ax.set_xlim((10^4, 10^5))
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlim((1000, 100000))
# ax.set_ylim(ymin=1e-5)
ax.grid(True)
fig.tight_layout()
# fig.savefig("figures/model_selection_slp_weight_error_without_dcc_different_aspect_ratio.pdf")
fig.savefig("figures/model_selection_slp_weight_error.pdf")

## ELBO with Marginal Likelihood

In [13]:
model = normal_model.NormalModel()
sdvi = SDVI(model, 0.1, "MeanFieldNormal", utility_class=SuccessiveHalving(10))
sdvi.find_slps(100)
ground_truth_weights, global_marginal_likelihood = model.calculate_ground_truth_weights(sdvi)

In [14]:
def baseline_extract_elbos(sweep_dir, method_name, fname="elbos.csv"):
    experiment_dirs = [f.path for f in os.scandir(sweep_dir) if f.is_dir()]
    dfs = []
    for experiment_dir in experiment_dirs:
        run_id = os.path.basename(experiment_dir)
        if run_id == ".submitit":
            continue
        
        try:
            data = np.genfromtxt(os.path.join(experiment_dir, fname), delimiter=",", names=True)
        except OSError:
            # File not found
            continue

        data = data["elbos"]
        # For each iteration calcluate error
        num_iterations = data.shape[0]

        # Create pandas dataframe with iteration,error as columns and append to dfs
        errors_pd = pd.DataFrame(data={"elbos": data, "iteration": np.arange(num_iterations)})
        errors_pd["run_id"] = int(run_id)
        dfs.append(errors_pd)
    
    # Concat all dataframes to each other
    metrics = pd.concat(dfs)

    # Group-by iteration id and calculate mean and standard deviation so that we get a table "iteration,mean,std"
    grouped_by_iteration = metrics[["iteration", "elbos"]].groupby("iteration")
    losses_stats = grouped_by_iteration.mean()
    losses_stats.rename(columns={"elbos": "mean"}, inplace=True)
    losses_stats["std"] = grouped_by_iteration.std()["elbos"]
    losses_stats["method_name"] = method_name
    return losses_stats

In [15]:
def sdvi_extract_elbos(sweep_dir, method_name, fname="exclusive_kl_results.csv"):
    # Loop through the different run idxs
    experiment_dirs = [f.path for f in os.scandir(sweep_dir) if f.is_dir()]
    dfs = []
    for experiment_dir in experiment_dirs:
        run_id = os.path.basename(experiment_dir)
        if run_id == ".submitit":
            continue

        # Load metrics csv into pandas dataframe
        metrics = pd.read_csv(os.path.join(experiment_dir, fname))

        num_iterations = len(metrics.index)

        # Create pandas dataframe with iteration,error as columns and append to dfs
        errors_pd = pd.DataFrame(data={"elbos": metrics["global_elbos"], "iteration": np.arange(num_iterations)})
        errors_pd["run_id"] = int(run_id)
        dfs.append(errors_pd)

    # Concat all dataframes to each other
    metrics = pd.concat(dfs)
    metrics.describe()

    # Group-by iteration id and calculate mean and standard deviation so that we get a table "iteration,mean,std"
    grouped_by_iteration = metrics[["iteration", "elbos"]].groupby("iteration")
    losses_stats = grouped_by_iteration.mean()
    losses_stats.rename(columns={"elbos": "mean"}, inplace=True)
    losses_stats["std"] = grouped_by_iteration.std()["elbos"]
    losses_stats["method_name"] = method_name
    return losses_stats

In [16]:
method2elbos = {
    name: baseline_extract_elbos(d, name) 
    for name, d in baselines 
    if name in ["Pyro AutoGuide", "BBVI"]
}
for n, sweep_dir in sdvi_dirs:
    method2elbos[n] = sdvi_extract_elbos(sweep_dir, n)


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'TODO: Fill out path to results dir.'

In [None]:
line_width = 4
fig, ax = plt.subplots(figsize=(8,4))
for name, errors in method2elbos.items():
    ixs = errors.index
    if name == "SDVI":
        ixs = ixs * 500 + 2000
    elif name == "Pyro AutoGuide":
        ixs = ixs * 1000
    elif name == "BBVI":
        ixs = ixs * 1000
    ax.plot(ixs, errors["mean"], alpha=1.0, lw=line_width, **algo2config[name])
    ax.fill_between(ixs, errors["mean"]-errors["std"], errors["mean"]+errors["std"], alpha=0.3, **algo2config[name])

ax.axhline(
    torch.log(global_marginal_likelihood),
    linestyle="--",
    color="black",
    lw=line_width,
    label=r"$\log Z$"
)
ax.set_xlabel("Computational Cost")
ax.set_ylabel("ELBO")
# ax.set_ylim((0.0, 0.05))
# ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlim((1000, 100000))

ax.grid(True)
ax.legend(loc="lower right")
fig.tight_layout()
fig.savefig("figures/model_selection_elbos.pdf")

# GP Kernel Learning

## GP Posterior Predictive Plot

In [17]:
DATA_FILE = "../data/airline/airline.csv"

In [21]:
def get_id_with_median_lppd(sweep_dir, key="lppds", fname="exclusive_kl_results.csv"):
    # Loop through the different run idxs
    experiment_dirs = [f.path for f in os.scandir(sweep_dir) if f.is_dir()]
    print(experiment_dirs)
    dfs = []
    for experiment_dir in experiment_dirs:
        run_id = os.path.basename(experiment_dir)
        if run_id == ".submitit":
            continue
        # Load metrics csv into pandas dataframe
        metrics = pd.read_csv(os.path.join(experiment_dir, fname))

        # Create pandas dataframe with iteration,error as columns and append to dfs
        errors_pd = pd.DataFrame(data={key: metrics[key].iloc[-1], "iteration": [0]})
        errors_pd["run_id"] = int(run_id)
        dfs.append(errors_pd)

    # Concat all dataframes to each other
    metrics = pd.concat(dfs)
    print(metrics)
    median_iteration_id = metrics.loc[metrics["lppds"] == metrics["lppds"].median()]["run_id"][0]
    return median_iteration_id

In [28]:
# sweep_dir = gp_sdvi_result_dirs[1][1]
sweep_dir = "../experiments/gp_grammar_sdvi/17-25-14"
median_lppd_id = get_id_with_median_lppd(sweep_dir)
# median_lppd_id = 1
with open(os.path.join(sweep_dir, str(median_lppd_id), "sdvi.pickle"), "rb") as f:
    sdvi = pickle.load(f).to("cuda")

['../experiments/gp_grammar_sdvi/17-25-14\\0']
       lppds  iteration  run_id
0 -31.334535          0       0


In [29]:
def load_data(data_path):
    data = torch.tensor(np.loadtxt(data_path, delimiter=","))
    xs = data[:, 0]
    ys = data[:, 1]
    xs -= xs.min()
    xs /= xs.max()
    ys -= ys.mean()
    ys *= 4 / (ys.max() - ys.min())

    # Keep 10 % of data for validation.
    val_ix = round(xs.size(0) * 0.9)
    xs, xs_val = xs[:val_ix], xs[val_ix:]
    ys, ys_val = ys[:val_ix], ys[val_ix:]

    return xs, ys, xs_val, ys_val

def extract_posterior_kernels(posterior_samples):
    post_kernels = [trace.nodes["_RETURN"]["value"] for trace in posterior_samples]
    for ix in range(len(post_kernels)):
        for name, s in posterior_samples[ix].iter_stochastic_nodes():
            if name in ["std", "y"] or "kernel_type" in name:
                continue

            if isinstance(post_kernels[ix], gp.kernels.Sum) or isinstance(
                post_kernels[ix], gp.kernels.Product
            ):
                names = name.split(".")
                kern_mod = post_kernels[ix]._modules[names[0]]
                for jx in range(len(names) - 2):
                    kern_mod = kern_mod._modules[names[jx + 1]]
                setattr(kern_mod, names[-1], s["value"])
            else:
                setattr(post_kernels[ix], name, s["value"])
    return post_kernels

def gp_analytic_posterior(
    kernel_fn: gp.kernels.Kernel,
    X: torch.tensor,
    new_xs: torch.tensor,
    y: torch.tensor,
    noise: torch.tensor,
    jitter: float,
    full_cov: bool = False,
):
    N = X.size(0).to("cuda")
    Kff = kernel_fn(X).contiguous().to("cuda")
    Kff = Kff.type(X.dtype).clone().to("cuda")
    Kff.view(-1)[:: N + 1] += jitter + torch.pow(noise, 2)
    Lff = torch.linalg.cholesky(Kff).to("cuda")

    gp_post_mean, gp_post_cov = gp.util.conditional(
        new_xs, X, kernel_fn, y, Lff=Lff, jitter=jitter, full_cov=full_cov
    )
    if full_cov:
        M = new_xs.size(0).to("cuda")
        gp_post_cov = gp_post_cov.contiguous().to("cuda")
        gp_post_cov.view(-1, M * M)[:, :: M + 1] += torch.pow(noise, 2)
    else:
        gp_post_cov = gp_post_cov + torch.pow(noise, 2).to("cuda")
    return gp_post_mean, gp_post_cov

In [30]:
def plot_posterior_samples(
    posterior_samples, 
    X, 
    y, 
    X_val, 
    y_val, 
    jitter=1e-6, 
    with_noise=True, 
    num_eval_points=500,
    start_ix_data=0,
    figsize=(15, 10)
):
    post_kernels = extract_posterior_kernels(posterior_samples)
    if with_noise:
        noises = [trace.nodes["std"]["value"] for trace in posterior_samples]
    else:
        noises = [torch.tensor(0.0) for _ in range(len(posterior_samples))]

    new_xs = torch.linspace(0, 1, num_eval_points)
    posterior_fs = torch.zeros((len(post_kernels), new_xs.size(0)))
    for ix in range(len(post_kernels)):
        with torch.no_grad():
            gp_post_mean, gp_post_cov = gp_analytic_posterior(
                post_kernels[ix],
                X,
                new_xs,
                y,
                noises[ix],
                jitter,
                full_cov=True,
            )
        posterior_fs[ix, :] = (
            dist.MultivariateNormal(gp_post_mean, gp_post_cov).sample().detach()
        )

    f_post_mean = posterior_fs.mean(dim=0)
    f_post_std = posterior_fs.std(dim=0)

    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(new_xs, f_post_mean, color="#0072B2", lw=2)
    ax.fill_between(
        new_xs,
        f_post_mean - 2 * f_post_std,
        f_post_mean + 2 * f_post_std,
        color="#0072B2",
        alpha=0.2,
    )
    num_samples_to_plot = min(0, len(post_kernels))
    for ix in range(num_samples_to_plot):
        ax.plot(new_xs, posterior_fs[ix, :], color="#009E73", alpha=0.3, linestyle="-")

    ax.scatter(X, y, label="Observed Data", color="black")
    ax.scatter(X_val, y_val, label="Held-Out Data", marker="x")
    ax.set_xlim((X[start_ix_data] - 0.01, 1.01))
    ax.set_ylim((y[start_ix_data:].min() - 0.1, ax.get_ylim()[1]))
    ax.legend(loc="upper left")
    
    ax.set_xlabel("Month")
    ax.set_ylabel("Number of Passengers")
    ax.set_xticks(())
    ax.set_yticks(())
    return fig, ax

In [31]:
X, y, X_val, y_val = load_data(DATA_FILE)

In [35]:
!conda activate sdvi

In [None]:
posterior_samples = sdvi.sample_posterior_predictive(1)

In [None]:
fig, ax = plot_posterior_samples(
    posterior_samples, 
    X, 
    y, 
    X_val, 
    y_val, 
    num_eval_points=1000, 
    start_ix_data=70, 
    figsize=(8, 5)
)
fig.savefig("figures/gp_posterior_predictive_median_lppd.png")