In [1]:
# evaluation based likelihood weighting

from daphne import daphne
from primitives import eval_env
import torch
import pickle

ENV = None

def likelihood_weighting_with_report(ast, L):
    result = likelihood_weighting(ast, L)
    returns = torch.stack(result['returns'])
    log_weights = torch.stack(result['log_weights'])
    unnormalized_weights = torch.exp(log_weights)
    normalized_weights = unnormalized_weights/torch.sum(unnormalized_weights)
    
    if returns.dim() > 1:
        weighted_returns = normalized_weights.unsqueeze(dim=1)*returns
        expectations = torch.sum(weighted_returns, dim=0)
    else:
        weighted_returns = normalized_weights*returns
        expectations = torch.sum(weighted_returns)
    
    
    print(expectations)

def likelihood_weighting_and_save(ast, L, filename):
    result = likelihood_weighting(ast, L)
    f = open(filename, 'wb')
    pickle.dump(result, f)
    f.close()


def likelihood_weighting(ast, L):
    """
    Generate likelihood weighted samples from a program desugared by Daphne.
    Args:
        ast: json FOPPL program
        L: number of samples to generate
    Return:
        L samples and likelihood weights in a dictionary
    """
    returns = [None]*L
    log_weights = [None]*L

    for l in range(L):
        returns[l], log_weights[l] = evaluate_program(ast)

    return {'returns': returns, 'log_weights': log_weights}


def evaluate_program(ast):
    """Evaluate a program as desugared by daphne, generate a sample from the prior
    Args:
        ast: json FOPPL program
    Returns: 
        samples with likelihood weights
    """
    global ENV
    ENV = eval_env()
    for defn in ast[:-1]:
        f_name = defn[1]
        f_v_is = defn[2]
        f_expr = defn[3]
        ENV.update({f_name: (f_v_is, f_expr)})
    l = {}
    sig = {'logW': 0}
    ret, sig = evaluate(ast[-1], l, sig)
    return ret, sig['logW']

# inspired by https://norvig.com/lispy.html
def evaluate(e, l, sig):
    # variable reference OR procedure OR just a string
    if isinstance(e, str):        
        # global procedures take precedence over locally defined vars
        if e in ENV:
            return ENV[e], sig
        elif e in l:
            return l[e], sig
        # could allow for hashmaps with string keys; for debugging setting this to fail
        else:
            assert False, "Unknown symbol: {}".format(e)
    # constant number
    elif isinstance(e, (int, float)):   
        return torch.tensor(float(e)), sig
    # if statements
    elif e[0] == 'if':
        (_, test, conseq, alt) = e
        test_value, sig = evaluate(test, l, sig)
        expr = (conseq if test_value else alt)
        return evaluate(expr, l, sig)
    # let statements
    elif e[0] == 'let':
        # get symbol
        symbol = e[1][0]
        # get value of e1
        value, sig = evaluate(e[1][1], l, sig)
        # evaluate e2 with value 
        return evaluate(e[2], {**l, symbol: value}, sig)
    # sample statement
    if e[0] == 'sample':
        dist, sig = evaluate(e[1], l, sig)
        # make sure it is a distribution object
        assert getattr(dist, '__module__', None).split('.')[:2] == ['torch', 'distributions']
        return dist.sample(), sig
    # observe statements
    if e[0] == 'observe':
        dist, sig = evaluate(e[1], l, sig) # get dist
        y, sig = evaluate(e[2], l, sig)    # get observed value
        # make sure it is a distribution object
        assert getattr(dist, '__module__', None).split('.')[:2] == ['torch', 'distributions']
        sig['logW'] = sig['logW'] + dist.log_prob(y)
        return y, sig
    # procedure call, either primitive or user-defined
    else:            
        proc, sig = evaluate(e[0], l, sig)
        # primitives are functions
        if callable(proc):
            args = [None]*len(e[1:])
            for i, arg in enumerate(e[1:]):
                result, sig = evaluate(arg, l, sig)
                args[i] = result
            result = proc(*args)
            return result, sig
        # user defined functions are not
        else:
            # as written in algorithm 6
            v_is, e0 = proc 
            assert(len(v_is) == len(e[1:]))
            c_is = [None]*len(e[1:])
            for i, arg in enumerate(e[1:]):
                result, sig = evaluate(arg, l, sig)
                c_is[i] = result
            l_proc = dict(zip(v_is, c_is))
            return evaluate(e0, {**l, **l_proc}, sig)

In [2]:
import json
f = open('asts/3.json', 'rb')
ast3 = json.load(f)
f.close()
#print(ast3)
#print('\n\n')

In [3]:
result = likelihood_weighting(ast3, 1)

In [4]:
result['log_weights']

[tensor(-41.2165)]

In [5]:
result = likelihood_weighting(ast3, 100000)

In [6]:
returns = torch.stack(result['returns']).float()
log_weights = torch.stack(result['log_weights'])

In [7]:
print(returns)
print(log_weights)

tensor([1., 0., 1.,  ..., 0., 1., 1.])
tensor([-1040.0190,  -273.2407, -1250.9554,  ..., -1312.2267,  -360.9262,
         -266.6620])


In [8]:
M = torch.max(log_weights)
m = torch.min(log_weights)

In [9]:
og_weights = torch.exp(log_weights)/torch.sum(torch.exp(log_weights))
M_weights = torch.exp(log_weights-M)/torch.sum(torch.exp(log_weights-M))
m_weights = torch.exp(log_weights-m)/torch.sum(torch.exp(log_weights-m))

In [10]:
print(torch.nonzero(og_weights).size())
print(torch.nonzero(M_weights).size())
print(torch.nonzero(m_weights).size())

torch.Size([15131, 1])
torch.Size([15250, 1])
torch.Size([99999, 1])


In [11]:
print(m_weights)

tensor([nan, nan, nan,  ..., nan, nan, nan])


In [12]:
torch.sum(og_weights*returns)

tensor(0.0829)

In [13]:
torch.sum(M_weights*returns)

tensor(0.0829)