In [1]:
import numpy as np
import torch
import pyro
from pyro import distributions as dst
import matplotlib.pylab as plt

In [2]:
import inspect, ast, astor

In [3]:
def model(X):
    N, D = X.shape
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(0.,10.))
        scale = pyro.sample('scale', dst.LogNormal(0.,4.))
        with pyro.plate('N', N):
            X = pyro.sample('obs', dst.Normal(loc,scale))
    return X

In [4]:
def model2(X):
    N, D = X.shape
    K = D
    locloc = 0.
    locscale = 1.
    scaleloc = 0.
    scalescale = 1.
    cov_factor_loc = torch.zeros(K,D)
    cov_factor_scale = torch.ones(K,D)*10
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(locloc, locscale))
        cov_diag = pyro.sample('scale', dst.LogNormal(scaleloc, scalescale))
        with pyro.plate('K', K):
            cov_factor = pyro.sample('cov_factor', dst.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
    with pyro.plate('N', N):
        X = pyro.sample('obs', dst.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag))
    return X

In [5]:
class AddToPlate(ast.NodeTransformer):
    """
    Identifies plate with specified name, and appends specified code to its body. 
    If code is None, deletes plate and assigns to self.plate
    """
    def __init__(self, plate_name, code):
        self.plate_name = plate_name
        self.code = code
        super().__init__()
    def visit_With(self, node):
        # we want to visit child nodes, so visit it
        self.generic_visit(node)
        withexpr = node.items[0].context_expr
        if withexpr.func.attr == 'plate' and withexpr.args[0].s == self.plate_name:
            self.plate = node
            if self.code is None:
                return
            else:
                newnode = node
                newnode.body.append(self.code)
                return newnode
        else:
            return node

In [6]:
class AddToFunctionBody(ast.NodeTransformer):
    """
    Adds code to either the beginning (default) or end of a function
    """
    def __init__(self, code, head=True):
        self.code = code
        self.head = head
        super().__init__()
    def visit_FunctionDef(self, node):
        self.generic_visit(node)
        if self.head:
            #insert the whole code statement first in the body of the function
            node.body = [self.code] + node.body
        else:
            #append the whole code statement last in the body of the function
            node.body.append(self.code)
            # switch code and return statement so it's last
            node.body[-1], node.body[-2] = node.body[-2], node.body[-1]
        return node

In [7]:
#class AddPlateUnderPlate(ast.NodeTransformer):
#    def __init__(self, plate, plate_name):
#        self.plate = plate # plate to add
#        self.plate_name = plate_name # plate under which to add
#        super().__init__()
#    def visit_With(self, node):
#        self.generic_visit(node)
#        withexpr = node.items[0].context_expr
#        if withexpr.func.attr == 'plate' and withexpr.args[0].s == self.plate_name:
#            node.body.append(self.plate)
#        else:
#            return node
        
#def cutpasteplate(tree, plate_name):
#    """
#    Cut a plate from any nesting level and paste it in the top level of the function,
#    e.g. when transforming an independent Gaussian to a factor model, where features are no longer independent
#    """
#    cut = ModifyPlate(plate_name, insert_code=None) # cut plate by adding nothing under plane
#    cut.visit(tree)
#    paste = AddPlateTopLevel(cut.plate)
#    paste.visit(tree)
    

In [8]:
class ChangeObservationModel(ast.NodeTransformer):
    """
    Identifies sampling site with name 'obs', and replaces its distribution with a specified one
    """
    def __init__(self, new_obs_model):
        self.new_obs_model = new_obs_model
        super().__init__()
    def visit_Call(self, node):
        # we want to visit child nodes, so visit it
        self.generic_visit(node)
        if node.func.attr == 'sample' and node.args[0].s == 'obs':
            newnode = node
            newnode.args[1] = self.new_obs_model
            return newnode
        else:
            return node
        
def change_observation_model_to_LowRankMultivariateNormal(tree):
    lowrank_normal_obs_model = ast.Call(func=ast.Attribute(value=ast.Name(id='dst', ctx=ast.Load()), attr='LowRankMultivariateNormal', ctx=ast.Load()),\
                                args=[ast.Name(id='loc', ctx=ast.Load())],\
                                keywords=[ast.keyword(arg='cov_factor', value=ast.Name(id='cov_factor', ctx=ast.Load())),\
                                ast.keyword(arg='cov_diag', value=ast.Name(id='cov_diag', ctx=ast.Load()))]) 
    ChangeObservationModel(lowrank_normal_obs_model).visit(tree)

In [9]:
# the following functions modify observation models

