# Partially Observed Clustering Plots

This notebook is for making plots similar to the ones from the partially observed clustering experiments in the paper.

In [None]:
import os

# We want to be in the project's root directory, not the "notebooks" directory.
os.chdir("..")

In [None]:
import pickle
import json

import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds

from posterior_matching.models.vade import PosteriorMatchingVADE
from posterior_matching.clustering import clustering_accuracy

In the cell below, set `RUN_DIR` to a run directory that was created by the `train_pm_vade.py` script.

In [None]:
RUN_DIR = "runs/pm-vade-mnist-20220305-170841"

with open(os.path.join(RUN_DIR, "train_state.pkl"), "rb") as fp:
    train_state = pickle.load(fp)
    
with open(os.path.join(RUN_DIR, "model_config.json"), "r") as fp:
    model_config = json.load(fp)

Here, we load the data. If using a dataset other than MNIST, change the name below.

In [None]:
ds = tfds.load("mnist", split="test").batch(32, drop_remainder=True)

def rescale(x):
    x["image"] = tf.cast(x["image"], tf.float32) / 255.0
    return x

ds = ds.map(rescale)

## Evaluation

You can change `NUM_SAMPLES` to determine how many samples are used when estimating the cluster probabilities.

In [None]:
NUM_SAMPLES = 50

def predict_fn(batch):
    model = PosteriorMatchingVADE.from_config(model_config)
    preds = model.partial_predict_cluster(batch["image"], batch["mask"], num_samples=NUM_SAMPLES)
    return jnp.argmax(preds, axis=-1)

predict_fn = jax.jit(hk.transform_with_state(predict_fn).apply)

In [None]:
observed_probs = np.linspace(0.0, 1.0, 41)

y_true = []
y_pred = {p: [] for p in observed_probs}

prng = hk.PRNGSequence(91)

for batch in ds.as_numpy_iterator():
    for i, p in enumerate(observed_probs):
            batch["mask"] = np.random.binomial(1, p, batch["image"].shape)
            preds, _ = predict_fn(train_state.params, train_state.state, prng.next(), batch)

            y_pred[p].append(preds)

            if i == 0:
                y_true.append(batch["label"])
                
y_true = np.hstack(y_true)
y_pred = {k: np.hstack(v) for k, v in y_pred.items()}

## Plot Clustering Accuracy

In [None]:
accs = [clustering_accuracy(y_true, v) for v in y_pred.values()]

In [None]:
sns.lineplot(x=observed_probs, y=accs, linewidth=3, color="#8da0cb")
sns.despine()

plt.xlabel("Percent Observedd")
plt.ylabel("Clustering Accuracy")
plt.grid(visible=True, axis="y")

plt.show()