In [24]:
import torch 
from primitives import eval_env
from graph_utils import topological_sort
import distributions

ENV = eval_env()

import json

f = open('graphs/p1.json', 'rb')
g1 = json.load(f)
f.close()
f = open('graphs/p2.json', 'rb')
g2 = json.load(f)
f.close()
f = open('graphs/p3_hw3.json', 'rb')
g3 = json.load(f)
f.close()
f = open('graphs/p4.json', 'rb')
g4 = json.load(f)
f.close()



In [23]:
# black box variation inference

import torch 
from primitives import eval_env
from graph_utils import topological_sort
import datetime
import wandb
import time
import numpy as np

ENV = eval_env()

def init_Q(graph):
    """
    Initialize proposal distributions for bbvi
    Args:
        graph: graph dictionary
    Output:
        a dictionary Q containing initial proposal distributions
    """

    # get contents of graph
    fn_defs = graph[0]
    V = graph[1]['V']
    A = graph[1]['A']
    P = graph[1]['P']
    Y = graph[1]['Y']
    ret_vals = graph[2]    

    # deal with fn_defs
    global ENV
    ENV = eval_env()
    for defn in fn_defs.items():
        f_name = defn[0]
        f_v_is = defn[1][1]
        f_expr = defn[1][2]
        ENV.update({f_name: (f_v_is, f_expr)})

    # get sorted V
    sorted_V = topological_sort(A, V)

    Q = {}

    # find each q in order
    l = {}
    for v in sorted_V:
        task, expr = P[v][0], P[v][1]
        if task == "sample*":
            dist, _ = deterministic_evaluate(expr, l)
            l.update({v: dist.sample()})
            q = make_q(dist)
            Q.update({v: q})
        
    return Q


def make_q(d):
    
    # create a new q:
    if type(d) == torch.distributions.bernoulli.Bernoulli:
        probs = d.probs.clone().detach()
        q = distributions.Bernoulli(probs = probs)
    elif type(d) == torch.distributions.categorical.Categorical:
        probs = d.probs.clone().detach()
        q = distributions.Categorical(probs=probs)
    elif type(d) == torch.distributions.normal.Normal:
        loc = d.loc.clone().detach() 
        scale = d.scale.clone().detach()
        q = distributions.Normal(loc=loc, scale=scale)
    elif type(d) == torch.distributions.gamma.Gamma:
        concentration = d.concentration.clone().detach()
        rate = d.rate.clone().detach()
        q = distributions.Gamma(concentration=concentration, rate=rate)
    elif type(d) == torch.distributions.dirichlet.Dirichlet:
        concentration = d.concentration.clone().detach()
        q = distributions.Dirichlet(concentration=concentration)
    else:
        assert False, "Unknown distribution type: {}".format(type(d))

    return q


def bbvi_train(graph, T, L, base_string, Q=None, time_based=False, time_T=3600, lr=0.1, no_b=False, logging=True):
    """
    Trains BBVI proposal distributions.
    Args:
        graph: the graph denoting the problem
        T: number of outer training loops
        L: number of samples to use in gradient estimate
        Q: proposal distributions to start from... allows for 
           continuation of training from previously trained Q
    Returns:
        a new dictionary Q containing learned proposals
    """

    best_elbo = -np.inf
    
    if time_based:
        start = time.time()
    project_name = base_string
    if logging:
        wandb.init(project=project_name, entity="lone-duck")

    if Q is None:
        Q = init_Q(graph)

    # set up ENV
    fn_defs = graph[0]
    global ENV
    ENV = eval_env()
    for defn in fn_defs.items():
        f_name = defn[0]
        f_v_is = defn[1][1]
        f_expr = defn[1][2]
        ENV.update({f_name: (f_v_is, f_expr)})

    # get P, Y, sorted V
    P = graph[1]['P']
    V = graph[1]['V']
    Y = {k: torch.tensor(v).float() for k, v in graph[1]['Y'].items()}
    Xkeys = list(set(V) - set(Y.keys()))
    sorted_V = topological_sort(graph[1]['A'], V)

    # for t iterations, or for time_T seconds
    for t in range(T):
        # initiliaze lists for logW, G
        logWs = [None]*L
        Gs = [None]*L
        # "evaluate", i.e. sample from proposal and get logWs, Gs
        for l in range(L):
            logWs[l], Gs[l] = evaluation(P, Q, Y, sorted_V)
        # compute noisy elbo gradients
        g = elbo_gradients(logWs, Gs, L, Xkeys, no_b)
        # compute elbo
        elbo = torch.mean(torch.stack(logWs))
        if elbo > best_elbo:
            best_Q = Q
            best_elbo = elbo
            print("new best elbo: {}".format(elbo.item()))
        
        # do an update
        Q = update_Q(Q, g, t+1, lr)
        if logging:
            wandb.log({"ELBO": elbo})
        if time_based:
            if time.time() - start > time_T:
                break

    return best_Q


