/
matmul.py
244 lines (192 loc) · 6.92 KB
/
matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
from ...core import Add, Expr, Mul, Number
from ...core.logic import _fuzzy_group
from ...core.strategies import do_one, exhaust, flatten, rm_id, typed, unpack
from ...core.sympify import sympify
from ...functions import adjoint
from ..matrices import MatrixBase, ShapeError
from .matexpr import Identity, MatrixExpr, ZeroMatrix
from .transpose import transpose
class MatMul(MatrixExpr):
"""
A product of matrix expressions
Examples
========
>>> A = MatrixSymbol('A', 5, 4)
>>> B = MatrixSymbol('B', 4, 3)
>>> C = MatrixSymbol('C', 3, 6)
>>> MatMul(A, B, C)
A*B*C
"""
is_MatMul = True
def _eval_is_commutative(self):
return _fuzzy_group((a.is_commutative for a in self.args),
quick_exit=True)
def __new__(cls, *args, **kwargs):
check = kwargs.get('check', True)
args = list(map(sympify, args))
obj = Expr.__new__(cls, *args)
_, matrices = obj.as_coeff_matrices()
if check:
validate(*matrices)
return obj
@property
def shape(self):
matrices = [arg for arg in self.args if arg.is_Matrix]
return matrices[0].rows, matrices[-1].cols
def _entry(self, i, j, expand=True):
coeff, matrices = self.as_coeff_matrices()
if len(matrices) == 1: # situation like 2*X, matmul is just X
return coeff * matrices[0][i, j]
head, tail = matrices[0], matrices[1:]
X = head
Y = MatMul(*tail)
from ...concrete import Sum
from ...core import Dummy
from .. import ImmutableMatrix
k = Dummy('k', integer=True)
if X.has(ImmutableMatrix) or Y.has(ImmutableMatrix):
return coeff*Add(*[X[i, k]*Y[k, j] for k in range(X.cols)])
result = Sum(coeff*X[i, k]*Y[k, j], (k, 0, X.cols - 1))
if not X.cols.is_number:
# Don't waste time in result.doit() if the sum bounds are symbolic
expand = False
return result.doit() if expand else result
def as_coeff_matrices(self):
scalars = [x for x in self.args if not x.is_Matrix]
matrices = [x for x in self.args if x.is_Matrix]
coeff = Mul(*scalars)
return coeff, matrices
def as_coeff_mmul(self):
coeff, matrices = self.as_coeff_matrices()
return coeff, MatMul(*matrices)
def _eval_transpose(self):
return MatMul(*[transpose(arg) for arg in self.args[::-1]]).doit()
def _eval_adjoint(self):
return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
def _eval_trace(self):
factor, mmul = self.as_coeff_mmul()
if factor != 1:
from .trace import trace
return factor * trace(mmul.doit())
else:
raise NotImplementedError("Can't simplify any further")
def _eval_determinant(self):
from .determinant import Determinant
factor, matrices = self.as_coeff_matrices()
square_matrices = only_squares(*matrices)
return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
def _eval_inverse(self):
try:
return MatMul(*[
arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
for arg in self.args[::-1]]).doit()
except ShapeError:
from .inverse import Inverse
return Inverse(self)
def doit(self, **kwargs):
deep = kwargs.get('deep', True)
if deep:
args = [arg.doit(**kwargs) for arg in self.args]
else:
args = self.args
return canonicalize(MatMul(*args))
def validate(*matrices):
"""Checks for valid shapes for args of MatMul."""
for i in range(len(matrices)-1):
A, B = matrices[i:i+2]
if A.cols != B.rows:
raise ShapeError(f'Matrices {A} and {B} are not aligned')
# Rules
def newmul(*args):
if args[0] == 1:
args = args[1:]
return MatMul(*args)
def any_zeros(mul):
if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
for arg in mul.args):
matrices = [arg for arg in mul.args if arg.is_Matrix]
return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
return mul
def merge_explicit(matmul):
"""Merge explicit MatrixBase arguments
>>> A = MatrixSymbol('A', 2, 2)
>>> B = Matrix([[1, 1], [1, 1]])
>>> C = Matrix([[1, 2], [3, 4]])
>>> X = MatMul(A, B, C)
>>> pprint(X, use_unicode=False)
[1 1] [1 2]
A*[ ]*[ ]
[1 1] [3 4]
>>> pprint(merge_explicit(X), use_unicode=False)
[4 6]
A*[ ]
[4 6]
>>> X = MatMul(B, A, C)
>>> pprint(X, use_unicode=False)
[1 1] [1 2]
[ ]*A*[ ]
[1 1] [3 4]
>>> pprint(merge_explicit(X), use_unicode=False)
[1 1] [1 2]
[ ]*A*[ ]
[1 1] [3 4]
"""
if not any(isinstance(arg, MatrixBase) for arg in matmul.args):
return matmul
newargs = []
last = matmul.args[0]
for arg in matmul.args[1:]:
if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)):
last = last * arg
else:
newargs.append(last)
last = arg
newargs.append(last)
return MatMul(*newargs)
def xxinv(mul):
"""X * X.inverse() -> Identity."""
factor, matrices = mul.as_coeff_matrices()
for i, (X, Y) in enumerate(zip(matrices[:-1], matrices[1:])):
try:
if X.is_square and Y.is_square and X == Y.inverse():
I = Identity(X.rows)
return newmul(factor, *(matrices[:i] + [I] + matrices[i+2:]))
except ValueError: # Y might not be invertible
pass
return mul
def remove_ids(mul):
"""Remove Identities from a MatMul
This is a modified version of diofant.core.strategies.rm_id.
This is necesssary because MatMul may contain both MatrixExprs and Exprs
as args.
See Also
========
diofant.core.strategies.rm_id
"""
# Separate Exprs from MatrixExprs in args
factor, mmul = mul.as_coeff_mmul()
# Apply standard rm_id for MatMuls
result = rm_id(lambda x: x.is_Identity is True)(mmul)
if result != mmul:
return newmul(factor, *result.args) # Recombine and return
else:
return mul
def factor_in_front(mul):
factor, matrices = mul.as_coeff_matrices()
if factor != 1:
return newmul(factor, *matrices)
return mul
rules = (any_zeros, remove_ids, xxinv, unpack, rm_id(lambda x: x == 1),
merge_explicit, factor_in_front, flatten)
canonicalize = exhaust(typed({MatMul: do_one(rules)}))
def only_squares(*matrices):
"""Factor matrices only if they are square."""
if matrices[0].rows != matrices[-1].cols:
raise RuntimeError('Invalid matrices being multiplied')
out = []
start = 0
for i, M in enumerate(matrices):
if M.cols == matrices[start].rows:
out.append(MatMul(*matrices[start:i+1]).doit())
start = i+1
return out