Skip to content

Commit

Permalink
added to_dense, removed viz
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Feb 26, 2021
1 parent 91cbac9 commit 452c661
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 27 deletions.
6 changes: 6 additions & 0 deletions emlp/solver/linear_operator_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ def _transpose(self):
""" Default implementation of _transpose; defers to rmatvec + conj"""
return _TransposedLinearOperator(self)

def to_dense(self):
""" Default implementation of to_dense which produces the dense
matrix corresponding to the given lazy matrix. Defaults to
multiplying by the identity """
return self@jnp.eye(self.shape[-1])


class _CustomLinearOperator(LinearOperator):
"""Linear operator defined in terms of user-specified operations."""
Expand Down
23 changes: 18 additions & 5 deletions emlp/solver/linear_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import numpy as np
import jax.jit as jit

from functools import reduce

class Lazy(LinearOperator):
def __init__(self,dense_matrix):
Expand Down Expand Up @@ -36,6 +36,10 @@ def _adjoint(self):
return LazyDirectSum([Mi.T for Mi in self.Ms])
def invT(self):
return LazyDirectSum([M.invT() for M in self.Ms])
def to_dense(self):
Ms_all = [self.Ms[i] for i in range(len(self.Ms)) for _ in range(self.multiplicities[i])]
Ms_all = [Mi.to_dense() if isinstance(Mi,LinearOperator) else Mi for Mi in Ms_all]
return jax.scipy.linalg.block_diag(*Ms_all)
def __new__(cls,Ms,multiplicities=None):
if len(Ms)==1 and multiplicities is None: return Ms[0]
return super().__new__(cls)
Expand Down Expand Up @@ -74,9 +78,16 @@ def _adjoint(self):
return LazyKron([Mi.T for Mi in self.Ms])
def invT(self):
return LazyKron([M.invT() for M in self.Ms])
# def __new__(cls,Ms):
# if len(Ms)==1: return Ms[0]
# return super().__new__(cls)
def to_dense(self):
Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
return reduce(jnp.kron,Ms)
def __new__(cls,Ms):
if len(Ms)==1: return Ms[0]
return super().__new__(cls)

@jit
def kronsum(A,B):
return jnp.kron(A,jnp.eye(B.shape[-1])) + jnp.kron(jnp.eye(A.shape[-1]),B)

class LazyKronsum(LinearOperator):

Expand All @@ -101,7 +112,9 @@ def _matmat(self,v):

def _adjoint(self):
return LazyKronsum([Mi.T for Mi in self.Ms])

def to_dense(self):
Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
return reduce(kronsum,Ms)
def __new__(cls,Ms):
if len(Ms)==1: return Ms[0]
return super().__new__(cls)
Expand Down
11 changes: 2 additions & 9 deletions emlp/solver/product_sum_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class SumRep(Rep):
concrete=True
atomic=False
def __init__(self,*reps,extra_perm=None,viz_shape_hint=None):#repcounter,repperm=None):
def __init__(self,*reps,extra_perm=None):#repcounter,repperm=None):
""" Constructs a tensor type based on a list of tensor ranks
and possibly the symmetry generators gen."""
# Integers can be used as shorthand for scalars.
Expand All @@ -43,7 +43,6 @@ def __init__(self,*reps,extra_perm=None,viz_shape_hint=None):#repcounter,repperm
self.invperm = np.argsort(self.perm)
self.canonical=(self.perm==self.invperm).all()
self.is_regular = all(rep.is_regular for rep in self.reps.keys())
if viz_shape_hint is not None: self.viz_shape_hint = viz_shape_hint
# if not self.canonical:
# print(self,self.perm,self.invperm)

Expand Down Expand Up @@ -253,7 +252,7 @@ def distribute_product(reps,extra_perm=None):
total_perm = order[block_perm[each_perm]]
if extra_perm is not None: total_perm = extra_perm[total_perm]
#TODO: could achieve additional reduction by canonicalizing at this step, but unnecessary for now
return SumRep(*ordered_reps,extra_perm=total_perm,viz_shape_hint=axis_sizes)
return SumRep(*ordered_reps,extra_perm=total_perm)


@cache(maxsize=None)
Expand Down Expand Up @@ -508,9 +507,3 @@ def T(self):
# and all(Ga==Gb for Ga,Gb in zip(self.reps,other.reps)) \
# and all(ra==rb for ra,rb in zip(self.reps.values(),other.reps.values())) \

@jit
def kronsum(A,B):
return jnp.kron(A,jnp.eye(B.shape[-1])) + jnp.kron(jnp.eye(A.shape[-1]),B)



17 changes: 4 additions & 13 deletions emlp/solver/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def size(self):
print(self,type(self))
raise NotImplementedError # The dimension of the representation
def rho_dense(self,M):
return self.rho(M)@jnp.eye(self.size())
rho = self.rho(M)
return rho.to_dense() if isinstance(rho,LinearOperator) else rho
def drho_dense(self,A):
return self.rho(A)@jnp.eye(self.size())
rho = self.drho(M)
return drho.to_dense() if isinstance(drho,LinearOperator) else drho
# def rho(self,M): # Group representation of matrix M (n,n)
# if hasattr(self,'_rho'): return self._rho(M)
# elif hasattr(self,'_rho_lazy'): return self._rho_lazy(M)@jnp.eye(self.size())
Expand Down Expand Up @@ -109,17 +111,6 @@ def symmetric_projector(self):
P = Q_lazy@Q_lazy.H
return P

def visualize(self,*shape):
#TODO: add support for non square
rep,perm = self.canonicalize()
Q = rep.symmetric_basis()
A = (sparsify_basis(Q)[np.argsort(perm)]@np.arange(1,Q.shape[-1]+1))
# Q= self.symmetric_basis() #THIS:
# A = sparsify_basis(Q)
if hasattr(self,"viz_shape_hint") and not shape: shape = self.viz_shape_hint
plt.imshow(A.reshape(shape))
plt.axis('off')

def __add__(self, other): # Tensor sum representation R1 + R2
if isinstance(other,int):
if other==0: return self
Expand Down

0 comments on commit 452c661

Please sign in to comment.