In [1]:
import numpy as np
from clt import BinaryCLT, load_dataset
import time

[-1, 6, 0, 5, 13, 7, 2, 6, 6, 7, 14, 10, 8, 14, 12, 12]
[[[-0.1580114  -1.92305371]
  [-0.1580114  -1.92305371]]

 [[-0.08514189 -2.50570499]
  [-0.87785954 -0.53729228]]

 [[-0.15248793 -1.95594512]
  [-1.43702211 -0.27132899]]

 [[-0.21637736 -1.63696999]
  [-1.64768811 -0.21380532]]

 [[-0.36397099 -1.1871529 ]
  [-2.68473667 -0.07067911]]

 [[-0.31299803 -1.31397873]
  [-2.12432577 -0.12728071]]

 [[-0.14277343 -2.01703383]
  [-1.12149374 -0.39421777]]

 [[-0.18551196 -1.77595836]
  [-2.17196823 -0.12098541]]

 [[-0.04623899 -3.09696239]
  [-1.23699809 -0.34284844]]

 [[-0.72476915 -0.66249458]
  [-3.76479366 -0.02344509]]

 [[-0.10575824 -2.29901265]
  [-1.02686311 -0.44336732]]

 [[-0.34527636 -1.23108591]
  [-2.16320683 -0.12211779]]

 [[-0.09057788 -2.44649241]
  [-1.01981627 -0.44732076]]

 [[-0.23140666 -1.57705179]
  [-2.50844619 -0.08489864]]

 [[-0.15478766 -1.94209675]
  [-1.47998337 -0.25830642]]

 [[-0.01277874 -4.36635481]
  [-0.61243786 -0.78094694]]]
[[0 0 0 0 1 0 0 

# Load Data

In [2]:
DIR = "nltcs"
train = load_dataset(DIR, "nltcs.train.data")
test  = load_dataset(DIR, "nltcs.test.data")

# Fit Chow-Liu Tree

In [13]:
clt = BinaryCLT(train, root=0, alpha=0.01)
print(clt.get_tree())

[-1, 6, 0, 5, 13, 7, 2, 6, 6, 7, 14, 10, 8, 14, 12, 12]


# Log CPTs

In [15]:
log_cpts = clt.get_log_params()
print(log_cpts.shape)

(16, 2, 2)


# Average log likelihoods

In [16]:
def mean_log_prob(model, data, exhaustive=False):
    return model.log_prob(data, exhaustive=exhaustive).mean()

print(mean_log_prob(clt, train))
print(mean_log_prob(clt, test))

-6.76005596449873
-6.7590743429784395


# Efficient-vs-exhaustive sanity check

In [21]:
rng = np.random.default_rng(0)
queries = test.copy()
mask = rng.random(queries.shape) < 0.3
queries[mask] = np.nan

ll_fast = clt.log_prob(queries, exhaustive=False)
ll_slow = clt.log_prob(queries, exhaustive=True)
print(np.allclose(ll_fast, ll_slow))

True


# Sampling & plausibility check

In [11]:
samples = clt.sample(1000)
print(mean_log_prob(clt, samples))

-6.656465523583656


# Runtime

In [22]:
t0 = time.time(); _ = clt.log_prob(queries, exhaustive=False); t1 = time.time()
t2 = time.time(); _ = clt.log_prob(queries, exhaustive=True);  t3 = time.time()
print("efficient =", t1-t0)
print("exhaustive =", t3-t2)

efficient = 0.5399618148803711
exhaustive = 0.9876770973205566
