# Inference Playground

This notebook is a playground for exploring approximate inference algorithms from PCFG + PCFG intersection.

$$
p(x) = \frac{1}{Z}\, p_1(x) p_2(x)  \quad\text{where}\quad Z = \sum_x p_1(x) p_2(x)
$$

$$
x^{(1)}, \ldots, x^{(M)} \overset{\mathrm{i.i.d}}{\sim} q
$$

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd, numpy as np
import nest_asyncio; nest_asyncio.apply()
from genparse import CFGLM, Float
from genparse.util import display_table
from genparse.steer import run, BruteForceGlobalProductOfExperts

In [3]:
MAX_LENGTH = 15
N_PARTICLES = 5_000

In [4]:
lm1 = CFGLM.from_string("""

0.45: S -> a S a
0.45: S -> b S b
0.1: S ->

""")

lm2 = CFGLM.from_string("""

0.5: S -> a b S
0.5: S ->

""")

# strings of just a's (with a geometric decay on length)
lm2 = CFGLM.from_string("""

20: S -> a a a a a a
1: S -> a a
1: S ->

""")

In [5]:
ref = BruteForceGlobalProductOfExperts(lm1, lm2, MAX_LENGTH)
ref.target.project(''.join)

0,1
key,value
aaaaaa▪,0.5569136745607334
aa▪,0.13750954927425513
▪,0.30557677616501144


In [6]:
particles = run(
    lm1,
    lm2,
    MAX_LENGTH = MAX_LENGTH,
    n_particles = N_PARTICLES,
    METHOD = 'is',
    #METHOD = 'smc',    
)

In [7]:
w = Float.chart()
for p in particles:
    w[tuple(p.ys)] += np.exp(p.weight)

w = w.normalize()

In [8]:
w.project(''.join).trim().compare(ref.target.project(''.join).trim()).sort_values('key', ascending=False)

Unnamed: 0,key,self,other,metric
1,▪,0.345323,0.305577,0.039746
0,aa▪,0.133131,0.13751,0.004379
2,aaaaaa▪,0.521546,0.556914,0.035367


In [9]:
w = Float.chart()
for p in particles:
    print()
    p.P = lm1.cfg(p.ys) * lm2.cfg(p.ys)
    
    want_weight = p.P / np.exp(p.Q) if p.P > 0 else 0
    have_weight = np.exp(p.weight)
    
    print(p)
    print('weights:', have_weight, want_weight)

    have_P = np.exp(p.weight - p.Q)
    want_P = p.P
    print('numerator:', have_P, want_P, have_P / want_P, want_P/have_P)
    
    if p.P > 0:
        w[tuple(p.ys)] += p.P / np.exp(p.Q)

w = w.normalize()
w.project(''.join)


['a', 'a', 'a', 'a', 'a', 'a', '▪']
weights: 0.008540879578650499 0.008540879575242742
numerator: 0.008805628130802878 0.00828409090916865 1.062956482171991 0.9407722863278947

['a', 'a', 'a', 'a', 'a', 'a', '▪']
weights: 0.008540879578650499 0.008540879575242742
numerator: 0.008805628130802878 0.00828409090916865 1.062956482171991 0.9407722863278947

['a', 'a', 'a', 'a', 'a', 'a', '▪']
weights: 0.008540879578650499 0.008540879575242742
numerator: 0.008805628130802878 0.00828409090916865 1.062956482171991 0.9407722863278947

['a', 'a', 'a', 'a', 'a', 'a', '▪']
weights: 0.008540879578650499 0.008540879575242742
numerator: 0.008805628130802878 0.00828409090916865 1.062956482171991 0.9407722863278947

['a', 'a', 'a', 'a', 'a', 'a', '▪']
weights: 0.008540879578650499 0.008540879575242742
numerator: 0.008805628130802878 0.00828409090916865 1.062956482171991 0.9407722863278947

['a', 'a', 'a', 'a', 'a', 'a', '▪']
weights: 0.008540879578650499 0.008540879575242742
numerator: 0.00880562813080

0,1
key,value
aaaaaa▪,0.5215464334677079
▪,0.3453228599835308
aa▪,0.13313070654876136


## Rejection Sampling

In [11]:
from genparse.inference import TraceSWOR
from arsenal import iterview

In [12]:
tracer = TraceSWOR()
R = Float.chart()
for _ in iterview(range(500)):
    with tracer:
        y1, p1 = lm1.sample(draw=tracer, prob=True)
        y2, p2 = lm2.sample(draw=tracer, prob=True)
        if y1 == y2:
            print(y1)
            R[y1] += p1 * p2
R = R.normalize()
R.sort()

Output()

0,1
key,value
(),0.30557677623111484
"('a', 'a')",0.1375095492829182
"('a', 'a', 'a', 'a', 'a', 'a')",0.556913674485967


In [13]:
# truncate the reference distribution to the support set of the sample; 
# renamed the keys to handle the minor discrepancy in the EOS symbol
tmp = ref.target.filter(lambda k: k[:-1] in R).normalize().sort()
tmp.project(lambda k: k[:-1]).compare(R).sort_values('key')

Unnamed: 0,key,self,other,metric
2,(),0.305577,0.305577,6.61034e-11
0,"(a, a)",0.13751,0.13751,8.66307e-12
1,"(a, a, a, a, a, a)",0.556914,0.556914,7.476642e-11