def addFactor(model):
    """
    Given an independent Gaussian, changes it to a LowRankMultivariateNormal with one factor
    Given a lowrank Gaussian, adds a factor
    """
    # read model code
    source = inspect.getsource(model)
    # parse into abstract syntax tree
    tree = ast.parse(source)
    
    # check if model is independent Gaussian or factor
    # if independent Gaussian:
    
    # move the N plate out from the D plate to top level
    cut = AddToPlate('N', code=None) # delete observation plate, assign to cut.plate
    cut.visit(tree)
    paste = AddToFunctionBody(code=cut.plate, head=False) # add plate at the bottom of the function
    paste.visit(tree)
    
    # replace Normal observation model with LowRankMultivariateNormal
    change_observation_model_to_LowRankMultivariateNormal(tree)
    
    # under D plate, add K plate, add cov_factor under K plate
    cov_factor = ast.Assign(targets=[ast.Name(id='cov_factor', ctx=ast.Store())],\
                                    value=ast.Call(func=ast.Attribute(value=ast.Name(id='pyro', ctx=ast.Load()), attr='sample', ctx=ast.Load()),\
                                        args=[ast.Str(s='cov_factor'),\
                                            ast.Call(func=ast.Attribute(value=ast.Name(id='dst', ctx=ast.Load()), attr='Normal', ctx=ast.Load()),\
                                                args=[ast.Name(id='cov_factor_loc', ctx=ast.Load()), ast.Name(id='cov_factor_scale', ctx=ast.Load())],\
                                                keywords=[])],\
                                        keywords=[]))
    AddToPlate('D', code=cov_factor).visit(tree)
    
    # under D plate, transpose it
    
    # add locloc, locscale, scaleloc, scalescale, cov_factor_loc and cov_factor_scale
    cov_factor_loc = ast.Assign(targets=[ast.Name(id='cov_factor_scale', ctx=ast.Store())],
                    value=ast.BinOp(
                        left=ast.Call(func=ast.Attribute(value=ast.Name(id='torch', ctx=ast.Load()), attr='ones'),
                            args=[ast.Name(id='K', ctx=ast.Load()), ast.Name(id='D', ctx=ast.Load())],
                            keywords=[]),
                        op=ast.Mult,
                        right=ast.Num(n=10))),

    AddToFunctionBody(code = cov_factor_scale).visit(tree)
    AddToFunctionBody(code = cov_factor_loc).visit(tree)
    AddToFunctionBody(code = ast.Assign(targets=[ast.Name(id='locloc', ctx=ast.Store())], value=ast.Num(n=0.0))).visit(tree)
    AddToFunctionBody(code = ast.Assign(targets=[ast.Name(id='locloc', ctx=ast.Store())], value=ast.Num(n=0.0))).visit(tree)
    AddToFunctionBody(code = ast.Assign(targets=[ast.Name(id='locloc', ctx=ast.Store())], value=ast.Num(n=0.0))).visit(tree)
    # add K = D
    ast.fix_missing_locations(tree)
    
    #print(astor.dump_tree(tree))
    print(astor.to_source(tree))
    code = compile(tree, '', 'exec')
    context = {}
    exec(code, globals(), context)
    return context['model']

def removeFactor(model):
    return model

def splitCluster(model):
    """
    Given an independent Gaussian, changes it to a mixture of two independent Gaussians
    Given a lowrank Gaussian, changes it to a mixture of two lowrank Gaussians with shared covariance
    """
    return model
    
def mergeCluster(model):
    return model
    
def scaleCovariances(model):
    """
    Given a mixture with shared covariances, adds a scaling parameter to each component covariance
    """
    return model
    
def decoupleCovariances(model):
    """
    Given a mixture with shared covariances, endows each component with its own covariance
    """
    return model
    
def shareCovariances(model):
    return model

def addLayer(model):
    return model

In [10]:
#source = inspect.getsource(model)
#tree = ast.parse(source)
print(astor.to_source(ast.parse(inspect.getsource(model))))

def model(X):
    N, D = X.shape
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(0.0, 10.0))
        scale = pyro.sample('scale', dst.LogNormal(0.0, 4.0))
        with pyro.plate('N', N):
            X = pyro.sample('obs', dst.Normal(loc, scale))
    return X



In [11]:
newmodel = addFactor(model)

NameError: name 'cov_factor_scale' is not defined

In [None]:
newmodel(np.random.randn(1000,2))

In [None]:
# because newmodel was built interactively, there's no source code anywhere to read
inspect.getsource(newmodel)

In [None]:
model(np.random.randn(1000,2))

In [None]:
print(astor.to_source(ast.parse(inspect.getsource(newmodel))))

In [None]:
class NumberFinder(ast.NodeVisitor):
    def visit_Num(self, node):
        print("Found number literal", node.n)

def find_numbers(tree):
    NumberFinder().visit(tree)

In [None]:
find_numbers(tree)

In [None]:
class FuncFinder(ast.NodeVisitor):
    def visit_Call(self, node):
        print("Found function", astor.dump_tree(node.func), '\nwith arguments \n', astor.dump_tree(node.args), "\n")

def find_func(tree):
    FuncFinder().visit(tree)

In [None]:
find_func(tree)

In [None]:
source2 = inspect.getsource(model2)

In [None]:
tree2 = ast.parse(source2)

In [None]:
print(astor.dump(tree2))

In [None]:
find_func(tree2)

In [None]:
print(astor.dump_tree(tree))

In [None]:
print(astor.to_source(tree))

In [None]:
class SampleFinder(ast.NodeVisitor):
    def visit_Call(self, node):
        if node.func.attr == 'sample':
            print(astor.dump_tree(astor.to_source(node.args[0])), '\nwith arguments \n', astor.dump_tree(node.args), "\n")

def find_sample(tree):
    SampleFinder().visit(tree)

In [None]:
plt.scatter(*model(np.random.randn(1000,2)).detach().numpy().T, alpha=.1);

In [None]:
plt.scatter(*model2(np.random.randn(1000,2)).detach().numpy().T, alpha=.1)