Skip to content

Commit

Permalink
python 3.7 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Feb 25, 2021
1 parent 610c36d commit 0174452
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 51 deletions.
6 changes: 3 additions & 3 deletions emlp/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def gated(sumrep):
return sumrep+sum([Scalar(rep.G) for rep in sumrep.reps if rep!=Scalar and not rep.is_regular])

@cache
@cache(maxsize=None)
def gate_indices(sumrep): #TODO: add regular
""" Indices for scalars, and also additional scalar gates
added by gated(sumrep)"""
Expand All @@ -31,7 +31,7 @@ def gate_indices(sumrep): #TODO: add regular
i+=rep.size()
return indices

@cache
@cache(maxsize=None)
def scalar_mask(sumrep):
channels = sumrep.size()
mask = np.ones(channels)>0
Expand All @@ -41,7 +41,7 @@ def scalar_mask(sumrep):
i+=rep.size()
return mask

@cache
@cache(maxsize=None)
def regular_mask(sumrep):
channels = sumrep.size()
mask = np.ones(channels)<0
Expand Down
17 changes: 9 additions & 8 deletions emlp/solver/product_sum_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import device_put
import collections,itertools
from functools import lru_cache as cache
from .utils import disk_cache
from .utils import prod
import scipy as sp
import scipy.linalg
import functools
Expand All @@ -22,6 +22,7 @@
import objax



class SumRep(Rep):
concrete=True
atomic=False
Expand Down Expand Up @@ -214,7 +215,7 @@ def distribute_product(reps,extra_perm=None):

axis_sizes = [len(perm) for perm in perms]

order = np.arange(math.prod(axis_sizes)).reshape(tuple(len(perm) for perm in perms))
order = np.arange(prod(axis_sizes)).reshape(tuple(len(perm) for perm in perms))
for i,perm in enumerate(perms):
order = np.swapaxes(np.swapaxes(order,0,i)[perm,...],0,i)
order = order.reshape(-1)
Expand All @@ -235,7 +236,7 @@ def distribute_product(reps,extra_perm=None):
for prod in itertools.product(*[rep.reps.items() for rep in reps]):
rs,cs = zip(*prod)
#import pdb; pdb.set_trace()
prod_rep,canonicalizing_perm = (math.prod(cs)*reduce(lambda a,b: a*b,rs)).canonicalize()
prod_rep,canonicalizing_perm = (prod(cs)*reduce(lambda a,b: a*b,rs)).canonicalize()
#print(f"{rs}:{cs} in distribute yield prod_rep {prod_rep}")
ordered_reps.append(prod_rep)
shape = []
Expand All @@ -254,7 +255,7 @@ def distribute_product(reps,extra_perm=None):
return SumRep(*ordered_reps,extra_perm=total_perm,viz_shape_hint=axis_sizes)


@cache()
@cache(maxsize=None)
def rep_permutation(repsizes_all):
"""Permutation from block ordering to flattened ordering"""
size_cumsums = [np.cumsum([0] + [size for size in repsizes]) for repsizes in repsizes_all]
Expand Down Expand Up @@ -349,7 +350,7 @@ def __hash__(self):
def __eq__(self, other): #TODO: worry about non canonical?
return isinstance(other,ProductRep) and self.reps==other.reps# and self.perm == other.perm
def size(self):
return math.prod([rep.size()**count for rep,count in self.reps.items()])
return prod([rep.size()**count for rep,count in self.reps.items()])
@property
def T(self): #TODO: reavaluate if this needs to change the order (it does not)
""" only swaps to adjoint representation, does not reorder elems"""
Expand All @@ -364,7 +365,7 @@ def compute_canonical(rep_cnters,rep_perms):
""" given that rep1_perm and rep2_perm are the canonical orderings for
rep1 and rep2 (ie v[rep1_perm] is in canonical order) computes
the canonical order for rep1 * rep2"""
order = np.arange(math.prod(len(perm) for perm in rep_perms))
order = np.arange(prod(len(perm) for perm in rep_perms))
# First: merge counters
unique_reps = sorted(reduce(lambda a,b: a|b,[cnter.keys() for cnter in rep_cnters]))
merged_cnt = defaultdict(int)
Expand Down Expand Up @@ -551,7 +552,7 @@ class lazy_kron(LinearOperator):

def __init__(self,Ms):
self.Ms = Ms
self.shape = math.prod([Mi.shape[0] for Mi in Ms]), math.prod([Mi.shape[1] for Mi in Ms])
self.shape = prod([Mi.shape[0] for Mi in Ms]), prod([Mi.shape[1] for Mi in Ms])
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')

