In [1]:
import numpy as np
import torch
import pyro
from pyro import distributions as dst
import matplotlib.pylab as plt

In [2]:
import inspect, ast, astor

In [3]:
def model(X):
    N, D = X.shape
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(0.,10.))
        scale = pyro.sample('scale', dst.LogNormal(0.,4.))
        with pyro.plate('N', N):
            X = pyro.sample('obs', dst.Normal(loc,scale))
    return X

In [4]:
def model2(X):
    N, D = X.shape
    K = D
    locloc = 0.
    locscale = 1.
    scaleloc = 0.
    scalescale = 1.
    cov_factor_loc = torch.zeros(K,D)
    cov_factor_scale = torch.ones(K,D)*10
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(locloc, locscale))
        cov_diag = pyro.sample('scale', dst.LogNormal(scaleloc, scalescale))
        with pyro.plate('K', K):
            cov_factor = pyro.sample('cov_factor', dst.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
    with pyro.plate('N', N):
        X = pyro.sample('obs', dst.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag))
    return X

In [5]:
# move the N plate out from the D plate to top level
# replace Normal with LowRankMultivariateNormal
# under D plate, add K plate, add cov_factor under K plate
# under D plate, transpose it
# add locloc, locscale, scaleloc, scalescale, cov_factor_loc and cov_factor_scale
# add K = D

In [6]:
def addRank(model):
    """
    Given an independent Gaussian model, changes it to a LowRankMultivariateNormal
    """
    source = inspect.getsource(model)
    tree = ast.parse(source)
    
    return newmodel

In [7]:
source = inspect.getsource(model)

In [8]:
source

"def model(X):\n    N, D = X.shape\n    with pyro.plate('D', D):\n        loc = pyro.sample('loc', dst.Normal(0.,10.))\n        scale = pyro.sample('scale', dst.LogNormal(0.,4.))\n        with pyro.plate('N', N):\n            X = pyro.sample('obs', dst.Normal(loc,scale))\n    return X\n"

In [9]:
tree = ast.parse(source)

In [10]:
tree

<_ast.Module at 0x7fb74a60f2b0>

In [11]:
tree.body

[<_ast.FunctionDef at 0x7fb74a60f4a8>]

In [12]:
print(astor.to_source(tree))

def model(X):
    N, D = X.shape
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(0.0, 10.0))
        scale = pyro.sample('scale', dst.LogNormal(0.0, 4.0))
        with pyro.plate('N', N):
            X = pyro.sample('obs', dst.Normal(loc, scale))
    return X



In [13]:
model.__code__

<code object model at 0x7fb7b542a660, file "<ipython-input-3-f9372a3a3c8f>", line 1>

In [14]:
class NumberFinder(ast.NodeVisitor):
    def visit_Num(self, node):
        print("Found number literal", node.n)

def find_numbers(tree):
    NumberFinder().visit(tree)

In [15]:
find_numbers(tree)

Found number literal 0.0
Found number literal 10.0
Found number literal 0.0
Found number literal 4.0


In [16]:
class FuncFinder(ast.NodeVisitor):
    def visit_Call(self, node):
        print("Found function", astor.dump_tree(node.func), '\nwith arguments \n', astor.dump_tree(node.args), "\n")

def find_func(tree):
    FuncFinder().visit(tree)

In [17]:
find_func(tree)

Found function Attribute(value=Name(id='pyro'), attr='plate') 
with arguments 
 [Str(s='D'), Name(id='D')] 

Found function Attribute(value=Name(id='pyro'), attr='sample') 
with arguments 
 [Str(s='loc'), Call(func=Attribute(value=Name(id='dst'), attr='Normal'), args=[Num(n=0.0), Num(n=10.0)], keywords=[])] 

Found function Attribute(value=Name(id='pyro'), attr='sample') 
with arguments 
 [Str(s='scale'),
    Call(func=Attribute(value=Name(id='dst'), attr='LogNormal'), args=[Num(n=0.0), Num(n=4.0)], keywords=[])] 

Found function Attribute(value=Name(id='pyro'), attr='plate') 
with arguments 
 [Str(s='N'), Name(id='N')] 

