Skip to content

Commit

Permalink
import error?
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Feb 26, 2021
1 parent 1f9f0da commit b607c1f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions emlp/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# from .product_sum_reps import SumRep,DeferredSumRep,ProductRep,DeferredProductRep,DirectProduct
# __all__=["SumRep","DeferredSumRep","ProductRep","DeferredProductRep","DirectProduct"]
2 changes: 1 addition & 1 deletion emlp/solver/linear_operator_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def to_dense(self):
""" Default implementation of to_dense which produces the dense
matrix corresponding to the given lazy matrix. Defaults to
multiplying by the identity """
return self@jnp.eye(self.shape[-1])
return self@np.eye(self.shape[-1])


class _CustomLinearOperator(LinearOperator):
Expand Down
4 changes: 2 additions & 2 deletions emlp/solver/linear_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ConcatLazy(LinearOperator):
a collection of matrices Ms along axis=0 """
def __init__(self,Ms):
self.Ms = Ms
assert all(M.shape==Ms.shape[0] for M in Ms),\
assert all(M.shape[0]==Ms[0].shape[0] for M in Ms),\
f"Trying to concatenate matrices of different sizes {[M.shape for M in Ms]}"
self.shape = (sum(M.shape[0] for M in Ms),Ms[0].shape[1])

Expand All @@ -116,7 +116,7 @@ def _rmatmat(self,V):
return sum([self.Ms[i].T@Vs[i] for i in range(len(self.Ms))])
def to_dense(self):
dense_Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
return jnp.concatenate(dense_ms,axis=0)
return jnp.concatenate(dense_Ms,axis=0)

class LazyDirectSum(LinearOperator):
def __init__(self,Ms,multiplicities=None):
Expand Down
2 changes: 1 addition & 1 deletion emlp/solver/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import functools
import random
import logging
import emlp.solver
import math
from jax.ops import index, index_add, index_update
import matplotlib.pyplot as plt
from collections import Counter
from functools import reduce
import emlp.solver
#TODO: add rep,v = flatten({'Scalar':..., 'Vector':...,}), to_dict(rep,vector) returns {'Scalar':..., 'Vector':...,}
#TODO and simpler rep = flatten({Scalar:2,Vector:10,...}),
# Do we even want + operator to implement non canonical orderings?
Expand Down

0 comments on commit b607c1f

Please sign in to comment.