# Inference Playground

This notebook is a playground for exploring approximate inference algorithms from PCFG + PCFG intersection.

$$
p(x) = \frac{1}{Z}\, p_1(x) p_2(x)  \quad\text{where}\quad Z = \sum_x p_1(x) p_2(x)
$$

$$
x^{(1)}, \ldots, x^{(M)} \overset{\mathrm{i.i.d}}{\sim} q
$$

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd, numpy as np
import nest_asyncio; nest_asyncio.apply()
from genparse import CFGLM, Float
from genparse.util import display_table
from genparse.steer import run, BruteForceGlobalProductOfExperts

In [3]:
MAX_LENGTH = 15
N_PARTICLES = 5_000

In [5]:
lm1 = CFGLM.from_string("""

0.45: S -> a S a
0.45: S -> b S b
0.1: S ->

""")

lm2 = CFGLM.from_string("""

0.5: S -> a b S
0.5: S ->

""")

# strings of just a's (with a geometric decay on length)
lm2 = CFGLM.from_string("""

20: S -> a a a a a a
1: S -> a a
1: S ->

""")

In [6]:
ref = BruteForceGlobalProductOfExperts(lm1, lm2, MAX_LENGTH)
ref.target.project(''.join)

0,1
key,value
aaaaaa▪,0.5569136745607334
aa▪,0.13750954927425513
▪,0.30557677616501144


In [7]:
particles = run(
    lm1,
    lm2,
    MAX_LENGTH = MAX_LENGTH,
    n_particles = N_PARTICLES,
    METHOD = 'is',
    #METHOD = 'smc',    
)

In [8]:
w = Float.chart()
for p in particles:
    ys = tuple(p.ys)
    numerator = lm1(ys) * lm2(ys)
    if numerator > 0:
        w[ys] += numerator * np.exp(-p.Q)
w = w.normalize()

In [9]:
w.project(''.join).trim().compare(ref.target.project(''.join).trim()).sort_values('key', ascending=False)

Unnamed: 0,key,self,other,metric
1,▪,0.34647,0.305577,0.040893
0,aa▪,0.110635,0.13751,0.026874
2,aaaaaa▪,0.542895,0.556914,0.014019


## Rejection Sampling

In [None]:
from genparse.inference import TraceSWOR
from arsenal import iterview

In [None]:
tracer = TraceSWOR()
R = Float.chart()
for _ in iterview(range(500)):
    with tracer:
        y1, p1 = lm1.sample(draw=tracer, prob=True)
        y2, p2 = lm2.sample(draw=tracer, prob=True)
        if y1 == y2:
            print(y1)
            R[y1] += p1 * p2
R = R.normalize()
R.sort()

In [None]:
# truncate the reference distribution to the support set of the sample; 
# renamed the keys to handle the minor discrepancy in the EOS symbol
tmp = ref.target.filter(lambda k: k[:-1] in R).normalize().sort()
tmp.project(lambda k: k[:-1]).compare(R).sort_values('key')