Skip to content

Commit

Permalink
sped up construction of large sum reps, extended linear operators, re…
Browse files Browse the repository at this point in the history
…factor fix
  • Loading branch information
mfinzi committed Mar 1, 2021
1 parent b500a4a commit 0796116
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 124 deletions.
116 changes: 69 additions & 47 deletions emlp/solver/linear_operator_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,12 @@ def __pow__(self, p):
def __add__(self, x):
if isinstance(x, LinearOperator):
return _SumLinearOperator(self, x)
elif isinstance(x,np.ndarray) and len(x.shape)==2:
return _SumLinearOperator(self, Lazy(x))
else:
return NotImplemented

def __radd__(self,x):
return self.__add__(x)
def __neg__(self):
return _ScaledLinearOperator(self, -1)

Expand Down Expand Up @@ -717,49 +720,68 @@ def _adjoint(self):
return self


def aslinearoperator(A):
"""Return A as a LinearOperator.
'A' may be any of the following types:
- ndarray
- matrix
- sparse matrix (e.g. csr_matrix, lil_matrix, etc.)
- LinearOperator
- An object with .shape and .matvec attributes
See the LinearOperator documentation for additional information.
Notes
-----
If 'A' has no .dtype attribute, the data type is determined by calling
:func:`LinearOperator.matvec()` - set the .dtype attribute to prevent this
call upon the linear operator creation.
Examples
--------
>>> from scipy.sparse.linalg import aslinearoperator
>>> M = np.array([[1,2,3],[4,5,6]], dtype=np.int32)
>>> aslinearoperator(M)
<2x3 MatrixLinearOperator with dtype=int32>
"""
if isinstance(A, LinearOperator):
return A

elif isinstance(A, np.ndarray):
if A.ndim > 2:
raise ValueError('array must have ndim <= 2')
return MatrixLinearOperator(A)

else:
if hasattr(A, 'shape') and hasattr(A, 'matvec'):
rmatvec = None
rmatmat = None
dtype = None

if hasattr(A, 'rmatvec'):
rmatvec = A.rmatvec
if hasattr(A, 'rmatmat'):
rmatmat = A.rmatmat
if hasattr(A, 'dtype'):
dtype = A.dtype
return LinearOperator(A.shape, A.matvec, rmatvec=rmatvec,
rmatmat=rmatmat, dtype=dtype)

else:
raise TypeError('type not understood')
class Lazy(LinearOperator):
def __init__(self,dense_matrix):
self.A = dense_matrix
super().__init__(self.A.dtype,self.A.shape)

def _matmat(self,V):
return self.A@V
def _matvec(self,v):
return self.A@v
def _rmatmat(self,V):
return self.A.T@V
def _rmatvec(self,v):
return self.A.T@v
def to_dense(self):
return self.A
def invT(self):
return Lazy(np.linalg.inv(self.A).T)


# def aslinearoperator(A):
# """Return A as a LinearOperator.
# 'A' may be any of the following types:
# - ndarray
# - matrix
# - sparse matrix (e.g. csr_matrix, lil_matrix, etc.)
# - LinearOperator
# - An object with .shape and .matvec attributes
# See the LinearOperator documentation for additional information.
# Notes
# -----
# If 'A' has no .dtype attribute, the data type is determined by calling
# :func:`LinearOperator.matvec()` - set the .dtype attribute to prevent this
# call upon the linear operator creation.
# Examples
# --------
# >>> from scipy.sparse.linalg import aslinearoperator
# >>> M = np.array([[1,2,3],[4,5,6]], dtype=np.int32)
# >>> aslinearoperator(M)
# <2x3 MatrixLinearOperator with dtype=int32>
# """
# if isinstance(A, LinearOperator):
# return A

# elif isinstance(A, np.ndarray):
# if A.ndim > 2:
# raise ValueError('array must have ndim <= 2')
# return MatrixLinearOperator(A)

