Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

BlockDiagMatrix a class, 2x2 inverses, Tests

BlockDiagMatrices now subclass BlockMatrix

Solution for 2x2 inverse of a blockmatrix baked in though not activated by default

Inverses of identities now identities

Added tests for BlockMatrices
  • Loading branch information...
commit 060f44f01b76229f9e3867dccca3e4c836ed6373 1 parent 4fa1567
@mrocklin authored
View
71 sympy/matrices/blockmatrix.py
@@ -10,6 +10,7 @@
class BlockMatrix(MatrixExpr):
is_BlockMatrix = True
+ is_BlockDiagMatrix = False
def __new__(cls, mat):
if not isinstance(mat, Matrix):
mat = Matrix(mat)
@@ -72,11 +73,20 @@ def eval_transpose(self):
def transpose(self):
return self.eval_transpose()
- def eval_inverse(self):
- # Inverse of size one block matrix is easy
- if len(self.mat.mat)==1:
+ def eval_inverse(self, expand=False):
+ # Inverse of one by one block matrix is easy
+ if self.blockshape==(1,1):
mat = Matrix(1, 1, (Inverse(self.mat[0]), ))
return BlockMatrix(mat)
+ # Inverse of a two by two block matrix is known
+ elif expand and self.blockshape==(2,2):
+ # Cite: The Matrix Cookbook Section 9.1.3
+ A11, A12, A21, A22 = self[0,0], self[0,1], self[1,0], self[1,1]
+ C1 = A11 - A12*Inverse(A22)*A21
+ C2 = A22 - A21*Inverse(A11)*A12
+ mat = Matrix([[Inverse(C1), Inverse(-A11)*A12*Inverse(C2)],
+ [-Inverse(C2)*A21*Inverse(A11), Inverse(C2)]])
+ return BlockMatrix(mat)
else:
raise NotImplementedError()
@@ -101,20 +111,51 @@ def is_Identity(self):
def is_structurally_symmetric(self):
return self.rowblocksizes == self.colblocksizes
-def BlockDiagMatrix(mats):
- data_matrix = eye(len(mats))
- for i, mat in enumerate(mats):
- data_matrix[i,i] = mat
+class BlockDiagMatrix(BlockMatrix):
+ is_BlockDiagMatrix = True
+ def __new__(cls, *mats):
+ data_matrix = eye(len(mats))
+ for i, mat in enumerate(mats):
+ data_matrix[i,i] = mat
+
+ for r in range(len(mats)):
+ for c in range(len(mats)):
+ if r == c:
+ continue
+ n = mats[r].n
+ m = mats[c].m
+ data_matrix[r, c] = ZeroMatrix(n, m)
+
+ shape = Tuple(*sympify(mat.shape))
+ data = Tuple(*data_matrix.mat)
+ obj = Basic.__new__(cls, data, shape, Tuple(*mats))
+ obj.mat = data_matrix
+ return obj
+
+ @property
+ def diag(self):
+ return self.args[2]
- for r in range(len(mats)):
- for c in range(len(mats)):
- if r == c:
- continue
- n = mats[r].n
- m = mats[c].m
- data_matrix[r, c] = ZeroMatrix(n, m)
+ def eval_inverse(self):
+ return BlockDiagMatrix(*[Inverse(mat) for mat in self.diag])
- return BlockMatrix(data_matrix)
+ def _blockmul(self, other):
+ if (other.is_Matrix and other.is_BlockDiagMatrix and
+ self.blockshape[1] == other.blockshape[0] and
+ self.colblocksizes == other.rowblocksizes):
+ return BlockDiagMatrix(*[a*b for a, b in zip(self.diag,other.diag)])
+ else:
+ return BlockMatrix._blockmul(self, other)
+
+ def _blockadd(self, other):
+
+ if (other.is_Matrix and other.is_BlockDiagMatrix and
+ self.blockshape == other.blockshape and
+ self.rowblocksizes == other.rowblocksizes and
+ self.colblocksizes == other.colblocksizes):
+ return BlockDiagMatrix(*[a+b for a, b in zip(self.diag,other.diag)])
+ else:
+ return BlockMatrix._blockadd(self, other)
def block_collapse(expr):
View
7 sympy/matrices/inverse.py
@@ -5,13 +5,13 @@
class Inverse(MatPow):
is_Inverse = True
- def __new__(cls, mat):
+ def __new__(cls, mat, **kwargs):
if not mat.is_Matrix:
return mat**(-1)
try:
- return mat.eval_inverse()
+ return mat.eval_inverse(**kwargs)
except (AttributeError, NotImplementedError):
pass
@@ -24,6 +24,9 @@ def __new__(cls, mat):
if mat.is_Inverse:
return mat.arg
+ if mat.is_Identity:
+ return mat
+
if not mat.is_square:
raise ShapeError("Inverse of non-square matrix %s"%mat)
View
3  sympy/matrices/matexpr.py
@@ -66,8 +66,7 @@ def __rpow__(self, other):
@_sympifyit('other', NotImplemented)
@call_highest_priority('__rdiv__')
def __div__(self, other):
- raise NotImplementedError()
- return MatMul(self, Pow(other, S.NegativeOne))
+ return MatMul(self, other**S.NegativeOne)
@_sympifyit('other', NotImplemented)
@call_highest_priority('__div__')
def __rdiv__(self, other):
View
40 sympy/matrices/tests/test_matrix_exprs.py
@@ -43,7 +43,45 @@ def test_matexpr():
assert (x*A).__class__ == MatMul
assert 2*A - A - A == S.Zero
+def test_BlockMatrix():
+ n,m,l,k = symbols('n m l k', integer=True)
+ A = MatrixSymbol('A', n, m)
+ B = MatrixSymbol('B', n, k)
+ C = MatrixSymbol('C', l, m)
+ D = MatrixSymbol('D', l, k)
+ X = BlockMatrix(Matrix([[A,B],[C,D]]))
+
+ assert X.shape == (l+n, k+m)
+ assert Transpose(X) == BlockMatrix(Matrix([[A.T, C.T], [B.T, D.T]]))
+ assert Transpose(X).shape == X.shape[::-1]
+ assert X.blockshape == (2,2)
+
+ E = MatrixSymbol('E', m, 1)
+ F = MatrixSymbol('F', k, 1)
+
+ Y = BlockMatrix(Matrix([[E], [F]]))
+
+ assert (X*Y).shape = (l+n, 1)
+ assert block_collapse(X*Y)[0,0] == A*E + B*F
+ assert block_collapse(X*Y)[1,0] == C*E + D*F
+ assert Transpose(block_collapse(Transpose(X*Y))) == block_collapse(X*Y)
+
+def test_BlockDiagMatrix():
+ n,m,l = symbols('n m l', integer=True)
+ A = MatrixSymbol('A', n, n)
+ B = MatrixSymbol('B', m, m)
+ C = MatrixSymbol('C', l, l)
+
+ X = BlockDiagMatrix(A,B,C)
+
+ assert X[1,1] == B
+ assert X.shape == (n+m+l, n+m+l)
+ assert all(X[i,j].is_ZeroMatrix if i!=j else X[i,j] in [A,B,C]
+ for i in range(3) for j in range(3))
+
+ assert block_collapse(X.I * X).is_Identity
+
+ assert block_collapse(X*X) == BlockDiagMatrix(A**2, B**2, C**2)
- pass
Please sign in to comment.
Something went wrong with that request. Please try again.