Skip to content

Commit

Permalink
added to_dense to base linear operators
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Feb 26, 2021
1 parent 4cb4826 commit 5a85d5a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 23 deletions.
14 changes: 14 additions & 0 deletions emlp/solver/linear_operator_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,11 @@ def _adjoint(self):
def invT(self):
A,B = self.args
return A.invT()*B.invT()
def to_dense(self):
A,B = self.args
A = A.to_dense() if isinstance(A,LinearOperator) else A
B = B.to_dense() if isinstance(B,LinearOperator) else B
return A@B


class _ScaledLinearOperator(LinearOperator):
Expand Down Expand Up @@ -616,6 +621,12 @@ def _matmat(self, x):
def _adjoint(self):
A, alpha = self.args
return A.H * np.conj(alpha)
def invT(self):
A, alpha = self.args
return (1/alpha)*A.T
def to_dense(self):
A, alpha = self.args
return alpha*A.to_dense()


class _PowerLinearOperator(LinearOperator):
Expand Down Expand Up @@ -651,6 +662,9 @@ def _matmat(self, x):
def _adjoint(self):
A, p = self.args
return A.H ** p
def invT(self):
A, p = self.args
return A.invT()**p


class MatrixLinearOperator(LinearOperator):
Expand Down
23 changes: 0 additions & 23 deletions emlp/solver/product_sum_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,29 +298,6 @@ def __init__(self,*reps,extra_perm=None,counter=None):
assert len(Gs)==1, f"Multiple different groups {Gs} in product rep {self}"
self.G= Gs[0]
self.is_regular = all(rep.is_regular for rep in self.reps.keys())

# if not self.canonical:
# print(self,self.perm,self.invperm)
# def __new__(cls,*reps,extra_perm=None,counter=None):

# if counter is not None: reps = counter.keys()
# unique_groups = set(rep.G for rep in reps if hasattr(rep,'G'))
# if len(unique_groups)>1 and len(unique_groups)!=len(reps):
# assert counter is None
# # write as ProductRep of separate ProductReps each with only one Group
# reps,perms = zip(*[rep.canonicalize() for rep in reps])
# rep_counters = [rep.reps if isinstance(rep,ProductRep) else {rep:1} for rep in reps]
# reps,perm = cls.compute_canonical(rep_counters,perms) # so that reps is sorted by group
# perm = extra_perm[perm] if extra_perm is not None else perm
# group_dict = defaultdict(dict)
# for rep,c in reps.items():
# group_dict[rep.G][rep]=c
# sub_products = {ProductRep(counter=repdict):1 for G,repdict in group_dict.items()}
# print(f"calling with {sub_products}")
# return ProductRep(counter=sub_products,extra_perm=perm)
# #init is being called twice because ProductRepFromCollection is a subclass
# else:
# return super().__new__(cls)

def canonicalize(self):
"""Returns a canonically ordered rep with order np.arange(self.size()) and the
Expand Down

0 comments on commit 5a85d5a

Please sign in to comment.