In [79]:
import json
import torch
from primitives import eval_env

In [80]:
f = open('graphs/1.json', 'rb')
g1 = json.load(f)
f.close()
f = open('graphs/2.json', 'rb')
g2 = json.load(f)
f.close()
f = open('graphs/3.json', 'rb')
g3 = json.load(f)
f.close()
f = open('graphs/4.json', 'rb')
g4 = json.load(f)
f.close()

In [81]:
ENV = eval_env()

# utilities for graph based sampling

def deterministic_evaluate(e, l, sig=None):
    # variable reference OR procedure OR just a string
    if isinstance(e, str):        
        # global procedures take precedence over locally defined vars
        if e in ENV:
            return ENV[e], sig
        elif e in l:
            return l[e], sig
        # could allow for hashmaps with string keys; for debugging setting this to fail
        else:
            assert False, "Unknown symbol: {}".format(e)
    # constant number
    elif isinstance(e, (int, float)):   
        return torch.tensor(float(e)), sig
    # if statements
    elif e[0] == 'if':
        (_, test, conseq, alt) = e
        exp = (conseq if deterministic_evaluate(test, l)[0] else alt)
        return deterministic_evaluate(exp, l)
    # let statements
    elif e[0] == 'let':
        # get symbol
        symbol = e[1][0]
        # get value of e1
        value, _ = deterministic_evaluate(e[1][1], l)
        # evaluate e2 with value 
        return deterministic_evaluate(e[2], {**l, symbol: value})
    # sample statement
    if e[0] == 'sample':
        dist = deterministic_evaluate(e[1], l)[0]
        # make sure it is a distribution object
        assert getattr(dist, '__module__', None).split('.')[:2] == ['torch', 'distributions']
        return dist.sample(), sig
    # obsere statements
    # TODO: change this, maybe in this hw or for hw3
    if e[0] == 'observe':
        dist = deterministic_evaluate(e[1], l)[0] # get dist
        y = deterministic_evaluate(e[2], l)[0]    # get observed value
        # make sure it is a distribution object
        assert getattr(dist, '__module__', None).split('.')[:2] == ['torch', 'distributions']
        # TODO: do something with observed value
        return dist.sample(), sig
    # procedure call, either primitive or user-defined
    else:
        result = deterministic_evaluate(e[0], l)
        proc, sig = result
        # primitives are functions
        if callable(proc):
            args = [deterministic_evaluate(arg, l)[0] for arg in e[1:]]
            result, sig = proc(*args), sig
            return result, sig
        # user defined functions are not
        else:
            # as written in algorithm 6
            v_is, e0 = proc 
            assert(len(v_is) == len(e[1:]))
            c_is = [deterministic_evaluate(arg, l)[0] for arg in e[1:]]
            l_proc = dict(zip(v_is, c_is))
            return deterministic_evaluate(e0, {**l, **l_proc})

# inspired by https://www.geeksforgeeks.org/python-program-for-topological-sorting/
# TODO: update to python 3.9 and use graphlib instead
def topological_sort(A, V):
    visited = {v:False for v in V}
    stack = []
    
    for v in V:
        if visited[v] == False:
            topo_sort_util(v, A, V, visited, stack)
            
    return stack

def topo_sort_util(v, A, V, visited, stack):
    
    visited[v] = True
    
    if v in A:
        for adj_v in A[v]:
            if visited[adj_v] == False:
                topo_sort_util(adj_v, A, V, visited, stack)
            
    stack.insert(0, v)

# Graph 3

In [None]:
graph = g3
print(g3)

In [82]:
def sample_from_prior(graph):
    # get contents of graph
    fn_defs = graph[0]
    V = graph[1]['V']
    A = graph[1]['A']
    P = graph[1]['P']
    Y = graph[1]['Y']
    ret_vals = graph[2]
    
    # deal with fn_defs
    global ENV
    ENV = eval_env()
    for defn in fn_defs.items():
        f_name = defn[0]
        f_v_is = defn[1][1]
        f_expr = defn[1][2]
        ENV.update({f_name: (f_v_is, f_expr)})
    
    # get sorted V
    sorted_V = topological_sort(A, V)

    # compute each value in order
    l = {}
    for v in sorted_V:
        task, expr = P[v][0], P[v][1]
        if task == "sample*":
            dist, _ = deterministic_evaluate(expr, l)
            l.update({v: dist.sample()})

    return l

In [83]:
X0 = sample_from_prior(graph)

