Lets go crazy and build matrix multiplication just from vmap

Currently we need to do some manual transposes and stuff due to the fact that vmap always vectorizes over the first dimension.

In [1]:
from pangolin import Given, d, t, I, IID, vmap, sample, E, P, var, std, cov, corr, makerv, jags_code
from matplotlib import pyplot as plt
import numpy as np
np.set_printoptions(formatter={'float': '{:6.2f}'.format}) # print nicely

A0 = np.array([[1,2,3],[4,5,6],[7,8,9]])
B0 = np.array([[5,2,2],[2,5,7],[2,2,0]])
x0 = np.array([1,2,3])
y0 = np.array([4,8,10])
A = makerv(A0)
B = makerv(B0)
x = makerv(x0)
y = makerv(y0)

elementwise = vmap(lambda a,b:a*b,[True,True])
inner = lambda a,b: t.sum(elementwise(a,b))
print('E(inner(x,y))',E(inner(x,y)))
print('x @ y        ',x0 @ y0)

print('')
mat_times_vec = vmap(inner,[True,False])
print('E(mat_times_vec(A,x))',E(mat_times_vec(A,x)))
print('E(A @ x)             ',E(A @ x))
print('A0 @ x0              ',A0 @ x0)

print('')
matT_times_mat = vmap(mat_times_vec,[False,True])
mat_times_mat  = lambda A,B: matT_times_mat(B.T,A) # vmap doesn't do things in right order
print('E(mat_times_mat(A,B))\n',E(mat_times_mat(A,B)))
print('E(A @ B              \n',E(A @ B))
print('A0 @ B0              \n',A0 @ B0)

E(inner(x,y)) 50.0
x @ y         50

E(mat_times_vec(A,x)) [ 14.00  32.00  50.00]
E(A @ x)              [ 14.00  32.00  50.00]
A0 @ x0               [14 32 50]

E(mat_times_mat(A,B))
 [[ 15.00  18.00  16.00]
 [ 42.00  45.00  43.00]
 [ 69.00  72.00  70.00]]
E(A @ B              
 [[ 15.00  18.00  16.00]
 [ 42.00  45.00  43.00]
 [ 69.00  72.00  70.00]]
A0 @ B0              
 [[15 18 16]
 [42 45 43]
 [69 72 70]]


It's interesting to look at what code is generated for these.

In [2]:
# inner-product via JAGS native matrix multiplication

print(jags_code(x @ y))

model{
v31v<-(v2v[1:3])%*%(v3v[1:3]);
}



In [3]:
# inner-product via vmap

print(jags_code(inner(x,y)))

model{
for (i0 in 1:3){
  v35v[i0]<-(v2v[i0])*(v3v[i0]);
}
v36v<-sum(v35v[1:3]);
}



In [4]:
# mat times vector using JAGS native matrix multiplication

print(jags_code(A @ x))

model{
v37v[1:3]<-(v0v[1:3,1:3])%*%(v2v[1:3]);
}



In [5]:
# mat times vector using vmap on top of vmap

print(jags_code(mat_times_vec(A, x)))

model{
for (i1 in 1:3){
  for (i0 in 1:3){
    v44v[i1,i0]<-(v0v[i1,i0])*(v2v[i0]);
  }
}
for (i0 in 1:3){
  v45v[i0]<-sum(v44v[i0,1:3]);
}
}



In [6]:
# mat times mat using jags native matrix multiplication

print(jags_code(A @ B))

model{
v46v[1:3,1:3]<-(v0v[1:3,1:3])%*%(v1v[1:3,1:3]);
}



In [7]:
# mat times mat using vmap on top of vmap on top of vmap

print(jags_code(mat_times_mat(A,B)))

model{
v47v[1:3,1:3]<-t(v1v[1:3,1:3]);
for (i2 in 1:3){
  for (i1 in 1:3){
    for (i0 in 1:3){
      v57v[i2,i1,i0]<-(v47v[i1,i0])*(v0v[i2,i0]);
    }
  }
}
for (i1 in 1:3){
  for (i0 in 1:3){
    v58v[i1,i0]<-sum(v57v[i1,i0,1:3]);
  }
}
}

