Skip to content

Commit

Permalink
printing: Remove unnecessary print overrides for Matrix subclasses
Browse files Browse the repository at this point in the history
Also misc related fixes in the matrices module.

And see sympy/sympy#19899.
  • Loading branch information
skirpichev committed Jan 25, 2021
1 parent befd5fd commit 49a590a
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 48 deletions.
2 changes: 1 addition & 1 deletion diofant/matrices/immutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __hash__(self):
ImmutableDenseMatrix = ImmutableMatrix


class ImmutableSparseMatrix(Basic, SparseMatrixBase):
class ImmutableSparseMatrix(SparseMatrixBase, Basic):
"""Create an immutable version of a sparse matrix.
Examples
Expand Down
2 changes: 1 addition & 1 deletion diofant/matrices/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3901,7 +3901,7 @@ def col_insert(self, pos, mti):
newmat[:, j:] = self[:, i:]
return type(self)(newmat)

def replace(self, F, G):
def replace(self, F, G, exact=False):
"""Replaces Function F in Matrix entries with Function G.
Examples
Expand Down
9 changes: 4 additions & 5 deletions diofant/printing/codeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,10 @@ def _print_Mul(self, expr):
return sign + '*'.join(a_str) + '/(%s)' % '*'.join(b_str)

def _print_not_supported(self, expr):
self._not_supported.add(expr)
try:
self._not_supported.add(expr)
except TypeError:
pass
return self.emptyPrinter(expr)

# The following can not be simply translated into C or Fortran
Expand All @@ -493,9 +496,6 @@ def _print_not_supported(self, expr):
_print_Interval = _print_not_supported
_print_Limit = _print_not_supported
_print_list = _print_not_supported
_print_Matrix = _print_not_supported
_print_ImmutableMatrix = _print_not_supported
_print_MutableDenseMatrix = _print_not_supported
_print_MatrixBase = _print_not_supported
_print_NaN = _print_not_supported
_print_NegativeInfinity = _print_not_supported
Expand All @@ -504,7 +504,6 @@ def _print_not_supported(self, expr):
_print_RootOf = _print_not_supported
_print_RootsOf = _print_not_supported
_print_RootSum = _print_not_supported
_print_SparseMatrix = _print_not_supported
_print_tuple = _print_not_supported
_print_Wild = _print_not_supported
_print_WildFunction = _print_not_supported
2 changes: 0 additions & 2 deletions diofant/printing/latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,6 @@ def _print_MatrixBase(self, expr):
out_str = r'\left' + left_delim + out_str + \
r'\right' + right_delim
return out_str % r'\\'.join(lines)
_print_ImmutableMatrix = _print_MatrixBase
_print_Matrix = _print_MatrixBase

def _print_MatrixSlice(self, expr):
def latexslice(x):
Expand Down
15 changes: 1 addition & 14 deletions diofant/printing/octave.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _print_MatrixBase(self, A):
return '[%s]' % A.table(self, rowstart='', rowend='',
rowsep=';\n', colsep=' ')

def _print_SparseMatrix(self, A):
def _print_SparseMatrixBase(self, A):
from ..matrices import Matrix
L = A.col_list()
# make row vectors of the indices and entries
Expand All @@ -293,19 +293,6 @@ def _print_SparseMatrix(self, A):
return 'sparse(%s, %s, %s, %s, %s)' % (self._print(I), self._print(J),
self._print(AIJ), A.rows, A.cols)

# FIXME: Str/CodePrinter could define each of these to call the _print
# method from higher up the class hierarchy (see _print_NumberSymbol).
# Then subclasses like us would not need to repeat all this.
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
_print_MatrixBase
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_SparseMatrix

def _print_MatrixElement(self, expr):
return self._print(expr.parent) + '(%s, %s)' % (expr.i + 1, expr.j + 1)

Expand Down
2 changes: 0 additions & 2 deletions diofant/printing/pretty/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,6 @@ def _print_MatrixBase(self, e):
D.baseline = D.height()//2
D = prettyForm(*D.parens('[', ']'))
return D
_print_ImmutableMatrix = _print_MatrixBase
_print_Matrix = _print_MatrixBase

def _print_Trace(self, e):
D = self._print(e.arg)
Expand Down
1 change: 1 addition & 0 deletions diofant/printing/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _print_Symbol(self, expr):
if symbol not in self.symbols:
self.symbols.append(symbol)
return StrPrinter._print_Symbol(self, expr)
_print_BaseSymbol = StrPrinter._print_BaseSymbol


def python(expr, **settings):
Expand Down
18 changes: 4 additions & 14 deletions diofant/printing/repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,6 @@ def _print_MatrixBase(self, expr):
l[-1].append(expr[i, j])
return f'{expr.__class__.__name__}({self._print(l)})'

_print_SparseMatrix = \
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
_print_MatrixBase

def _print_BooleanTrue(self, expr):
return 'true'

Expand All @@ -132,16 +122,14 @@ def _print_Float(self, expr):
r = mlib.to_str(expr._mpf_, repr_dps(expr._prec))
return f"{expr.__class__.__name__}('{r}', dps={dps:d})"

def _print_Symbol(self, expr):
def _print_BaseSymbol(self, expr):
d = expr._assumptions.generator
if d == {}:
return f'{expr.__class__.__name__}({self._print(expr.name)})'
else:
attr = [f'{k}={v}' for k, v in d.items()]
return '%s(%s, %s)' % (expr.__class__.__name__,
self._print(expr.name), ', '.join(attr))
_print_Dummy = _print_Symbol
_print_Wild = _print_Symbol

def _print_str(self, expr):
return repr(expr)
Expand All @@ -157,7 +145,9 @@ def _print_WildFunction(self, expr):

def _print_PolynomialRing(self, ring):
return '%s(%s, %s, %s)' % (ring.__class__.__name__,
self._print(ring.domain), self._print(ring.symbols), self._print(ring.order))
self._print(ring.domain),
self._print(ring.symbols),
self._print(ring.order))

def _print_GMPYIntegerRing(self, expr):
return f'{expr.__class__.__name__}()'
Expand Down
9 changes: 0 additions & 9 deletions diofant/printing/str.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,6 @@ def _print_list(self, expr):

def _print_MatrixBase(self, expr):
return expr._format_str(self)
_print_SparseMatrix = \
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
_print_MatrixBase

def _print_MatrixElement(self, expr):
return self._print(expr.parent) + f'[{expr.i}, {expr.j}]'
Expand Down

0 comments on commit 49a590a

Please sign in to comment.