def update_Q(Q, g, t, lr):

    alpha = lr/torch.sqrt(torch.tensor(1.0*t))
    new_Q = {}

    for v in Q.keys():
        old_params = Q[v].Parameters()
        gradient = g[v]
        new_params = [(p + alpha*grad).clone().detach() for p, grad in zip(old_params, gradient)]
        new_Q[v] = type(Q[v])(*new_params, copy=True)

    return new_Q

    
def elbo_gradients(logWs, Gs, L, Xkeys, no_b):

    g = {}

    for v in Xkeys:
        # compute Fv's, Gv's for this v
        Fv = [None]*L
        Gv = [None]*L
        for l in range(L):
            if v in Gs[l]:
                Fv[l] = torch.stack(Gs[l][v]) * logWs[l]
                Gv[l] = torch.stack(Gs[l][v])
            else:
                Fv[l], Gv[l] = torch.tensor(0.), torch.tensor(0.)
        # both of shape (L, G.size())
        Fv = torch.stack(Fv)
        Gv = torch.stack(Gv)
        # compute b for this v
        if no_b:
            b = 0
        else:
            b = compute_b(Fv, Gv)
        g[v] = torch.sum(Fv - b*Gv, dim=0)/L

    return g

def compute_b(F, G):
    
    assert F.dim() < 3, "Need to ensure things work for higher dimensions"
    assert G.dim() < 3, "Need to ensure things work for higher dimensions"

    L, d = F.size()

    if d == 1:
        num = torch.sum((F - torch.mean(F))*(G - torch.mean(G)))/(L-1)
        den = torch.std(G)**2
    else:
        num = torch.tensor(0.)
        den = torch.tensor(0.)
        for i in range(d):
            num += torch.sum((F[:,i] - torch.mean(F[:,i]))*(G[:,i] - torch.mean(G[:,i])))/(L-1)
            den += torch.std(G[:,i])**2

    return num/den 


def evaluation(P, Q, Y, sorted_V):
    logW = 0
    G = {}
    l = {}
    for v in sorted_V:
        task, expr = P[v][0], P[v][1]
        if task == "sample*":
            # get prior dist
            d, _ = deterministic_evaluate(expr, l)
            # get proposal and grad-able proposal
            q = Q[v]
            q_with_grad = q.make_copy_with_grads()
            # take sample from proposal and add to l
            c = q.sample()
            l.update({v: c})
            # update logW
            with torch.no_grad():
                logW += d.log_prob(c) - q.log_prob(c)
            # get gradient
            log_prob_q = q_with_grad.log_prob(c)
            log_prob_q.backward()
            G[v] = [param.grad for param in q_with_grad.Parameters()]
        elif task == "observe*":
            # get prior dist, add log prob of observation to logW
            d, _ = deterministic_evaluate(expr, l)
            c = Y[v]
            with torch.no_grad():
                logW += d.log_prob(c)
            
    return logW, G


