Skip to content

Commit

Permalink
name confl
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Feb 25, 2021
1 parent 0174452 commit 25bf5fe
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions emlp/solver/product_sum_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import device_put
import collections,itertools
from functools import lru_cache as cache
from .utils import prod
from .utils import prod as product
import scipy as sp
import scipy.linalg
import functools
Expand Down Expand Up @@ -215,7 +215,7 @@ def distribute_product(reps,extra_perm=None):

axis_sizes = [len(perm) for perm in perms]

order = np.arange(prod(axis_sizes)).reshape(tuple(len(perm) for perm in perms))
order = np.arange(product(axis_sizes)).reshape(tuple(len(perm) for perm in perms))
for i,perm in enumerate(perms):
order = np.swapaxes(np.swapaxes(order,0,i)[perm,...],0,i)
order = order.reshape(-1)
Expand All @@ -236,7 +236,7 @@ def distribute_product(reps,extra_perm=None):
for prod in itertools.product(*[rep.reps.items() for rep in reps]):
rs,cs = zip(*prod)
#import pdb; pdb.set_trace()
prod_rep,canonicalizing_perm = (prod(cs)*reduce(lambda a,b: a*b,rs)).canonicalize()
prod_rep,canonicalizing_perm = (product(cs)*reduce(lambda a,b: a*b,rs)).canonicalize()
#print(f"{rs}:{cs} in distribute yield prod_rep {prod_rep}")
ordered_reps.append(prod_rep)
shape = []
Expand Down Expand Up @@ -350,7 +350,7 @@ def __hash__(self):
def __eq__(self, other): #TODO: worry about non canonical?
return isinstance(other,ProductRep) and self.reps==other.reps# and self.perm == other.perm
def size(self):
return prod([rep.size()**count for rep,count in self.reps.items()])
return product([rep.size()**count for rep,count in self.reps.items()])
@property
def T(self): #TODO: reavaluate if this needs to change the order (it does not)
""" only swaps to adjoint representation, does not reorder elems"""
Expand All @@ -365,7 +365,7 @@ def compute_canonical(rep_cnters,rep_perms):
""" given that rep1_perm and rep2_perm are the canonical orderings for
rep1 and rep2 (ie v[rep1_perm] is in canonical order) computes
the canonical order for rep1 * rep2"""
order = np.arange(prod(len(perm) for perm in rep_perms))
order = np.arange(product(len(perm) for perm in rep_perms))
# First: merge counters
unique_reps = sorted(reduce(lambda a,b: a|b,[cnter.keys() for cnter in rep_cnters]))
merged_cnt = defaultdict(int)
Expand Down Expand Up @@ -552,7 +552,7 @@ class lazy_kron(LinearOperator):

def __init__(self,Ms):
self.Ms = Ms
self.shape = prod([Mi.shape[0] for Mi in Ms]), prod([Mi.shape[1] for Mi in Ms])
self.shape = product([Mi.shape[0] for Mi in Ms]), product([Mi.shape[1] for Mi in Ms])
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')

Expand All @@ -577,7 +577,7 @@ class lazy_kronsum(LinearOperator):

def __init__(self,Ms):
self.Ms = Ms
self.shape = prod([Mi.shape[0] for Mi in Ms]), prod([Mi.shape[1] for Mi in Ms])
self.shape = product([Mi.shape[0] for Mi in Ms]), product([Mi.shape[1] for Mi in Ms])
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')

Expand Down

0 comments on commit 25bf5fe

Please sign in to comment.