# else:
# if hasattr(A, 'shape') and hasattr(A, 'matvec'):
# rmatvec = None
# rmatmat = None
# dtype = None

# if hasattr(A, 'rmatvec'):
# rmatvec = A.rmatvec
# if hasattr(A, 'rmatmat'):
# rmatmat = A.rmatmat
# if hasattr(A, 'dtype'):
# dtype = A.dtype
# return LinearOperator(A.shape, A.matvec, rmatvec=rmatvec,
# rmatmat=rmatmat, dtype=dtype)

# else:
# raise TypeError('type not understood')
44 changes: 17 additions & 27 deletions emlp/solver/linear_operators.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
from .linear_operator_jax import LinearOperator
from .linear_operator_jax import LinearOperator,Lazy
import jax.numpy as jnp
import numpy as np
from jax import jit
import jax
from functools import reduce
from .utils import prod as product

class Lazy(LinearOperator):
def __init__(self,dense_matrix):
self.A = dense_matrix
self.shape = self.A.shape
self.dtype = self.A.dtype
def _matmat(self,V):
return self.A@V
def _matvec(self,v):
return self.A@v
def _rmatmat(self,V):
return self.A.T@V
def _rmatvec(self,v):
return self.A.T@v
def to_dense(self):
return self.A
def lazify(x):
if isinstance(x,LinearOperator): return x
elif isinstance(x,(jnp.ndarray,np.ndarray)): return Lazy(x)
else: raise NotImplementedError

class I(LinearOperator):
def __init__(self,d):
self.shape = (d,d)
shape = (d,d)
super().__init__(None, shape)
def _matmat(self,V): #(c,k)
return V
def _matvec(self,V):
Expand Down Expand Up @@ -121,15 +111,16 @@ def to_dense(self):

class LazyDirectSum(LinearOperator):
def __init__(self,Ms,multiplicities=None):
self.Ms = [jax.device_put(M.astype(np.float32)) if isinstance(M,(np.ndarray,jnp.ndarray)) else M for M in Ms]
self.Ms = [jax.device_put(M.astype(np.float32)) if isinstance(M,(np.ndarray)) else M for M in Ms]
self.multiplicities = [1 for M in Ms] if multiplicities is None else multiplicities
self.shape = (sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)),
shape = (sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)),
sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)))
super().__init__(None,shape)
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')
#self.dtype=jnp.dtype('float32')

def _matvec(self,v):
return self._matmat(v.reshape(v.shape[0],-1)).reshape(-1)
return lazy_direct_matmat(v,self.Ms,self.multiplicities)

def _matmat(self,v): # (n,k)
return lazy_direct_matmat(v,self.Ms,self.multiplicities)
Expand All @@ -141,25 +132,24 @@ def to_dense(self):
Ms_all = [M for M,c in zip(self.Ms,self.multiplicities) for _ in range(c)]
Ms_all = [Mi.to_dense() if isinstance(Mi,LinearOperator) else Mi for Mi in Ms_all]
return jax.scipy.linalg.block_diag(*Ms_all)
def __new__(cls,Ms,multiplicities=None):
if len(Ms)==1 and multiplicities is None: return Ms[0]
return super().__new__(cls)
# def __new__(cls,Ms,multiplicities=None):
# if len(Ms)==1 and multiplicities is None: return Ms[0]
# return super().__new__(cls)

def lazy_direct_matmat(v,Ms,mults):
n,k = v.shape
n = v.shape[0]
i=0
y = []
for M, multiplicity in zip(Ms,mults):
if not M.shape[-1]: continue
i_end = i+multiplicity*M.shape[-1]
elems = M@v[i:i_end].T.reshape(-1,M.shape[-1]).T
y.append(elems.T.reshape(k,multiplicity*M.shape[0]).T)
y.append(elems.T.reshape(-1,multiplicity*M.shape[0]).T)
i = i_end
y = jnp.concatenate(y,axis=0) #concatenate over rep axis
return y



