Skip to content

Commit

Permalink
refactor constraint matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Feb 26, 2021
1 parent 5a85d5a commit dcc1945
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 145 deletions.
114 changes: 67 additions & 47 deletions emlp/solver/linear_operators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .linear_operator_jax import LinearOperator
import jax.numpy as jnp
import numpy as np
import jax.jit as jit
from jax import jit
import jax
from functools import reduce

Expand All @@ -18,45 +18,20 @@ def _rmatmat(self,V):
return self.A.T@V
def _rmatvec(self,v):
return self.A.T@v
def to_dense(self):
return self.A

class LazyDirectSum(LinearOperator):
def __init__(self,Ms,multiplicities=None):
self.Ms = [jax.device_put(M.astype(np.float32)) if isinstance(M,(np.ndarray,jnp.ndarray)) else M for M in Ms]
self.multiplicities = [1 for M in Ms] if multiplicities is None else multiplicities
self.shape = (sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)),
sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)))
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')

def _matvec(self,v):
return self._matmat(v.reshape(v.shape[0],-1)).reshape(-1)

def _matmat(self,v): # (n,k)
return lazy_direct_matmat(v,self.Ms,self.multiplicities)
class I(LinearOperator):
def __init__(self,d):
self.shape = (d,d)
def _matmat(self,V): #(c,k)
return V
def _matvec(self,V):
return V
def _adjoint(self):
return LazyDirectSum([Mi.T for Mi in self.Ms])
return self
def invT(self):
return LazyDirectSum([M.invT() for M in self.Ms])
def to_dense(self):
Ms_all = [self.Ms[i] for i in range(len(self.Ms)) for _ in range(self.multiplicities[i])]
Ms_all = [Mi.to_dense() if isinstance(Mi,LinearOperator) else Mi for Mi in Ms_all]
return jax.scipy.linalg.block_diag(*Ms_all)
def __new__(cls,Ms,multiplicities=None):
if len(Ms)==1 and multiplicities is None: return Ms[0]
return super().__new__(cls)

def lazy_direct_matmat(v,Ms,mults):
n,k = v.shape
i=0
y = []
for M, multiplicity in zip(Ms,mults):
if not M.shape[-1]: continue
i_end = i+multiplicity*M.shape[-1]
elems = M@v[i:i_end].T.reshape(-1,M.shape[-1]).T
y.append(elems.T.reshape(k,multiplicity*M.shape[0]).T)
i = i_end
y = jnp.concatenate(y,axis=0) #concatenate over rep axis
return y
return self

class LazyKron(LinearOperator):

Expand Down Expand Up @@ -110,7 +85,6 @@ def _matmat(self,v):
out += jnp.moveaxis(Mev_front,0,i)
return out.reshape(self.shape[0],ev.shape[-1])


def _adjoint(self):
return LazyKronsum([Mi.T for Mi in self.Ms])
def to_dense(self):
Expand All @@ -126,17 +100,63 @@ def __new__(cls,Ms):
# rprod = np.cumprod([1]+[mi.shape[-1] for mi in reversed(Ms)])[::-1]
# return reduce(lambda a,b: a+b,[lazy_kron([I(lprod[i]),Mi,I(rprod[i+1])]) for i,Mi in enumerate(Ms)])

class I(LinearOperator):
def __init__(self,d):
self.shape = (d,d)
def _matmat(self,V): #(c,k)
return V
def _matvec(self,V):
return V
class ConcatLazy(LinearOperator):
""" Produces a linear operator equivalent to concatenating
a collection of matrices Ms along axis=0 """
def __init__(self,Ms):
self.Ms = Ms
assert all(M.shape==Ms.shape[0] for M in Ms),\
f"Trying to concatenate matrices of different sizes {[M.shape for M in Ms]}"
self.shape = (sum(M.shape[0] for M in Ms),Ms[0].shape[1])

def _matmat(self,V):
return jnp.concatenate([M@V for M in self.Ms],axis=0)
def _rmatmat(self,V):
Vs = jnp.split(V,len(self.Ms))
return sum([self.Ms[i].T@Vs[i] for i in range(len(self.Ms))])
def to_dense(self):
dense_Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
return jnp.concatenate(dense_ms,axis=0)

class LazyDirectSum(LinearOperator):
def __init__(self,Ms,multiplicities=None):
self.Ms = [jax.device_put(M.astype(np.float32)) if isinstance(M,(np.ndarray,jnp.ndarray)) else M for M in Ms]
self.multiplicities = [1 for M in Ms] if multiplicities is None else multiplicities
self.shape = (sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)),
sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)))
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')

def _matvec(self,v):
return self._matmat(v.reshape(v.shape[0],-1)).reshape(-1)