Expand All @@ -576,7 +577,7 @@ class lazy_kronsum(LinearOperator):

def __init__(self,Ms):
self.Ms = Ms
self.shape = math.prod([Mi.shape[0] for Mi in Ms]), math.prod([Mi.shape[1] for Mi in Ms])
self.shape = prod([Mi.shape[0] for Mi in Ms]), prod([Mi.shape[1] for Mi in Ms])
#self.dtype=Ms[0].dtype
self.dtype=jnp.dtype('float32')

Expand Down
8 changes: 4 additions & 4 deletions emlp/solver/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import optax
import collections,itertools
from functools import lru_cache as cache
from emlp.solver.utils import disk_cache,ltqdm
from emlp.solver.linear_operator_jax import LinearOperator
from .utils import ltqdm,prod
from .linear_operator_jax import LinearOperator
import scipy as sp
import scipy.linalg
import functools
Expand Down Expand Up @@ -94,7 +94,7 @@ def symmetric_basis(self):
logging.info(f"Solving basis for {self}"+(f", for G={self.G}" if hasattr(self,"G") else ""))
#if isinstance(group,Trivial): return np.eye(size(rank,group.d))
C_lazy = canon_rep.constraint_matrix()
if math.prod(C_lazy.shape)>3e7: #Too large to use SVD
if prod(C_lazy.shape)>3e7: #Too large to use SVD
result = krylov_constraint_solve(C_lazy)
else:
C_dense = C_lazy@jnp.eye(C_lazy.shape[-1])
Expand Down Expand Up @@ -436,5 +436,5 @@ def lazy_projection(params,x): # (r,), (*c) #TODO: find out why backwards of thi

@jit
def mul_part(bparams,x,bids):
b = math.prod(x.shape[:-1])
b = prod(x.shape[:-1])
return (bparams@x[...,bids].T.reshape(bparams.shape[-1],-1)).reshape(-1,b).T
74 changes: 38 additions & 36 deletions emlp/solver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from tqdm.auto import tqdm
tqdm.get_lock().locks = []

prod = lambda c: reduce(lambda a,b:a*b,c)

# class TqdmToLogger(io.StringIO):
# """
# Output stream for TQDM which will output to logger module instead of
Expand Down Expand Up @@ -35,47 +37,47 @@ def ltqdm(*args,level='info',**kwargs):
else:
return args[0]

class NoCache(object):
def __enter__(self):
self.settings = CacheSettings.disk_caching
CacheSettings.disk_caching=False
return self
def __exit__(self, *exc):
CacheSettings.disk_caching=self.settings
return False
# class NoCache(object):
# def __enter__(self):
# self.settings = CacheSettings.disk_caching
# CacheSettings.disk_caching=False
# return self
# def __exit__(self, *exc):
# CacheSettings.disk_caching=self.settings
# return False

class CacheSettings(object):
memory_caching=True
disk_caching=True
# class CacheSettings(object):
# memory_caching=True
# disk_caching=True

def make_key(args, kwds, kwd_mark = (object(),)):
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
return key
# def make_key(args, kwds, kwd_mark = (object(),)):
# key = args
# if kwds:
# key += kwd_mark
# for item in kwds.items():
# key += item
# return key



def disk_cache(file_name):
try:
with open(file_name, 'rb') as f:
cache = pickle.load(f)
except (IOError, ValueError):
cache = {}
# def disk_cache(file_name):
# try:
# with open(file_name, 'rb') as f:
# cache = pickle.load(f)
# except (IOError, ValueError):
# cache = {}

atexit.register(lambda: pickle.dump(cache, open(file_name, 'wb')))
# atexit.register(lambda: pickle.dump(cache, open(file_name, 'wb')))

def decorator(func):
def new_func(*args,**kwargs):
if not CacheSettings.disk_caching: return func(*args,**kwargs)
key = make_key(args,kwargs)
if key not in cache:
logging.info(f"{key} cache miss")
cache[key] = func(*args,**kwargs)
logging.debug(f"{key} cache hit")
return cache[key]
return new_func
# def decorator(func):
# def new_func(*args,**kwargs):
# if not CacheSettings.disk_caching: return func(*args,**kwargs)
# key = make_key(args,kwargs)
# if key not in cache:
# logging.info(f"{key} cache miss")
# cache[key] = func(*args,**kwargs)
# logging.debug(f"{key} cache hit")
# return cache[key]
# return new_func

return decorator
# return decorator

0 comments on commit 0174452

Please sign in to comment.