In [84]:
print(X0)

{'sample3': tensor(0.0404), 'sample1': tensor(2.1421), 'sample2': tensor(-6.0278), 'sample4': tensor(-5.6020), 'sample6': tensor([0.0235, 0.9545, 0.0221]), 'sample11': tensor(1), 'sample13': tensor(1), 'sample19': tensor(1), 'sample15': tensor(1), 'sample9': tensor(1), 'sample7': tensor(1), 'sample17': tensor(1), 'sample0': tensor(-0.7293), 'sample5': tensor(0.3977)}


In [85]:
# gibbs step
X = X0
P = graph[1]['P']
unif = torch.distributions.Uniform(0,1)
Xkeys = list(X.keys())

In [92]:
x = Xkeys[5]
task, expr = P[x][0], P[x][1]
assert task == "sample*", "Found observed variable in X???"
q, _ = deterministic_evaluate(expr, X)
Xprime = X.copy()
Xprime[x] = q.sample()
print(x)
print(q)
print(X)
print(Xprime)

sample11
Categorical(probs: torch.Size([3]))
{'sample3': tensor(0.1025), 'sample1': tensor(1.0373), 'sample2': tensor(-0.5321), 'sample4': tensor(-1.5413), 'sample6': tensor([0.5471, 0.3383, 0.1146]), 'sample11': tensor(0), 'sample13': tensor(0), 'sample19': tensor(0), 'sample15': tensor(0), 'sample9': tensor(0), 'sample7': tensor(2), 'sample17': tensor(2), 'sample0': tensor(-11.5455), 'sample5': tensor(2.3074)}
{'sample3': tensor(0.1025), 'sample1': tensor(1.0373), 'sample2': tensor(-0.5321), 'sample4': tensor(-1.5413), 'sample6': tensor([0.5471, 0.3383, 0.1146]), 'sample11': tensor(1), 'sample13': tensor(0), 'sample19': tensor(0), 'sample15': tensor(0), 'sample9': tensor(0), 'sample7': tensor(2), 'sample17': tensor(2), 'sample0': tensor(-11.5455), 'sample5': tensor(2.3074)}


In [94]:
print(x)
for k in X.keys():
    print(k)
    print(X[k] == Xprime[k])

sample11
sample3
tensor(True)
sample1
tensor(True)
sample2
tensor(True)
sample4
tensor(True)
sample6
tensor([True, True, True])
sample11
tensor(False)
sample13
tensor(True)
sample19
tensor(True)
sample15
tensor(True)
sample9
tensor(True)
sample7
tensor(True)
sample17
tensor(True)
sample0
tensor(True)
sample5
tensor(True)


In [97]:
P = graph[1]['P']
log_alpha = 0.0
task, expr = P[x][0], P[x][1]
assert task == "sample*", "Found observed variable in X???"
q, _ = deterministic_evaluate(expr, X)
qprime, _ = deterministic_evaluate(expr, Xprime)
print(qprime.log_prob(X[x]))
print(q.log_prob(Xprime[x]))

tensor(-0.6031)
tensor(-1.0838)


In [104]:
for x in X.keys():
    V_x = A[x] + [x]
    XUY = {**X, **Y}
    XprimeUY = {**Xprime, **Y}
    for v in V_x:
        
    

['observe18', 'observe12', 'observe14', 'observe10', 'observe8', 'observe16', 'observe20', 'sample3']
True
['observe18', 'observe12', 'observe14', 'observe10', 'observe8', 'observe16', 'observe20', 'sample1']
True
['observe18', 'observe12', 'observe14', 'observe10', 'observe8', 'observe16', 'observe20', 'sample2']
True
['observe18', 'observe12', 'observe14', 'observe10', 'observe8', 'observe16', 'observe20', 'sample4']
True
['sample17', 'sample7', 'sample9', 'sample15', 'sample19', 'sample13', 'sample11', 'sample6']
True
['observe12', 'sample11']
True
['observe14', 'sample13']
True
['observe20', 'sample19']
True
['observe16', 'sample15']
True
['observe10', 'sample9']
True
['observe8', 'sample7']
True
['observe18', 'sample17']
True
['observe18', 'observe12', 'observe14', 'observe10', 'observe8', 'observe16', 'observe20', 'sample0']
True
['observe18', 'observe12', 'observe14', 'observe10', 'observe8', 'observe16', 'observe20', 'sample5']
True


In [101]:
Y = {k: torch.tensor(v).float() for k,v in Y.items()}