# Imputation experiments

In [None]:
import glob
import numpy as np
import os
import torch
from matplotlib import pyplot
from treecat_exp.util import TEST, load_object
%matplotlib inline
%config InlineBackend.rc = {'figure.facecolor': (1, 1, 1, 1)}
# %config InlineBackend.figure_format = 'svg'

The following results were generated by `cleanup.py` by running
```sh
python main.py
```

In [None]:
results = []
paths = glob.glob(os.path.join(TEST, "cleanup.*.pkl"))
paths.sort()
print("Loading {} experimental results:".format(len(paths)))
for path in paths:
    metrics = load_object(path)
    results.append(metrics)
    args = metrics["args"]
    print("  {} {} {}".format(args.delete_percent,
                              args.dataset,
                              args.model))

In [None]:
for dataset in sorted(set(m["args"].dataset for m in results)):
    pyplot.figure(figsize=(9, 6))
    for model in sorted(set(m["args"].model for m in results
                            if m["args"].dataset == dataset)):
        ms = [m for m in results if m["args"].dataset == dataset if m["args"].model == model]
        ms.sort(key=lambda m: m["args"].delete_percent)
        # pyplot.violinplot([m["losses"] for m in ms])
        X = [m["args"].delete_percent / 100 for m in ms]
        Y = [np.mean(m["losses"]) for m in ms]
        p, = pyplot.plot(X, Y, label=model)
        for f in range(len(ms[0]["losses"])):
            Y = [m["losses"][f] for m in ms]
            pyplot.plot(X, Y, color=p.get_color(), lw=1, alpha=0.3)
    X = list(sorted(set(m["args"].delete_percent / 100
                        for m in results if m["args"].dataset == dataset)))
    pyplot.yscale("log")
    pyplot.xticks(X, labels=["{:0.2g}".format(x) for x in X])
    pyplot.title("Accuracy of imputing {} data (lower is better)".format(dataset))
    pyplot.xlabel("Missing Probability")
    pyplot.ylabel("Error (loss / cell)")
    pyplot.legend(loc="best")
    pyplot.tight_layout()

In [None]:
for dataset in sorted(set(m["args"].dataset for m in results)):
    for model in sorted(set(m["args"].model for m in results
                            if m["args"].dataset == dataset
                            if "posterior_predictive" in m)):
        ms = [m for m in results if m["args"].dataset == dataset if m["args"].model == model]
        ms.sort(key=lambda m: m["args"].delete_percent)
        X = [m["args"].delete_percent / 100 for m in ms]
        Y = np.array([m["posterior_predictive"].mean().item()
                      / sum(m["num_cleaned"]) * m["num_rows"]
                      for m in ms])
        dY = np.array([m["posterior_predictive"].std().item()
                       / sum(m["num_cleaned"]) * m["num_rows"]
                       for m in ms])
        
        pyplot.figure(figsize=(9, 6))
        pyplot.fill_between(X, Y - dY, Y + dY, alpha=0.3)
        pyplot.plot(X, Y, 'k--')
        pyplot.xticks(X, labels=["{:0.2g}".format(x) for x in X])
        pyplot.title("Accuracy of {} imputing {} data (higher is better)"
                     .format(model, dataset))
        pyplot.xlabel("Missing Probability")
        pyplot.ylabel("Posterior Predictive  (nats / imputed cell)")
        pyplot.tight_layout()