class LazyPerm(LinearOperator):
def __init__(self,perm):
self.perm=perm
Expand Down
83 changes: 46 additions & 37 deletions emlp/solver/product_sum_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import random
from .representation import Rep
from .linear_operator_jax import LinearOperator
from .linear_operators import LazyPerm,LazyDirectSum,LazyKron,LazyKronsum,I
from .linear_operators import LazyPerm,LazyDirectSum,LazyKron,LazyKronsum,I,lazy_direct_matmat,lazify
import logging
import copy
import math
Expand All @@ -41,7 +41,7 @@ def __init__(self,*reps,extra_perm=None):#repcounter,repperm=None):
self.reps,perm = self.compute_canonical(rep_counters,perms)
self.perm = extra_perm[perm] if extra_perm is not None else perm
self.invperm = np.argsort(self.perm)
self.canonical=(self.perm==self.invperm).all()
self.canonical=(self.perm==np.arange(len(self.perm))).all()
self.is_regular = all(rep.is_regular for rep in self.reps.keys())
# if not self.canonical:
# print(self,self.perm,self.invperm)
Expand Down Expand Up @@ -104,55 +104,64 @@ def symmetric_basis(self):
to the projection matrix drho(Mi). Function returns both the
dimension of the active subspace (r) and also a function that
maps an array of size (*,r) to a vector v with a representaiton
given by the rnaks that satisfies drho(Mi)v=0 for each i.
given by the ranks that satisfies drho(Mi)v=0 for each i.
Inputs: [generators seq(tensor(d,d))] [ranks seq(tuple(p,q))]
Outputs: [r int] [projection (tensor(r)->tensor(rep_dim))]"""
Qs = {rep: rep.symmetric_basis() for rep in self.reps}
Qs = {rep: (jax.device_put(Q.astype(np.float32)) if isinstance(Q,(np.ndarray)) else Q) for rep,Q in Qs.items()}
active_dims = sum([self.reps[rep]*Qs[rep].shape[-1] for rep in Qs.keys()])
multiplicities = self.reps.values()
def lazy_Q(array):
array = array.T
i=0
Ws = []
for rep, multiplicity in self.reps.items():
Qr = Qs[rep]
if not Qr.shape[-1]: continue
i_end = i+multiplicity*Qr.shape[-1]
elems = Qr@array[...,i:i_end].reshape(-1,Qr.shape[-1]).T
Ws.append(elems.T.reshape(*array.shape[:-1],multiplicity*rep.size()))
i = i_end
Ws = jnp.concatenate(Ws,axis=-1) #concatenate over rep axis
inp_ordered_Ws = Ws[...,self.invperm] #(should it be inverse?) reorder to original rep ordering
return inp_ordered_Ws.T
return lazy_direct_matmat(array,Qs.values(),multiplicities)[self.invperm]
# def lazy_Q(array):
# array = array.T
# i=0
# Ws = []
# for rep, multiplicity in self.reps.items():
# Qr = Qs[rep]
# if not Qr.shape[-1]: continue
# i_end = i+multiplicity*Qr.shape[-1]
# elems = Qr@array[...,i:i_end].reshape(-1,Qr.shape[-1]).T
# Ws.append(elems.T.reshape(*array.shape[:-1],multiplicity*rep.size()))
# i = i_end
# Ws = jnp.concatenate(Ws,axis=-1) #concatenate over rep axis
# inp_ordered_Ws = Ws[...,self.invperm] #(should it be inverse?) reorder to original rep ordering
# return inp_ordered_Ws.T
return LinearOperator(shape=(self.size(),active_dims),matvec=lazy_Q,matmat=lazy_Q)

def symmetric_projector(self):
Ps = {rep:rep.symmetric_projector() for rep in self.reps}

# Apply the projections for each rep, concatenate, and permute back to orig rep order
def lazy_QQT(W):
ordered_W = W[self.perm]
PWs = []
i=0
for rep, multiplicity in self.reps.items():
P = Ps[rep]
i_end = i+multiplicity*rep.size()
PWs.append((P@ordered_W[i:i_end].reshape(multiplicity,rep.size()).T).T.reshape(-1))
i = i_end
#print(rep,multiplicity,i_end)
PWs = jnp.concatenate(PWs,axis=-1) #concatenate over rep axis
inp_ordered_PWs = PWs[self.invperm]
#print(inp_ordered_PWs)
return inp_ordered_PWs # reorder to original rep ordering
return LinearOperator(shape=(self.size(),self.size()),matvec=lazy_QQT)

##TODO: investigate why these more idiomatic definitions with Lazy Tensors end up slower
# def lazy_QQT(W):
# ordered_W = W[self.perm]
# PWs = []
# i=0
# for rep, multiplicity in self.reps.items():
# P = Ps[rep]
# i_end = i+multiplicity*rep.size()
# PWs.append((P@ordered_W[i:i_end].reshape(multiplicity,rep.size()).T).T.reshape(-1))
# i = i_end
# #print(rep,multiplicity,i_end)
# PWs = jnp.concatenate(PWs,axis=-1) #concatenate over rep axis
# inp_ordered_PWs = PWs[self.invperm]
# #print(inp_ordered_PWs)
# return inp_ordered_PWs # reorder to original rep ordering
multiplicities = self.reps.values()
def lazy_P(array):
return lazy_direct_matmat(array,Ps.values(),multiplicities)[self.invperm]
return LinearOperator(shape=(self.size(),self.size()),matvec=lazy_P,matmat=lazy_P)

# ##TODO: investigate why these more idiomatic definitions with Lazy Tensors end up slower
# def symmetric_basis(self):
# Qs = [rep.symmetric_basis() for rep in self.reps]
# Qs = [(jax.device_put(Q.astype(np.float32)) if isinstance(Q,(np.ndarray)) else Q) for Q in Qs]
# multiplicities = self.reps.values()
# return LazyPerm(self.invperm)@LazyDirectSum(Qs,multiplicities)
# Q = I(len(self.perm))
# Q@jnp.zeros((Q.shape[-1],1))
# return Q#LazyPerm(self.invperm)#LazyDirectSum(Qs,multiplicities)#LazyPerm(self.invperm)@LazyDirectSum(Qs,multiplicities)
# def symmetric_projector(self):
# Ps = [rep.symmetric_projector() for rep in self.reps]
# Ps = (jax.device_put(P.astype(np.float32)) if isinstance(P,(np.ndarray)) else P)
# multiplicities = self.reps.values()
# return LazyPerm(self.invperm)@LazyDirectSum(Ps,multiplicities)@LazyPerm(self.perm)
def rho(self,M):
Expand Down Expand Up @@ -201,7 +210,7 @@ def __init__(self,counter,perm=None):
self.perm = np.arange(self.size()) if perm is None else perm
self.reps,self.perm = self.compute_canonical([counter],[self.perm])
self.invperm = np.argsort(self.perm)
self.canonical=(self.perm==self.invperm).all()
self.canonical=(self.perm==np.arange(len(self.perm))).all()
self.is_regular = all(rep.is_regular for rep in self.reps.keys())
# if not self.canonical:
# print(self,self.perm,self.invperm)
Expand Down Expand Up @@ -433,7 +442,7 @@ def __init__(self,*reps):

def __call__(self,G):
if G is None: return self
return sum([rep(G) for rep in self.to_sum])
return SumRep(*[rep(G) for rep in self.to_sum])
def __repr__(self):
return '('+"+".join(f"{rep}" for rep in self.to_sum)+')'
def __str__(self):
Expand Down

0 comments on commit 0796116

Please sign in to comment.