Skip to content

Commit

Permalink
Make assumptions about indexed symbols (sympy#13992)
Browse files Browse the repository at this point in the history
``Indexed`` and ``IndexedBase`` now accepts a new keyword argument
allowing the specification of assumptions on indexed symbols.
  • Loading branch information
bsamseth committed Feb 26, 2019
1 parent b68caad commit 81e3a14
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions sympy/tensor/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@

from __future__ import print_function, division

from sympy.core.assumptions import StdFactKB
from sympy.core import Expr, Tuple, Symbol, sympify, S
from sympy.core.compatibility import (is_sequence, string_types, NotIterable,
Iterable)
from sympy.core.logic import fuzzy_bool
from sympy.core.sympify import _sympify
from sympy.functions.special.tensor_functions import KroneckerDelta

Expand Down Expand Up @@ -137,7 +139,7 @@ class Indexed(Expr):
is_symbol = True
is_Atom = True

def __new__(cls, base, *args, **kw_args):
def __new__(cls, base, *args, assumptions=None, **kw_args):
from sympy.utilities.misc import filldedent
from sympy.tensor.array.ndim_array import NDimArray
from sympy.matrices.matrices import MatrixBase
Expand All @@ -156,7 +158,15 @@ def __new__(cls, base, *args, **kw_args):
else:
return base[args]

return Expr.__new__(cls, base, *args, **kw_args)

# Apply assumptions, with same logic as for Symbol.__new__.
obj = Expr.__new__(cls, base, *args, **kw_args)
assumptions = dict() if assumptions is None else assumptions
is_commutative = fuzzy_bool(assumptions.get('commutative', True))
assumptions['commutative'] = is_commutative
obj._assumptions = StdFactKB(assumptions)

return obj

@property
def name(self):
Expand Down Expand Up @@ -377,7 +387,7 @@ class IndexedBase(Expr, NotIterable):
is_symbol = True
is_Atom = True

def __new__(cls, label, shape=None, **kw_args):
def __new__(cls, label, shape=None, assumptions=None, **kw_args):
from sympy import MatrixBase, NDimArray

if isinstance(label, string_types):
Expand Down Expand Up @@ -407,6 +417,7 @@ def __new__(cls, label, shape=None, **kw_args):
obj._offset = offset
obj._strides = strides
obj._name = str(label)
obj._assumptions = dict() if assumptions is None else assumptions
return obj

@property
Expand All @@ -418,11 +429,11 @@ def __getitem__(self, indices, **kw_args):
# Special case needed because M[*my_tuple] is a syntax error.
if self.shape and len(self.shape) != len(indices):
raise IndexException("Rank mismatch.")
return Indexed(self, *indices, **kw_args)
return Indexed(self, *indices, assumptions=self._assumptions, **kw_args)
else:
if self.shape and len(self.shape) != 1:
raise IndexException("Rank mismatch.")
return Indexed(self, indices, **kw_args)
return Indexed(self, indices, assumptions=self._assumptions, **kw_args)

@property
def shape(self):
Expand Down

0 comments on commit 81e3a14

Please sign in to comment.