def deterministic_evaluate(e, l, sig=None):
    # variable reference OR procedure OR just a string
    if isinstance(e, str):        
        # global procedures take precedence over locally defined vars
        if e in ENV:
            return ENV[e], sig
        elif e in l:
            return l[e], sig
        # could allow for hashmaps with string keys; for debugging setting this to fail
        else:
            assert False, "Unknown symbol: {}".format(e)
    # constant number
    elif isinstance(e, (int, float)):   
        return torch.tensor(float(e)), sig
    # if statements
    elif e[0] == 'if':
        (_, test, conseq, alt) = e
        exp = (conseq if deterministic_evaluate(test, l)[0] else alt)
        return deterministic_evaluate(exp, l)
    # let statements
    elif e[0] == 'let':
        # get symbol
        symbol = e[1][0]
        # get value of e1
        value, _ = deterministic_evaluate(e[1][1], l)
        # evaluate e2 with value 
        return deterministic_evaluate(e[2], {**l, symbol: value})
    # sample statement
    if e[0] == 'sample':
        dist = deterministic_evaluate(e[1], l)[0]
        # make sure it is a distribution object
        assert getattr(dist, '__module__', None).split('.')[:2] == ['torch', 'distributions']
        return dist.sample(), sig
    # obsere statements
    # TODO: change this, maybe in this hw or for hw3
    if e[0] == 'observe':
        dist = deterministic_evaluate(e[1], l)[0] # get dist
        y = deterministic_evaluate(e[2], l)[0]    # get observed value
        # make sure it is a distribution object
        assert getattr(dist, '__module__', None).split('.')[:2] == ['torch', 'distributions']
        # TODO: do something with observed value
        return dist.sample(), sig
    # procedure call, either primitive or user-defined
    else:
        result = deterministic_evaluate(e[0], l)
        proc, sig = result
        # primitives are functions
        if callable(proc):
            args = [deterministic_evaluate(arg, l)[0] for arg in e[1:]]
            result, sig = proc(*args), sig
            return result, sig
        # user defined functions are not
        else:
            # as written in algorithm 6
            v_is, e0 = proc 
            assert(len(v_is) == len(e[1:]))
            c_is = [deterministic_evaluate(arg, l)[0] for arg in e[1:]]
            l_proc = dict(zip(v_is, c_is))
            return deterministic_evaluate(e0, {**l, **l_proc})



In [16]:
Q3 = bbvi_train(g3, T=10000, L=100, base_string="haha", time_based=True, \
                time_T=10*60, lr=0.001, no_b=True, logging=False)


Iteration: 0
Xs:
mean of logWX
-4.891833305358887
max of logWX
1.5829049348831177
min of logWX
-33.8731575012207
Ys:
mean of logWY
-110693.1015625
max of logWY
-11.218883514404297
min of logWX
-4585668.0
new best elbo:
tensor(-110697.9922)

Iteration: 1
Xs:
mean of logWX
-2.1422218414075837e+30
max of logWX
-1.924517724232744e+30
min of logWX
-2.4389934850784064e+30
Ys:
mean of logWY
-inf
max of logWY
-467.44989013671875
min of logWX
-inf


RuntimeError: invalid multinomial distribution (encountering probability entry < 0)

In [26]:
Q = init_Q(g3)
for key, value in Q.items():
    print(key)
    print(value)

sample3
Gamma(concentration: 1.0, rate: 0.9999999403953552)
sample1
Gamma(concentration: 1.0, rate: 0.9999999403953552)
sample2
Normal(loc: 0.0, scale: 10.0)
sample4
Normal(loc: 0.0, scale: 10.0)
sample6
Dirichlet(concentration: torch.Size([3]))
sample11
Categorical(logits: torch.Size([3]))
sample13
Categorical(logits: torch.Size([3]))
sample19
Categorical(logits: torch.Size([3]))
sample15
Categorical(logits: torch.Size([3]))
sample9
Categorical(logits: torch.Size([3]))
sample7
Categorical(logits: torch.Size([3]))
sample17
Categorical(logits: torch.Size([3]))
sample0
Normal(loc: 0.0, scale: 10.0)
sample5
Gamma(concentration: 1.0, rate: 0.9999999403953552)


In [27]:
alpha = torch.tensor(10.)
beta = torch.tensor(1.)
Q['sample3'] = distributions.Gamma(alpha, beta)
Q['sample1'] = distributions.Gamma(alpha, beta)
Q['sample5'] = distributions.Gamma(alpha, beta)
for key, value in Q.items():
    print(key)
    print(value)

sample3
Gamma(concentration: 10.0, rate: 0.9999999403953552)
sample1
Gamma(concentration: 10.0, rate: 0.9999999403953552)
sample2
Normal(loc: 0.0, scale: 10.0)
sample4
Normal(loc: 0.0, scale: 10.0)
sample6
Dirichlet(concentration: torch.Size([3]))
sample11
Categorical(logits: torch.Size([3]))
sample13
Categorical(logits: torch.Size([3]))
sample19
Categorical(logits: torch.Size([3]))
sample15
Categorical(logits: torch.Size([3]))
sample9
Categorical(logits: torch.Size([3]))
sample7
Categorical(logits: torch.Size([3]))
sample17
Categorical(logits: torch.Size([3]))
sample0
Normal(loc: 0.0, scale: 10.0)
sample5
Gamma(concentration: 10.0, rate: 0.9999999403953552)