def _matmat(self,v): # (n,k)
return lazy_direct_matmat(v,self.Ms,self.multiplicities)
def _adjoint(self):
return self
return LazyDirectSum([Mi.T for Mi in self.Ms])
def invT(self):
return self
return LazyDirectSum([M.invT() for M in self.Ms])
def to_dense(self):
Ms_all = [self.Ms[i] for i in range(len(self.Ms)) for _ in range(self.multiplicities[i])]
Ms_all = [Mi.to_dense() if isinstance(Mi,LinearOperator) else Mi for Mi in Ms_all]
return jax.scipy.linalg.block_diag(*Ms_all)
def __new__(cls,Ms,multiplicities=None):
if len(Ms)==1 and multiplicities is None: return Ms[0]
return super().__new__(cls)

def lazy_direct_matmat(v,Ms,mults):
n,k = v.shape
i=0
y = []
for M, multiplicity in zip(Ms,mults):
if not M.shape[-1]: continue
i_end = i+multiplicity*M.shape[-1]
elems = M@v[i:i_end].T.reshape(-1,M.shape[-1]).T
y.append(elems.T.reshape(k,multiplicity*M.shape[0]).T)
i = i_end
y = jnp.concatenate(y,axis=0) #concatenate over rep axis
return y



class LazyPerm(LinearOperator):
Expand Down
29 changes: 2 additions & 27 deletions emlp/solver/product_sum_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def compute_canonical(rep_cnters,rep_perms):
permlist.append(shifted_perms[i][ids[i]:ids[i]+c*rep.size()])
ids[i]+=+c*rep.size()
merged_cnt[rep]+=c
#print(permlist)
return dict(merged_cnt),np.concatenate(permlist)

def symmetric_basis(self):
Expand Down Expand Up @@ -329,7 +328,7 @@ def __eq__(self, other): #TODO: worry about non canonical?
def size(self):
return product([rep.size()**count for rep,count in self.reps.items()])
@property
def T(self): #TODO: reavaluate if this needs to change the order (it does not)
def T(self):
""" only swaps to adjoint representation, does not reorder elems"""
return self.__class__(*[rep.T for rep,c in self.reps.items() for _ in range(c)],extra_perm=self.perm)
#return self.__class__(counter={rep.T:c for rep,c in self.reps.items()},extra_perm=self.perm)
Expand Down Expand Up @@ -435,7 +434,6 @@ def __init__(self,*reps):
def __call__(self,G):
if G is None: return self
return sum([rep(G) for rep in self.to_sum])
#return SumRep(*[rep(G) for rep in self.to_sum])
def __repr__(self):
return '('+"+".join(f"{rep}" for rep in self.to_sum)+')'
def __str__(self):
Expand All @@ -454,33 +452,10 @@ def __init__(self,*reps):
def __call__(self,G):
if G is None: return self
return reduce(lambda a,b:a*b,[rep(G) for rep in self.to_prod])
#return ProductRep(*[rep(G) for rep in self.to_prod])
def __repr__(self):
return "⊗".join(f"{rep}" for rep in self.to_prod)
def __str__(self):
return repr(self)
@property
def T(self):
return DeferredProductRep(*[rep.T for rep in self.to_prod])#TODO: need to reverse the order?

# class ProductGroupTensorRep(Rep): # Eventually will need to handle reordering to canonical G1,G2, etc (from hash?)
# atomic=False
# # TO be used like (T(0) + T(1))(SO(3))*T(1)(S(5)) -> T(2)(SO(3))
# def __init__(self,rep_dict):
# assert len(rep_dict)>1, "Not product rep?"
# self.reps = rep_dict
# #for rep in rep_dict.values():
# #self.ordering =
# def rho(self,Ms):
# rhos = [rep.rho(Ms[G]) for (G,rep) in self.reps.items()]
# return functools.reduce(jnp.kron,rhos,1)
# def drho(self,As):
# drhos = [rep.drho(As[G]) for (G,rep) in self.reps.items()]
# raise functools.reduce(kronsum,drhos,0)

# def __eq__(self, other):
# if not isinstance(other,ProductGroupTensorRep): return False
# return len(self.reps)==len(other.reps) \
# and all(Ga==Gb for Ga,Gb in zip(self.reps,other.reps)) \
# and all(ra==rb for ra,rb in zip(self.reps.values(),other.reps.values())) \