Found function Attribute(value=Name(id='pyro'), attr='sample') 
with arguments 
 [Str(s='obs'),
    Call(func=Attribute(value=Name(id='dst'), attr='Normal'), args=[Name(id='loc'), Name(id='scale')], keywords=[])] 



In [18]:
source2 = inspect.getsource(model2)

In [19]:
tree2 = ast.parse(source2)

In [20]:
print(astor.dump(tree2))

Module(
    body=[
        FunctionDef(name='model2',
            args=arguments(args=[arg(arg='X', annotation=None)],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=[],
                kwarg=None,
                defaults=[]),
            body=[
                Assign(targets=[Tuple(elts=[Name(id='N'), Name(id='D')])],
                    value=Attribute(value=Name(id='X'), attr='shape')),
                Assign(targets=[Name(id='K')], value=Name(id='D')),
                Assign(targets=[Name(id='locloc')], value=Num(n=0.0)),
                Assign(targets=[Name(id='locscale')], value=Num(n=1.0)),
                Assign(targets=[Name(id='scaleloc')], value=Num(n=0.0)),
                Assign(targets=[Name(id='scalescale')], value=Num(n=1.0)),
                Assign(targets=[Name(id='cov_factor_loc')],
                    value=Call(func=Attribute(value=Name(id='torch'), attr='zeros'),
                        args=[Name(id='K'), Name(id='D')],
 

  """Entry point for launching an IPython kernel.


In [21]:
find_func(tree2)

Found function Attribute(value=Name(id='torch'), attr='zeros') 
with arguments 
 [Name(id='K'), Name(id='D')] 

Found function Attribute(value=Name(id='torch'), attr='ones') 
with arguments 
 [Name(id='K'), Name(id='D')] 

Found function Attribute(value=Name(id='pyro'), attr='plate') 
with arguments 
 [Str(s='D'), Name(id='D')] 

Found function Attribute(value=Name(id='pyro'), attr='sample') 
with arguments 
 [Str(s='loc'),
    Call(func=Attribute(value=Name(id='dst'), attr='Normal'),
        args=[Name(id='locloc'), Name(id='locscale')],
        keywords=[])] 

Found function Attribute(value=Name(id='pyro'), attr='sample') 
with arguments 
 [Str(s='scale'),
    Call(func=Attribute(value=Name(id='dst'), attr='LogNormal'),
        args=[Name(id='scaleloc'), Name(id='scalescale')],
        keywords=[])] 

Found function Attribute(value=Name(id='pyro'), attr='plate') 
with arguments 
 [Str(s='K'), Name(id='K')] 

Found function Attribute(value=Name(id='pyro'), attr='sample') 
with argumen

In [22]:
class AddUnderPlate(ast.NodeTransformer):
    """
    Identifies plate with specified name, and appends specified code to its body
    """
    def __init__(self, plate_name, insert_code):
        self.plate_name = plate_name
        self.insert_code = insert_code
        super().__init__()
    def visit_With(self, node):
        # we want to visit child nodes, so visit it
        self.generic_visit(node)
        withexpr = node.items[0].context_expr
        if withexpr.func.attr == 'plate' and withexpr.args[0].s == self.plate_name:
            self.plate = node
            if self.insert_code is None:
                return
            else:
                newnode = node
                newnode.append(self.insert_code)
                return newnode
        else:
            return node

In [23]:
class PastePlate(ast.NodeVisitor):
    def __init__(self, plate):
        self.plate = plate
        super().__init__()
    def visit_FunctionDef(self, node):
        self.generic_visit(node)
        #append the whole plate statement last in the body of the function
        node.body.append(self.plate)
        # switch plate and return statement so it's last
        node.body[-1], node.body[-2] = node.body[-2], node.body[-1]

def cutpasteplate(tree, plate_name):
    """
    Cut a plate from any nesting level and paste it in the top level of the function,
    e.g. when transforming an independent Gaussian to a factor model, where features are no longer independent
    """
    cut = AddUnderPlate(plate_name, insert_code=None) # cut plate by adding nothing under plane
    cut.visit(tree)
    paste = PastePlate(cut.plate)
    paste.visit(tree)

In [24]:
print(astor.dump_tree(tree))

Module(
    body=[
        FunctionDef(name='model',
            args=arguments(args=[arg(arg='X', annotation=None)],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=[],
                kwarg=None,
                defaults=[]),
            body=[
                Assign(targets=[Tuple(elts=[Name(id='N'), Name(id='D')])],
                    value=Attribute(value=Name(id='X'), attr='shape')),
                With(
                    items=[
                        withitem(
                            context_expr=Call(func=Attribute(value=Name(id='pyro'), attr='plate'),
                                args=[Str(s='D'), Name(id='D')],
                                keywords=[]),
                            optional_vars=None)],
                    body=[
                        Assign(targets=[Name(id='loc')],
                            value=Call(func=Attribute(value=Name(id='pyro'), attr='sample'),
                                args=[Str(s='

In [25]:
cutpasteplate(tree, 'N')

In [26]:
print(astor.to_source(tree))

def model(X):
    N, D = X.shape
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(0.0, 10.0))
        scale = pyro.sample('scale', dst.LogNormal(0.0, 4.0))
    with pyro.plate('N', N):
        X = pyro.sample('obs', dst.Normal(loc, scale))
    return X



In [27]:
class SampleFinder(ast.NodeVisitor):
    def visit_Call(self, node):
        if node.func.attr == 'sample':
            print(astor.dump_tree(astor.to_source(node.args[0])), '\nwith arguments \n', astor.dump_tree(node.args), "\n")

def find_sample(tree):
    SampleFinder().visit(tree)

In [28]:
class ChangeObservationModel(ast.NodeTransformer):
    """
    Identifies sampling site with name 'obs', and replaces its distribution with a specified one
    """
    def __init__(self, new_obs_model):
        self.new_obs_model = new_obs_model
        super().__init__()
    def visit_Call(self, node):
        # we want to visit child nodes, so visit it
        self.generic_visit(node)
        if node.func.attr == 'sample' and node.args[0].s == 'obs':
            newnode = node
            newnode.args[1] = self.new_obs_model
            return newnode
        else:
            return node
        
def change_observation_model_to_LowRankMultivariateNormal(tree):
    lowrank_normal_obs_model = ast.Call(func=ast.Attribute(value=ast.Name(id='dst'), attr='LowRankMultivariateNormal'),\
                                args=[ast.Name(id='loc')],\
                                keywords=[ast.keyword(arg='cov_factor', value=ast.Name(id='cov_factor')),\
                                ast.keyword(arg='cov_diag', value=ast.Name(id='cov_diag'))]) 
    ChangeObservationModel(lowrank_normal_obs_model).visit(tree)

In [29]:
change_observation_model_to_LowRankMultivariateNormal(tree)

In [30]:
print(astor.to_source(tree))

def model(X):
    N, D = X.shape
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(0.0, 10.0))
        scale = pyro.sample('scale', dst.LogNormal(0.0, 4.0))
    with pyro.plate('N', N):
        X = pyro.sample('obs', dst.LowRankMultivariateNormal(loc,
            cov_factor=cov_factor, cov_diag=cov_diag))
    return X



In [31]:
                        With(\
                            items=[\
                                withitem(\
                                    context_expr=Call(func=Attribute(value=Name(id='pyro'), attr='plate'),\
                                        args=[Str(s='K'), Name(id='K')],\
                                        keywords=[]),\
                                    optional_vars=None)],\
                            body=[\
                                Assign(targets=[Name(id='cov_factor')],\
                                    value=Call(func=Attribute(value=Name(id='pyro'), attr='sample'),\
                                        args=[Str(s='cov_factor'),\
                                            Call(func=Attribute(value=Name(id='dst'), attr='Normal'),\
                                                args=[Name(id='cov_factor_loc'), Name(id='cov_factor_scale')],\
                                                keywords=[])],\
                                        keywords=[]))])

NameError: name 'With' is not defined

In [None]:
find_sample(tree2)

In [None]:
find_sample(tree)

In [None]:
plt.scatter(*model(np.random.randn(1000,2)).detach().numpy().T, alpha=.1);

In [None]:
plt.scatter(*model2(np.random.randn(1000,2)).detach().numpy().T, alpha=.1)