In [35]:
import torch
import math
import operator as op

def pure_hashmap_update(d, k, v):
    if isinstance(k, torch.Tensor):
        k = k.item()
    d2 = d.copy()
    d2.update({k:v})
    return d2
    

# inspired by https://norvig.com/lispy.html
def eval_env():
	env = {}
	env.update({
		'+': op.add,
       	'-': op.sub,
        '*': op.mul,
        '/': op.truediv,
        '>': op.gt,
        '<': op.lt,
        '>=': op.ge,
        '<=': op.le, 
        '=': op.eq,
        'sqrt': torch.sqrt,
        'vector': lambda *x: torch.tensor(x),
        'hash-map': lambda *x : dict(zip([i.item() if isinstance(i, torch.Tensor) else i for i in x[::2]], x[1::2])),
        'get': lambda x, y: x[y.long()] if isinstance(x, torch.Tensor) else x[y.item() if isinstance(y, torch.Tensor) else y],
        'put': lambda x, y, z: torch.cat((x[:y.long()], torch.tensor([z]), x[y.long()+1:])) if isinstance(x, torch.Tensor) else pure_hashmap_update(x,y,z),
        'append' : lambda x, y: torch.cat((x, torch.tensor([y]))),
        'first' : lambda x: x[0],
        'last' : lambda x: x[-1],
        'remove': lambda x, y : torch.cat((x[:y.long()], x[y.long()+1:])) if isinstance(x, torch.Tensor) else {i:x[i] for i in x if i != y}
        })


	return env

In [36]:
ENV = None

def evaluate_program(ast, return_sig=False):
    """Evaluate a program as desugared by daphne, generate a sample from the prior
    Args:
        ast: json FOPPL program
    Returns: sample from the prior of ast
    """
    global ENV
    ENV = eval_env()
    """
    something here to deal with defns
    
    """
    l = {}
    ret, sig = evaluate(ast)
    return (ret, sig) if return_sig else ret

# inspired by https://norvig.com/lispy.html
def evaluate(e, l, sig=None):
    # variable reference
    if isinstance(e, str):        
        return ENV[e], sig
    # constant number
    elif isinstance(e, (int, float)):   
        return torch.tensor(float(e)), sig
    # root of tree
    # THIS MUST BE FIXED TO ACCOUNT FOR DEFNs!!!!
    elif isinstance(e, list) and len(e) == 1:  
        return evaluate(e[0], l)
    # if statements
    elif e[0] == 'if':
        (_, test, conseq, alt) = e
        print(test)
        exp = (conseq if evaluate(test, l)[0] else alt)
        return evaluate(exp)
    # procedure call
    else:
        proc, sig = evaluate(e[0], l)
        args = [evaluate(arg, l)[0] for arg in e[1:]]
        result, sig = proc(*args), sig
        return result, sig

In [37]:
evaluate_program([['+', 5, 2]])

tensor(7.)

In [38]:
evaluate_program([['sqrt', 2]])

tensor(1.4142)

In [39]:
evaluate_program([['*', 3.0, 8.0]])

tensor(24.)

In [40]:
evaluate_program([['/', 2, 8]])

tensor(0.2500)

In [41]:
evaluate_program([['/', 2, ['+', 3, ['*', 3, 2.7]]]])

tensor(0.1802)

In [42]:
evaluate_program([['vector', 2, 3, 4, 5]])

tensor([2., 3., 4., 5.])

In [43]:
evaluate_program([['get', ['vector', 2, 3, 4, 5], 2]])

tensor(4.)

In [44]:
evaluate_program([['put', ['vector', 2, 3, 4, 5], 2, 3]])

tensor([2., 3., 3., 5.])

In [45]:
evaluate_program([['first', ['vector', 2, 3, 4, 5]]])

tensor(2.)

In [46]:
evaluate_program([['last', ['vector', 2, 3, 4, 5]]])

tensor(5.)

In [47]:
evaluate_program([['append', ['vector', 2, 3, 4, 5], 3.14]])

tensor([2.0000, 3.0000, 4.0000, 5.0000, 3.1400])

In [48]:
evaluate_program([['get', ['hash-map', 6, 5.3, 1, 3.2], 6]])

tensor(5.3000)

In [49]:
evaluate_program([['put', ['hash-map', 6, 5.3, 1, 3.2], 6, 2]])

{6.0: tensor(2.), 1.0: tensor(3.2000)}

In [50]:
evaluate_program([['remove', ['vector', 2, 3, 4, 5], 3]])

tensor([2., 3., 4.])

In [51]:
evaluate_program([['if', ['>', 1, 0], ['*', 2, 2], ['/', 2, 2]]])

tensor(4.)

In [52]:
evaluate_program([['if', ['<', 1, 0], ['*', 2, 2], ['/', 2, 2]]])

tensor(1.)

In [53]:
evaluate_program([['remove', ['hash-map', 6, 5.3, 1, 3.2], 6]])

{1.0: tensor(3.2000)}