return DeferredProductRep(*[rep.T for rep in self.to_prod])
82 changes: 11 additions & 71 deletions emlp/solver/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import lru_cache as cache
from .utils import ltqdm,prod
from .linear_operator_jax import LinearOperator
from .linear_operators import Lazy
from .linear_operators import Lazy,ConcatLazy
import scipy as sp
import scipy.linalg
import functools
Expand Down Expand Up @@ -49,38 +49,17 @@ def rho_dense(self,M):
def drho_dense(self,A):
rho = self.drho(M)
return drho.to_dense() if isinstance(drho,LinearOperator) else drho
# def rho(self,M): # Group representation of matrix M (n,n)
# if hasattr(self,'_rho'): return self._rho(M)
# elif hasattr(self,'_rho_lazy'): return self._rho_lazy(M)@jnp.eye(self.size())
# else: raise NotImplementedError
# def drho(self,A):# Lie Algebra representation of matrix A (n,n)
# if hasattr(self,'_drho'): return self._drho(M)
# elif hasattr(self,'_drho_lazy'): return self._drho_lazy(M)@jnp.eye(self.size())
# else: raise NotImplementedError
# def rho_lazy(self,M): # Group representation of matrix M (n,n)
# if hasattr(self,'_rho_lazy'): return self._rho_lazy(M)
# elif hasattr(self,'_rho'): return self._rho(M)
# else: raise NotImplementedError
# def drho_lazy(self,A):# Lie Algebra representation of matrix A (n,n)
# if hasattr(self,'_drho_lazy'): return self._drho_lazy(M)
# elif hasattr(self,'_drho'): return self._drho
# else: raise NotImplementedError
# def rho_lazy(self,M): raise NotImplementedError # Lazy version of rho
# def drho_lazy(self,M): raise NotImplementedError # Lazy version of drho
# def constraint_matrix(self):
# """ Given a sequence of exponential generators [A1,A2,...]
# and a tensor rank (p,q), the function concatenates the representations
# [drho(A1), drho(A2), ...] into a single large projection matrix.
# Input: [generators seq(tensor(d,d))], [rank tuple(p,q)], [d int] """
# constraints = [] # Multiply by identity to convert lazy to dense matrices
# constraints.extend([self.drho(A)@jnp.eye(self.size()) for A in self.G.lie_algebra])
# constraints.extend([self.rho(h)@jnp.eye(self.size())-jnp.eye(self.size()) for h in self.G.discrete_generators])
# P = jnp.concatenate(constraints,axis=0) if constraints else jnp.zeros((1,self.size()))
# return P

def constraint_matrix(self):
""" A lazy version of constraint_matrix"""
return ConstraintMatrixLazy(self.G,self.rho,self.drho,self.size())
""" Given a sequence of exponential generators [A1,A2,...]
and a tensor rank (p,q), the function concatenates the representations
[drho(A1), drho(A2), ...] into a single large constraint matrix C.
Input: [generators seq(tensor(d,d))], [rank tuple(p,q)], [d int] """
n = self.size()
constraints = []
constraints.extend([self.rho(h)-I(n) for h in self.G.discrete_generators])
constraints.extend([self.drho(A) for A in self.G.lie_algebra])
return ConcatLazy(constraints) if constraints else jnp.zeros(1,n)

#@disk_cache('./_subspace_cache_jax.dat')
solcache = {}
Expand All @@ -100,8 +79,7 @@ def symmetric_basis(self):
if prod(C_lazy.shape)>3e7: #Too large to use SVD
result = krylov_constraint_solve(C_lazy)
else:
C_dense = C_lazy@jnp.eye(C_lazy.shape[-1])
result = orthogonal_complement(C_dense)
result = orthogonal_complement(C_lazy.to_dense())
self.solcache[canon_rep]=result
return self.solcache[canon_rep][invperm]

Expand Down Expand Up @@ -250,44 +228,6 @@ def __lt__(self,other):
def T(p,q=0,G=None):
return (V**p*V.T**q)(G)

class ConstraintMatrixLazy(LinearOperator):
def __init__(self,group,rho_lazy,drho_lazy,size):
self.d = group.d
self.rho_lazy=rho_lazy
self.drho_lazy=drho_lazy
# if group.discrete_generators_lazy is not NotImplemented:
# self.hi = group.discrete_generators_lazy
# else:
# self.hi=group.discrete_generators
# logging.debug(f"no discrete lazy found for {group}")
# if group.lie_algebra_lazy is not NotImplemented:
# self.Ai = group.lie_algebra_lazy
# else:
# self.Ai = group.lie_algebra
# logging.debug(f"no Lie Algebra lazy found for {group}")
self.hi = group.discrete_generators
self.Ai = group.lie_algebra
self.G=group
self.n_constraints= len(self.hi)+len(self.Ai)
if not self.n_constraints: raise NotImplementedError
self.c = size
self.dtype=np.float32
@property
def shape(self):
return (self.c*self.n_constraints,self.c)
def _matmat(self,V): #(c,k)
constraints = []
constraints.extend([self.drho_lazy(A)@V for A in self.Ai])
constraints.extend([self.rho_lazy(h)@V-V for h in self.hi])
CV = jnp.concatenate(constraints,axis=0)
return CV
def _rmatmat(self,V):
n_constraints = len(self.hi)+len(self.Ai)
Vi = jnp.split(V,self.n_constraints)
out = 0
out += sum([self.drho_lazy(A).T@Vi[i] for i,A in enumerate(self.Ai)])
out += sum([self.rho_lazy(h).T@Vi[i+len(self.Ai)] for i,h in enumerate(self.hi)])
return out

def orthogonal_complement(proj):
""" Computes the orthogonal complement to a given matrix proj"""
Expand Down

0 comments on commit dcc1945

Please sign in to comment.