In [4]:
# Primitive functions
import torch
import torch.distributions as dist

class Normal(dist.Normal):
    
    def __init__(self, alpha, loc, scale):
        
        if scale > 20.:
            self.optim_scale = scale.clone().detach().requires_grad_()
        else:
            self.optim_scale = torch.log(torch.exp(scale) - 1).clone().detach().requires_grad_()
        
        
        super().__init__(loc, torch.nn.functional.softplus(self.optim_scale))
    
    def Parameters(self):
        """Return a list of parameters for the distribution"""
        return [self.loc, self.optim_scale]
        
    def make_copy_with_grads(self):
        """
        Return a copy  of the distribution, with parameters that require_grad
        """
        
        ps = [p.clone().detach().requires_grad_() for p in self.Parameters()]
         
        return Normal(*ps)
    
    def log_prob(self, x):
        
        self.scale = torch.nn.functional.softplus(self.optim_scale)
        
        return super().log_prob(x)

def push_addr(alpha, value):
    return alpha + value

def vector(*arg):
    if len(arg) == 0:
        return torch.tensor([])
    # general case
    try:
        return torch.stack(arg, dim=0)
    
    # for concatenation of many vectors
    except RuntimeError:
        dim = len(arg[0].shape) - 1
        return torch.cat(arg, dim=dim)
    
    # for distribution objects
    except TypeError:
        return list(arg)

def get(v, i):
    if type(i) is str:
        return v[i]
    return v[int(i.item())]

def put(v, i, c):
    if type(i) is str:
        v[i] = c
    else:
        v[int(i.item())] = c
    return v

def first(v):
    return v[0]

def second(v):
    return v[1]

def last(v):
    return v[-1]

def append(v, c):
    return torch.cat((v, c.unsqueeze(dim=0)), dim=0)

def hashmap(*v):
    hm = {}
    i = 0
    while i < len(v):
        if type(v[i]) is str:
            hm[v[i]] = v[i+1]
        else:
            hm[v[i].item()] = v[i+1]
        i+=2
    return hm

def less_than(*args):
    return args[0] < args[1]

def rest(v):
    return v[1:]

def l(*arg):
    return list(arg)

def cons(x, l):
    return [x] + l  

def equal(x, y):
    return torch.tensor(x.item() == y.item())

def and_fn(x, y):
    return x and y

def or_fn(x, y):
    return x or y

def dirac(x):
    # approximate with a normal distribution but with very small std
    return torch.distributions.Normal(x, 0.001)

def greater_than(x, y):
    return x > y

def empty(v):
    return len(v) == 0

def peek(v):
    return v[-1]

funcprimitives = {
    "vector": vector,
    "get": get,
    "put": put,
    "first": first,
    "last": last,
    "append": append,
    "hash-map": hashmap,
    "less_than": less_than,
    "second": second,
    "rest": rest,
    "conj": append,
    "list": l,
    "cons": cons,
    "=": equal,
    "and": and_fn,
    "or": or_fn,
    "dirac": dirac,
    ">": greater_than,
    "empty?": empty,
    "peek": peek,
}

NameError: name 'dist' is not defined

In [2]:
class Env(dict):
    "An environment: a dict of {'var': val} pairs, with an outer Env."
    def __init__(self, parms=(), args=(), outer=None):
        self.update(zip(parms, args))
        self.outer = outer
    def find(self, var):
        "Find the innermost Env where var appears."
        return self if (var in self) else self.outer.find(var)

class Procedure(object):
    "A user-defined Scheme procedure."
    def __init__(self, parms, body, env):
        self.parms, self.body, self.env = parms, body, env
    def __call__(self, *args): 
        print("BODY", self.body)
        return evaluate_helper(self.body, Env(self.parms, args, self.env))

In [3]:
def standard_env() -> Env:
    "An environment with some Scheme standard procedures."
    env = Env()
    env.update({'alpha' : ''}) 
    env.update({'normal': dist.Normal,
       'sqrt': torch.sqrt,
       '+': torch.add,
       '-': torch.sub,
       '*': torch.mul,
       '/': torch.div,
       'beta': dist.Beta,
       'gamma': dist.Gamma,
       'dirichlet': dist.Dirichlet,
       'exponential': dist.Exponential,
       'discrete': dist.Categorical,
       'uniform': dist.Uniform,
       'uniform-continuous': dist.Uniform,
       'flip': dist.Bernoulli,
       'vector': funcprimitives["vector"],
       'get': funcprimitives["get"],
       'put': funcprimitives["put"],
       'hash-map': funcprimitives["hash-map"],
       'first': funcprimitives["first"],
       'second': funcprimitives["second"],
       'last': funcprimitives["last"],
       'append': funcprimitives["append"],
       'conj': funcprimitives["conj"],
       'cons': funcprimitives["cons"],
       'list': funcprimitives["list"],
       '<': funcprimitives["less_than"],
       'mat-mul': torch.matmul,
       'mat-repmat': lambda x, y, z: x.repeat((int(y.item()), int(z.item()))),
       'mat-add': torch.add,
       'mat-tanh': torch.tanh,
       'mat-transpose': torch.t,
       'rest': funcprimitives["rest"],
       '=' : funcprimitives["="],
       '>': funcprimitives[">"],
       'empty?': funcprimitives["empty?"],
       'log': torch.log,
       'peek': funcprimitives['peek'],
       })
    return env

global_env = standard_env()

NameError: name 'dist' is not defined

In [143]:
def evaluate_helper(x, env=global_env):
    print(x, type(x))
    try:
        print("ENV OF VAR ", env.find('var'))
        print()
    except:
        pass
    "Evaluate an expression in an environment."
    if type(x) is str and x != 'fn':    # variable reference
        try:
            return env.find(x)[x]
        except AttributeError:
            return x
    
    elif type(x) in [int, float]: # constant 
        return torch.tensor(float(x))
    
    elif type(x) is torch.Tensor:
        return x
    
    op, *args = x 
    
    if op == 'if':             # conditional
        (test, conseq, alt) = args
        exp = (conseq if evaluate_helper(test, env) else alt)
        return evaluate_helper(exp, env)
            
    elif op == 'fn':         # procedure
        (parms, body) = args
        
        env_inner = Env(outer=env)
        return Procedure(parms, body, env_inner)
    
    elif op == 'sample':
        d = evaluate_helper(args[0], env)
        return d.sample()
    
    elif op == 'observe':
        return evaluate_helper(args[-1])

    else:                        # procedure call
        proc = evaluate_helper(op, env) 
        vals = [evaluate_helper(arg, env) for arg in args]  
        return proc(*vals)

In [None]:
# Evaluating
i = 1
exp = daphne(['desugar-hoppl', '-i', 'programs/hw5_{}.daphne'.format(i)])