In [29]:
Q3 = bbvi_train(g3, T=10000, L=1000, base_string="program-3-workingi-think", Q=Q,  time_based=True, \
                time_T=10*60, lr=0.1, no_b=True, logging=True)

new best elbo: -54.2967414855957
new best elbo: -47.37205505371094
new best elbo: -46.908111572265625
new best elbo: -46.227821350097656
new best elbo: -45.91982650756836
new best elbo: -45.1109619140625
new best elbo: -45.01436996459961
new best elbo: -44.99113845825195
new best elbo: -44.721702575683594
new best elbo: -43.193580627441406
new best elbo: -43.021026611328125
new best elbo: -42.897735595703125
new best elbo: -42.81101608276367
new best elbo: -42.2492790222168
new best elbo: -42.17620849609375


In [30]:
Q3 = bbvi_train(g3, T=10000, L=1000, base_string="program-3-for-real", Q=Q,  time_based=True, \
                time_T=60*60, lr=0.1, no_b=True, logging=True)

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
ELBO,▁▅▆▆▆▆▆▆▆▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇██▇█▇█████▇█▇█

0,1
ELBO,-43.0001


new best elbo: -54.12262725830078
new best elbo: -49.41806411743164
new best elbo: -48.05881881713867
new best elbo: -47.204193115234375
new best elbo: -47.1398811340332
new best elbo: -46.136531829833984
new best elbo: -45.60818099975586
new best elbo: -45.54535675048828
new best elbo: -45.35262680053711
new best elbo: -45.2760009765625
new best elbo: -44.29511642456055
new best elbo: -43.80609130859375
new best elbo: -43.663204193115234
new best elbo: -43.190738677978516
new best elbo: -43.10811233520508
new best elbo: -42.590362548828125
new best elbo: -42.29479217529297
new best elbo: -42.29214096069336
new best elbo: -42.19688034057617
new best elbo: -41.78074645996094
new best elbo: -41.583656311035156
new best elbo: -41.40888214111328
new best elbo: -40.95470428466797
new best elbo: -40.914371490478516
new best elbo: -40.54847717285156
new best elbo: -40.12153244018555
new best elbo: -40.092140197753906
new best elbo: -40.07582092285156
new best elbo: -39.98765182495117
new best

In [31]:
import pickle

In [32]:
with open("pickles/Q3", 'wb') as f:
    pickle.dump(Q3, f)

# Program 4

In [34]:
Q4 = bbvi_train(g4, T=10000, L=500, base_string="program-4-first-attempt", Q=None, time_based=True, \
                time_T=60*60, lr=0.1, no_b=False, logging=True)

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
ELBO,▁▁▂▂▃▃▃▄▄▄▄▄▅▄▅▅▅▅▅▅▅▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇█▇█

0,1
ELBO,-36.56314


new best elbo: -494.9900207519531
new best elbo: -478.92364501953125
new best elbo: -474.1729431152344
new best elbo: -473.4201354980469
new best elbo: -468.9661560058594
new best elbo: -467.98541259765625
new best elbo: -465.643798828125
new best elbo: -464.61077880859375
new best elbo: -462.5359191894531
new best elbo: -461.53887939453125
new best elbo: -461.4224853515625
new best elbo: -460.1002502441406
new best elbo: -459.70843505859375
new best elbo: -459.4884033203125
new best elbo: -459.27801513671875
new best elbo: -458.6831359863281
new best elbo: -458.4217529296875
new best elbo: -457.86798095703125
new best elbo: -457.08087158203125
new best elbo: -456.9406433105469
new best elbo: -456.7578125
new best elbo: -456.1025695800781
new best elbo: -455.6830749511719
new best elbo: -455.2585754394531
new best elbo: -455.1347351074219
new best elbo: -455.0375061035156
new best elbo: -454.815185546875
new best elbo: -454.416259765625
new best elbo: -454.380615234375
new best elbo: -

In [35]:
with open("pickles/Q4_100", 'wb') as f:
    pickle.dump(Q4, f)