Permalink
Find file
Fetching contributors…
Cannot retrieve contributors at this time
executable file 302 lines (223 sloc) 7.53 KB
#!/usr/bin/env python
# encoding: utf-8
"""
Author(s): Matthew Loper
See LICENCE.txt for licensing and contact information.
"""
__all__ = ['inv', 'svd', 'det', 'slogdet', 'pinv', 'lstsq', 'norm']
import numpy as np
import scipy.sparse as sp
from ch import Ch, depends_on, NanDivide
from utils import row, col
import ch
try:
asarray = ch.asarray
import inspect
exec(''.join(inspect.getsourcelines(np.linalg.tensorinv)[0]))
__all__.append('tensorinv')
except: pass
def norm(x, ord=None, axis=None):
if ord is not None or axis is not None:
raise NotImplementedError("'ord' and 'axis' should be None for now.")
return ch.sqrt(ch.sum(x**2))
# This version works but derivatives are too slow b/c of nested loop in Svd implementation.
# def lstsq(a, b):
# u, s, v = Svd(a)
# x = (v.T / s).dot(u.T.dot(b))
# residuals = NotImplementedError # ch.sum((a.dot(x) - b)**2, axis=0)
# rank = NotImplementedError
# s = NotImplementedError
# return x, residuals, rank, s
def lstsq(a, b, rcond=-1):
if rcond != -1:
raise Exception('non-default rcond not yet implemented')
x = Ch(lambda a, b : pinv(a).dot(b))
x.a = a
x.b = b
residuals = ch.sum( (x.a.dot(x) - x.b) **2 , axis=0)
rank = NotImplementedError
s = NotImplementedError
return x, residuals, rank, s
def Svd(x, full_matrices=0, compute_uv=1):
if full_matrices != 0:
raise Exception('full_matrices must be 0')
if compute_uv != 1:
raise Exception('compute_uv must be 1')
need_transpose = x.shape[0] < x.shape[1]
if need_transpose:
x = x.T
svd_d = SvdD(x=x)
svd_v = SvdV(x=x, svd_d=svd_d)
svd_u = SvdU(x=x, svd_d=svd_d, svd_v=svd_v)
if need_transpose:
return svd_v, svd_d, svd_u.T
else:
return svd_u, svd_d, svd_v.T
class Pinv(Ch):
dterms = 'mtx'
def on_changed(self, which):
mtx = self.mtx
if mtx.shape[1] > mtx.shape[0]:
result = mtx.T.dot(Inv(mtx.dot(mtx.T)))
else:
result = Inv(mtx.T.dot(mtx)).dot(mtx.T)
self._result = result
def compute_r(self):
return self._result.r
def compute_dr_wrt(self, wrt):
if wrt is self.mtx:
return self._result.dr_wrt(self.mtx)
# Couldn't make the SVD version of pinv work yet...
#
# class Pinv(Ch):
# dterms = 'mtx'
#
# def on_changed(self, which):
# u, s, v = Svd(self.mtx)
# result = (v.T * (NanDivide(1.,row(s)))).dot(u.T)
# self.add_dterm('_result', result)
#
# def compute_r(self):
# return self._result.r
#
# def compute_dr_wrt(self, wrt):
# if wrt is self._result:
# return 1
class LogAbsDet(Ch):
dterms = 'x'
def on_changed(self, which):
self.sign, self.slogdet = np.linalg.slogdet(self.x.r)
def compute_r(self):
return self.slogdet
def compute_dr_wrt(self, wrt):
if wrt is self.x:
return row(np.linalg.inv(self.x.r).T)
class SignLogAbsDet(Ch):
dterms = 'logabsdet',
def compute_r(self):
_ = self.logabsdet.r
return self.logabsdet.sign
def compute_dr_wrt(self, wrt):
return None
class Det(Ch):
dterms = 'x'
def compute_r(self):
return np.linalg.det(self.x.r)
def compute_dr_wrt(self, wrt):
if wrt is self.x:
return row(self.r * np.linalg.inv(self.x.r).T)
class Inv(Ch):
dterms = 'a'
def compute_r(self):
return np.linalg.inv(self.a.r)
def compute_dr_wrt(self, wrt):
if wrt is not self.a:
return None
Ainv = self.r
if Ainv.ndim <= 2:
return -np.kron(Ainv, Ainv.T)
else:
Ainv = np.reshape(Ainv, (-1, Ainv.shape[-2], Ainv.shape[-1]))
AinvT = np.rollaxis(Ainv, -1, -2)
AinvT = np.reshape(AinvT, (-1, AinvT.shape[-2], AinvT.shape[-1]))
result = np.dstack([-np.kron(Ainv[i], AinvT[i]).T for i in range(Ainv.shape[0])]).T
result = sp.block_diag(result)
return result
class SvdD(Ch):
dterms = 'x'
@depends_on('x')
def UDV(self):
result = np.linalg.svd(self.x.r, full_matrices=False)
result = [result[0], result[1], result[2].T]
result[1][np.abs(result[1]) < np.spacing(1)] = 0.
return result
def compute_r(self):
return self.UDV[1]
def compute_dr_wrt(self, wrt):
if wrt is not self.x:
return
u, d, v = self.UDV
shp = self.x.r.shape
u = u[:shp[0], :shp[1]]
v = v[:shp[1], :d.size]
result = np.einsum('ik,jk->kij', u, v)
result = result.reshape((result.shape[0], -1))
return result
class SvdV(Ch):
terms = 'svd_d'
dterms = 'x'
def compute_r(self):
return self.svd_d.UDV[2]
def compute_dr_wrt(self, wrt):
if wrt is not self.x:
return
U,_D,V = self.svd_d.UDV
shp = self.svd_d.x.r.shape
mxsz = max(shp[0], shp[1])
#mnsz = min(shp[0], shp[1])
D = np.zeros(mxsz)
D[:_D.size] = _D
omega = np.zeros((shp[0], shp[1], shp[1], shp[1]))
M = shp[0]
N = shp[1]
assert(M >= N)
for i in range(shp[0]):
for j in range(shp[1]):
for k in range(N):
for l in range(k+1, N):
mtx = np.array([
[D[l],D[k]],
[D[k],D[l]]])
rhs = np.array([U[i,k]*V[j,l], -U[i,l]*V[j,k]])
result = np.linalg.solve(mtx, rhs)
omega[i,j,k,l] = result[1]
omega[i,j,l,k] = -result[1]
#print 'v size is %s' % (str(V.shape),)
#print 'v omega size is %s' % (str(omega.shape),)
assert(V.shape[1] == omega.shape[2])
return np.einsum('ak,ijkl->alij', -V, omega).reshape((self.r.size, wrt.r.size))
class SvdU(Ch):
dterms = 'x'
terms = 'svd_d', 'svd_v'
def compute_r(self):
return self.svd_d.UDV[0]
def compute_dr_wrt(self, wrt):
if wrt is self.x:
# return (
# self.svd_d.x.dot(self.svd_v)
# /
# self.svd_d.reshape((1,-1))
# ).dr_wrt(self.svd_d.x)
return (
NanDivide(
self.svd_d.x.dot(self.svd_v),
self.svd_d.reshape((1,-1)))
).dr_wrt(self.svd_d.x)
inv = Inv
svd = Svd
det = Det
pinv = Pinv
def slogdet(*args):
n = len(args)
if n == 1:
r2 = LogAbsDet(x=args[0])
r1 = SignLogAbsDet(r2)
return r1, r2
else:
r2 = [LogAbsDet(x=arg) for arg in args]
r1 = [SignLogAbsDet(r) for r in r2]
r2 = ch.concatenate(r2)
return r1, r2
def main():
tmp = ch.random.randn(100).reshape((10,10))
print 'chumpy version: ' + str(slogdet(tmp)[1].r)
print 'old version:' + str(np.linalg.slogdet(tmp.r)[1])
eps = 1e-10
diff = np.random.rand(100) * eps
diff_reshaped = diff.reshape((10,10))
print np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1]
print slogdet(tmp)[1].dr_wrt(tmp).dot(diff)
print np.linalg.slogdet(tmp.r)[0]
print slogdet(tmp)[0]
if __name__ == '__main__':
main()