# Debugging Feature Initialization

In [None]:
import numpy as np
import torch

import pyro
import pyro.poutine as poutine
from pyro.contrib.tabular import TreeCat
from pyro.optim import Adam
from treecat_exp.preprocess import load_data, partition_data
from treecat_exp.util import TRAIN

np.set_printoptions(precision=4)
pyro.set_rng_seed(1)
pyro.get_param_store().clear()
pyro.enable_validation(True)

from matplotlib import pyplot
%matplotlib inline
%config InlineBackend.rc = {'figure.facecolor': (1, 1, 1, 1)}
# %config InlineBackend.figure_format = 'svg'

In [None]:
args = type("Args", (), {})
args.dataset = "molecules"
args.max_num_rows = 9999999999
args.capacity = 8
args.init_size = 100000

Load data.

In [None]:
features, data, mask = load_data(args)
num_rows = len(data[0])
num_cells = num_rows * len(features)
print("loaded {} rows x {} features = {} cells".format(
      num_rows, len(features), num_cells))
print("\n".join(["Features:"] + [str(f) for f in features]))

Initialize the model.

In [None]:
pyro.get_param_store().clear()
model = TreeCat(features, args.capacity)
trainer = model.trainer("map", optim=Adam({}))
for batch_data, batch_mask in partition_data(data, mask, args.init_size):
    break
trainer.init(batch_data, batch_mask)
model.guide(batch_data, batch_mask)  # initializes groups
print("\n".join("{} = {}".format(key, value.data.cpu().numpy())
                for key, value in sorted(pyro.get_param_store().items())))

In [None]:
@torch.no_grad()
def plot_feature(name):
    (f, col), = [(f, col) for f, col in zip(features, batch_data) if f.name == name]
    guide_trace = poutine.trace(model.guide).get_trace(batch_data, batch_mask)
    with poutine.replay(trace=guide_trace):
        shared = f.sample_shared()
        with pyro.plate("components", args.capacity):
            group = f.sample_group(shared)

    print("data mean = {:0.3g}, std = {:0.3g}".format(col.float().mean(),
                                                      col.float().std()))
    print("\n".join("{} = {}".format(key, value.data.cpu().numpy())
                    for key, value in sorted(pyro.get_param_store().items())
                    if key.startswith("auto_{}_".format(name))))

    pyplot.figure(figsize=(9, 8), dpi=300)
    pyplot.hist(col.numpy(), alpha=0.3, label='data', bins=20, density=True)
    datatype = type(f).__name__
    if datatype in "Real":
        x0 = col.min().item()
        x1 = col.max().item()
        X = torch.linspace(x0, x1, 100)
    elif datatype == "Boolean":
        X = torch.arange(2.)
    elif datatype == "Discrete":
        X = torch.arange(f.cardinality)
    else:
        raise ValueError(type(f))
    for i in range(args.capacity):
        d = f.value_dist(group, i)
        Y = d.log_prob(X).exp().numpy()
        if datatype == "Real":
            pyplot.plot(X.numpy(), Y,
                        label='loc={:0.2g}, scale={:0.2g}'.format(d.loc, d.scale))
        else:
            pyplot.plot(X.numpy(), Y)
    pyplot.title(name)
    pyplot.legend(loc='best')
    pyplot.tight_layout()

In [None]:
FEATURE = "b1"
plot_feature(FEATURE)
pyplot.yscale('log')

Now load the trained model.

In [None]:
pyro.get_param_store().load("results/train/{}.treecatnuts.{}.model.pyro"
                            .format(args.dataset, args.capacity), map_location='cpu')

In [None]:
plot_feature(FEATURE)
pyplot.yscale('log')
# pyplot.ylim(1e-4, 1e2)