# Inference Playground

This notebook is a playground for exploring approximate inference algorithms from PCFG + PCFG intersection.

$$
p(x) = \frac{1}{Z}\, p_1(x) p_2(x)  \quad\text{where}\quad Z = \sum_x p_1(x) p_2(x)
$$

$$
x^{(1)}, \ldots, x^{(M)} \overset{\mathrm{i.i.d}}{\sim} q
$$

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd, numpy as np
import nest_asyncio; nest_asyncio.apply()
from genparse import CFGLM, Float
from genparse.util import display_table
from genparse.steer import run, BruteForceGlobalProductOfExperts

In [3]:
MAX_LENGTH = 10
N_PARTICLES = 5_000

In [4]:
lm1 = CFGLM.from_string("""

0.45: S -> a S a
0.45: S -> b S b
0.1: S ->

""")

lm2 = CFGLM.from_string("""

0.5: S -> a b S
0.5: S ->

""")

In [5]:
ref = BruteForceGlobalProductOfExperts(lm1, lm2, MAX_LENGTH)
ref.target.project(''.join)

0,1
key,value
▪,1.0


In [8]:
particles = run(
    lm1,
    lm2,
    MAX_LENGTH = MAX_LENGTH,
    n_particles = N_PARTICLES,
    #METHOD = 'is',
    METHOD = 'smc',    
)

w = Float.chart()
for p in particles:
    w[tuple(p.ys)] += np.exp(p.weight)

w = w.normalize()
w.project(''.join)

0,1
key,value
▪,0.9999782457737258
abababababa▪,2.175422627407215e-05


In [27]:
particles = run(lm1, lm2, MAX_LENGTH = MAX_LENGTH, n_particles = 100, METHOD = 'is')

w = Float.chart()
for p in particles:
    print()
    p.P = lm1.cfg(p.ys) * lm2.cfg(p.ys)
    
    want_weight = p.P / np.exp(p.Q) if p.P > 0 else 0
    have_weight = np.exp(p.weight)
    
    print(p)
    print('weights:', have_weight, want_weight)

    have_P = np.exp(p.weight - p.Q)
    want_P = p.P
    print('numerator:', have_P, want_P, have_P / want_P, want_P/have_P)
    
    if p.P > 0:
        w[tuple(p.ys)] += p.P / np.exp(p.Q)

w = w.normalize()
w.project(''.join)


['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', '▪']
weights: 1.3168016639741699e-06 0
numerator: inf 0.0 inf 0.0

['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', '▪']
weights: 1.3168016639741699e-06 0
numerator: inf 0.0 inf 0.0

['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', '▪']
weights: 1.3168016639741699e-06 0
numerator: inf 0.0 inf 0.0

['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', '▪']
weights: 1.3168016639741699e-06 0
numerator: inf 0.0 inf 0.0

['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', '▪']
weights: 1.3168016639741699e-06 0
numerator: inf 0.0 inf 0.0

['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', '▪']
weights: 1.3168016639741699e-06 0
numerator: inf 0.0 inf 0.0

['▪']
weights: 0.27500000001881464 0.27499999981293355
numerator: 1.5124999990576704 0.05000000000056015 30.249999980814515 0.033057851260635734

['▪']
weights: 0.27500000001881464 0.27499999981293355
numerator: 1.5124999990576704 0.05000000000056015 30.2499999

0,1
key,value
▪,1.0
