This notebook demonstrates how models could be constructed directly in pangolin's IR. **Note that
as a user you almost certainly don't want to create models this way,** as it's a low-level notation. But it
is useful for understanding how models are represented under the hood. You would definitely want to understand this if you want to write a new algorithm that manipulates these models, or you want to implement a new backend.

In [1]:
from pangolin import ir

In [17]:
# Create an "Op" that is the constant of zero.
# repr() gives a very explicit representation
# str() or print() gives a compact math-like representation

d = ir.Constant(0)

print(repr(d))
print(d)

Constant(0)
0


In [18]:
# Create a RV that represents that constant

z = ir.RV(ir.Constant(0))

print(repr(z))
print(z)

RV(Constant(0))
0


In [19]:
# create a random variable that adds two other random variables

x = ir.RV(ir.Constant(0))
y = ir.RV(ir.Constant(2))
z = ir.RV(ir.Add(), x, y)

print(repr(z))
print(z)

RV(Add(), RV(Constant(0)), RV(Constant(2)))
add(0, 2)


In [43]:
# z has parents. you can access them.

print(f"{z.parents         = }")
print(f"{z.parents[0] == x = }")
print(f"{z.parents[1] == y = }")


z.parents         = (RV(Constant(-1)), RV(Constant(3)))
z.parents[0] == x = False
z.parents[1] == y = False


In [21]:
# visualize all upstream nodes
# (note this doesn't know about variable names)

ir.print_upstream(z)

shape | statement
----- | ---------
()    | a = 0
()    | b = 2
()    | c = add(a,b)


In [7]:
# Create a normal distribution Op.
# Note this takes no arguments!
# Again, repr() is explicit and str() and print() use a more compact math-like representation.

d = ir.Normal()
print(repr(d))
print(d)

Normal()
normal


In [8]:
# Create a standard-normal RV
loc = ir.RV(ir.Constant(0.5))
scale = ir.RV(ir.Constant(2.5))
z = ir.RV(ir.Normal(), loc, scale)
print(repr(z))
print(z)

RV(Normal(), RV(Constant(0.5)), RV(Constant(2.5)))
normal(0.5, 2.5)


In [9]:
# Again, you can print the upstream nodes

ir.print_upstream(z)

shape | statement
----- | ---------
()    | a = 0.5
()    | b = 2.5
()    | c ~ normal(a,b)


In [10]:
# Create vmapped normal distribution

d = ir.VMap(ir.Normal(), (0,0))
print(repr(d))
print(d)

VMap(Normal(), (0, 0))
vmap(normal, (0, 0))


In [30]:
# Create a diagonal normal RV

locs = ir.RV(ir.Constant([-1,1,2]))
scales = ir.RV(ir.Constant([3,4,5]))
z = ir.RV(ir.VMap(ir.Normal(), in_axes=(0,0)), locs, scales)

print(repr(z))
print(z)

RV(VMap(Normal(), (0, 0)), RV(Constant([-1,1,2])), RV(Constant([3,4,5])))
vmap(normal, (0, 0))([-1 1 2], [3 4 5])


In [31]:
# now, printing upstream gives useful shape information

ir.print_upstream(z)

shape | statement
----- | ---------
(3,)  | a = [-1 1 2]
(3,)  | b = [3 4 5]
(3,)  | c ~ vmap(normal, (0, 0))(a,b)


In [32]:
# Create a diagonal normal RV with shared scale

locs = ir.RV(ir.Constant([-1,1,2]))
scale = ir.RV(ir.Constant(3))
z = ir.RV(ir.VMap(ir.Normal(), in_axes=(0,None)), locs, scale)

print(repr(z))
print(z)

RV(VMap(Normal(), (0, None)), RV(Constant([-1,1,2])), RV(Constant(3)))
vmap(normal, (0, None))([-1 1 2], 3)


In [33]:
ir.print_upstream(z)

shape | statement
----- | ---------
(3,)  | a = [-1 1 2]
()    | b = 3
(3,)  | c ~ vmap(normal, (0, None))(a,b)


In [36]:
# Create a diagonal normal RV with shared loc and scale
# need axis_size argument (borrowed from jax)

loc = ir.RV(ir.Constant(-1))
scale = ir.RV(ir.Constant(3))
z = ir.RV(ir.VMap(ir.Normal(), in_axes=(None,None), axis_size=3), loc, scale)

print(repr(z))
print(z)

RV(VMap(Normal(), (None, None), 3), RV(Constant(-1)), RV(Constant(3)))
vmap(normal, (None, None), 3)(-1, 3)


In [35]:
ir.print_upstream(z)

shape | statement
----- | ---------
()    | a = -1
()    | b = 3
(3,)  | c ~ vmap(normal, (None, None), 3)(a,b)


In [41]:
# do something complicated
# the text format is very complicated

a = ir.RV(ir.Constant(2))
b = ir.RV(ir.Constant([2.2,3.3,4.4]))
c = ir.RV(ir.VMap(ir.Mul(),(None,0)), a, b)
d = ir.RV(ir.VMap(ir.Normal(), (0, 0)), b, c)

print(repr(d))
print(d)

RV(VMap(Normal(), (0, 0)), RV(Constant([2.2,3.3,4.4])), RV(VMap(Mul(), (None, 0)), RV(Constant(2)), RV(Constant([2.2,3.3,4.4]))))
vmap(normal, (0, 0))([2.2 3.3 4.4], vmap(mul, (None, 0))(2, [2.2 3.3 4.4]))


In [42]:
# but printing upstream gives a reasonable view

ir.print_upstream(d)

shape | statement
----- | ---------
(3,)  | a = [2.2 3.3 4.4]
()    | b = 2
(3,)  | c = vmap(mul, (None, 0))(b,a)
(3,)  | d ~ vmap(normal, (0, 0))(a,c)
