Skip to content

Commit

Permalink
Merge pull request #16076 from kronbichler/full_matrix_inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Oct 3, 2023
2 parents 93b1f88 + 6a119d5 commit 9d46a3b
Showing 1 changed file with 132 additions and 173 deletions.
305 changes: 132 additions & 173 deletions include/deal.II/lac/full_matrix.templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -1342,182 +1342,141 @@ FullMatrix<number>::invert(const FullMatrix<number2> &M)
Assert(this->n_rows() == M.n_rows(),
ExcDimensionMismatch(this->n_rows(), M.n_rows()));

if (PointerComparison::equal(&M, this))
switch (this->n_cols())
{
// avoid overwriting source
// by destination matrix:
const FullMatrix<number> M2 = *this;
invert(M2);
}
else
switch (this->n_cols())
{
case 1:
(*this)(0, 0) = number2(1.0) / M(0, 0);
case 1:
(*this)(0, 0) = number2(1.0) / M(0, 0);
return;
case 2:
{
const number2 M00 = M(0, 0);
const number2 M01 = M(0, 1);
const number2 M10 = M(1, 0);
const number2 M11 = M(1, 1);
const number2 t4 = number2(1.0) / (M00 * M11 - M01 * M10);
(*this)(0, 0) = M11 * t4;
(*this)(0, 1) = -M01 * t4;
(*this)(1, 0) = -M10 * t4;
(*this)(1, 1) = M00 * t4;
return;
};

case 3:
{
const number2 M00 = M(0, 0);
const number2 M01 = M(0, 1);
const number2 M02 = M(0, 2);
const number2 M10 = M(1, 0);
const number2 M11 = M(1, 1);
const number2 M12 = M(1, 2);
const number2 M20 = M(2, 0);
const number2 M21 = M(2, 1);
const number2 M22 = M(2, 2);
const number2 t00 = M11 * M22 - M12 * M21;
const number2 t10 = M12 * M20 - M10 * M22;
const number2 t20 = M10 * M21 - M11 * M20;
const number2 inv_det =
number2(1.0) / (M00 * t00 + M01 * t10 + M02 * t20);
(*this)(0, 0) = t00 * inv_det;
(*this)(0, 1) = (M02 * M21 - M01 * M22) * inv_det;
(*this)(0, 2) = (M01 * M12 - M02 * M11) * inv_det;
(*this)(1, 0) = t10 * inv_det;
(*this)(1, 1) = (M00 * M22 - M02 * M20) * inv_det;
(*this)(1, 2) = (M02 * M10 - M00 * M12) * inv_det;
(*this)(2, 0) = t20 * inv_det;
(*this)(2, 1) = (M01 * M20 - M00 * M21) * inv_det;
(*this)(2, 2) = (M00 * M11 - M01 * M10) * inv_det;
return;
case 2:
// this is Maple output,
// thus a bit unstructured
{
const number2 t4 =
number2(1.0) / (M(0, 0) * M(1, 1) - M(0, 1) * M(1, 0));
(*this)(0, 0) = M(1, 1) * t4;
(*this)(0, 1) = -M(0, 1) * t4;
(*this)(1, 0) = -M(1, 0) * t4;
(*this)(1, 1) = M(0, 0) * t4;
return;
};

case 3:
{
const number2 t4 = M(0, 0) * M(1, 1), t6 = M(0, 0) * M(1, 2),
t8 = M(0, 1) * M(1, 0), t00 = M(0, 2) * M(1, 0),
t01 = M(0, 1) * M(2, 0), t04 = M(0, 2) * M(2, 0),
t07 = number2(1.0) /
(t4 * M(2, 2) - t6 * M(2, 1) - t8 * M(2, 2) +
t00 * M(2, 1) + t01 * M(1, 2) - t04 * M(1, 1));
(*this)(0, 0) = (M(1, 1) * M(2, 2) - M(1, 2) * M(2, 1)) * t07;
(*this)(0, 1) = -(M(0, 1) * M(2, 2) - M(0, 2) * M(2, 1)) * t07;
(*this)(0, 2) = -(-M(0, 1) * M(1, 2) + M(0, 2) * M(1, 1)) * t07;
(*this)(1, 0) = -(M(1, 0) * M(2, 2) - M(1, 2) * M(2, 0)) * t07;
(*this)(1, 1) = (M(0, 0) * M(2, 2) - t04) * t07;
(*this)(1, 2) = -(t6 - t00) * t07;
(*this)(2, 0) = -(-M(1, 0) * M(2, 1) + M(1, 1) * M(2, 0)) * t07;
(*this)(2, 1) = -(M(0, 0) * M(2, 1) - t01) * t07;
(*this)(2, 2) = (t4 - t8) * t07;
return;
};

case 4:
{
// with (linalg);
// a:=matrix(4,4);
// evalm(a);
// ai:=inverse(a);
// readlib(C);
// C(ai,optimized,filename=x4);

const number2 t14 = M(0, 0) * M(1, 1);
const number2 t15 = M(2, 2) * M(3, 3);
const number2 t17 = M(2, 3) * M(3, 2);
const number2 t19 = M(0, 0) * M(2, 1);
const number2 t20 = M(1, 2) * M(3, 3);
const number2 t22 = M(1, 3) * M(3, 2);
const number2 t24 = M(0, 0) * M(3, 1);
const number2 t25 = M(1, 2) * M(2, 3);
const number2 t27 = M(1, 3) * M(2, 2);
const number2 t29 = M(1, 0) * M(0, 1);
const number2 t32 = M(1, 0) * M(2, 1);
const number2 t33 = M(0, 2) * M(3, 3);
const number2 t35 = M(0, 3) * M(3, 2);
const number2 t37 = M(1, 0) * M(3, 1);
const number2 t38 = M(0, 2) * M(2, 3);
const number2 t40 = M(0, 3) * M(2, 2);
const number2 t42 = t14 * t15 - t14 * t17 - t19 * t20 + t19 * t22 +
t24 * t25 - t24 * t27 - t29 * t15 + t29 * t17 +
t32 * t33 - t32 * t35 - t37 * t38 + t37 * t40;
const number2 t43 = M(2, 0) * M(0, 1);
const number2 t46 = M(2, 0) * M(1, 1);
const number2 t49 = M(2, 0) * M(3, 1);
const number2 t50 = M(0, 2) * M(1, 3);
const number2 t52 = M(0, 3) * M(1, 2);
const number2 t54 = M(3, 0) * M(0, 1);
const number2 t57 = M(3, 0) * M(1, 1);
const number2 t60 = M(3, 0) * M(2, 1);
const number2 t63 = t43 * t20 - t43 * t22 - t46 * t33 + t46 * t35 +
t49 * t50 - t49 * t52 - t54 * t25 + t54 * t27 +
t57 * t38 - t57 * t40 - t60 * t50 + t60 * t52;
const number2 t65 = number2(1.) / (t42 + t63);
const number2 t71 = M(0, 2) * M(2, 1);
const number2 t73 = M(0, 3) * M(2, 1);
const number2 t75 = M(0, 2) * M(3, 1);
const number2 t77 = M(0, 3) * M(3, 1);
const number2 t81 = M(0, 1) * M(1, 2);
const number2 t83 = M(0, 1) * M(1, 3);
const number2 t85 = M(0, 2) * M(1, 1);
const number2 t87 = M(0, 3) * M(1, 1);
const number2 t101 = M(1, 0) * M(2, 2);
const number2 t103 = M(1, 0) * M(2, 3);
const number2 t105 = M(2, 0) * M(1, 2);
const number2 t107 = M(2, 0) * M(1, 3);
const number2 t109 = M(3, 0) * M(1, 2);
const number2 t111 = M(3, 0) * M(1, 3);
const number2 t115 = M(0, 0) * M(2, 2);
const number2 t117 = M(0, 0) * M(2, 3);
const number2 t119 = M(2, 0) * M(0, 2);
const number2 t121 = M(2, 0) * M(0, 3);
const number2 t123 = M(3, 0) * M(0, 2);
const number2 t125 = M(3, 0) * M(0, 3);
const number2 t129 = M(0, 0) * M(1, 2);
const number2 t131 = M(0, 0) * M(1, 3);
const number2 t133 = M(1, 0) * M(0, 2);
const number2 t135 = M(1, 0) * M(0, 3);
(*this)(0, 0) =
(M(1, 1) * M(2, 2) * M(3, 3) - M(1, 1) * M(2, 3) * M(3, 2) -
M(2, 1) * M(1, 2) * M(3, 3) + M(2, 1) * M(1, 3) * M(3, 2) +
M(3, 1) * M(1, 2) * M(2, 3) - M(3, 1) * M(1, 3) * M(2, 2)) *
t65;
(*this)(0, 1) =
-(M(0, 1) * M(2, 2) * M(3, 3) - M(0, 1) * M(2, 3) * M(3, 2) -
t71 * M(3, 3) + t73 * M(3, 2) + t75 * M(2, 3) - t77 * M(2, 2)) *
t65;
(*this)(0, 2) = (t81 * M(3, 3) - t83 * M(3, 2) - t85 * M(3, 3) +
t87 * M(3, 2) + t75 * M(1, 3) - t77 * M(1, 2)) *
t65;
(*this)(0, 3) = -(t81 * M(2, 3) - t83 * M(2, 2) - t85 * M(2, 3) +
t87 * M(2, 2) + t71 * M(1, 3) - t73 * M(1, 2)) *
t65;
(*this)(1, 0) =
-(t101 * M(3, 3) - t103 * M(3, 2) - t105 * M(3, 3) +
t107 * M(3, 2) + t109 * M(2, 3) - t111 * M(2, 2)) *
t65;
(*this)(1, 1) = (t115 * M(3, 3) - t117 * M(3, 2) - t119 * M(3, 3) +
t121 * M(3, 2) + t123 * M(2, 3) - t125 * M(2, 2)) *
t65;
(*this)(1, 2) =
-(t129 * M(3, 3) - t131 * M(3, 2) - t133 * M(3, 3) +
t135 * M(3, 2) + t123 * M(1, 3) - t125 * M(1, 2)) *
t65;
(*this)(1, 3) = (t129 * M(2, 3) - t131 * M(2, 2) - t133 * M(2, 3) +
t135 * M(2, 2) + t119 * M(1, 3) - t121 * M(1, 2)) *
t65;
(*this)(2, 0) = (t32 * M(3, 3) - t103 * M(3, 1) - t46 * M(3, 3) +
t107 * M(3, 1) + t57 * M(2, 3) - t111 * M(2, 1)) *
t65;
(*this)(2, 1) = -(t19 * M(3, 3) - t117 * M(3, 1) - t43 * M(3, 3) +
t121 * M(3, 1) + t54 * M(2, 3) - t125 * M(2, 1)) *
t65;
(*this)(2, 2) = (t14 * M(3, 3) - t131 * M(3, 1) - t29 * M(3, 3) +
t135 * M(3, 1) + t54 * M(1, 3) - t125 * M(1, 1)) *
t65;
(*this)(2, 3) = -(t14 * M(2, 3) - t131 * M(2, 1) - t29 * M(2, 3) +
t135 * M(2, 1) + t43 * M(1, 3) - t121 * M(1, 1)) *
t65;
(*this)(3, 0) = -(t32 * M(3, 2) - t101 * M(3, 1) - t46 * M(3, 2) +
t105 * M(3, 1) + t57 * M(2, 2) - t109 * M(2, 1)) *
t65;
(*this)(3, 1) = (t19 * M(3, 2) - t115 * M(3, 1) - t43 * M(3, 2) +
t119 * M(3, 1) + t54 * M(2, 2) - t123 * M(2, 1)) *
t65;
(*this)(3, 2) = -(t14 * M(3, 2) - t129 * M(3, 1) - t29 * M(3, 2) +
t133 * M(3, 1) + t54 * M(1, 2) - t123 * M(1, 1)) *
t65;
(*this)(3, 3) = (t14 * M(2, 2) - t129 * M(2, 1) - t29 * M(2, 2) +
t133 * M(2, 1) + t43 * M(1, 2) - t119 * M(1, 1)) *
t65;

break;
}


default:
// if no inversion is
// hardcoded, fall back
// to use the
// Gauss-Jordan algorithm
};

case 4:
{
// Initially derived from the following maple script
//
// with (linalg);
// a:=matrix(4,4);
// evalm(a);
// ai:=inverse(a);
// readlib(C);
// C(ai,optimized,filename=x4);
//
// but then combined re-occurring terms via distributive law in an
// FMA-friendly format
const number2 M00 = M(0, 0);
const number2 M01 = M(0, 1);
const number2 M02 = M(0, 2);
const number2 M03 = M(0, 3);
const number2 M10 = M(1, 0);
const number2 M11 = M(1, 1);
const number2 M12 = M(1, 2);
const number2 M13 = M(1, 3);
const number2 M20 = M(2, 0);
const number2 M21 = M(2, 1);
const number2 M22 = M(2, 2);
const number2 M23 = M(2, 3);
const number2 M30 = M(3, 0);
const number2 M31 = M(3, 1);
const number2 M32 = M(3, 2);
const number2 M33 = M(3, 3);

const number2 t14 = M00 * M11 - M10 * M01;
const number2 t15 = M22 * M33 - M23 * M32;
const number2 t19 = M00 * M21 - M20 * M01;
const number2 t20 = M12 * M33 - M13 * M32;
const number2 t24 = M00 * M31 - M30 * M01;
const number2 t25 = M12 * M23 - M13 * M22;
const number2 t32 = M10 * M21 - M20 * M11;
const number2 t33 = M02 * M33 - M03 * M32;
const number2 t37 = M10 * M31 - M30 * M11;
const number2 t38 = M02 * M23 - M03 * M22;
const number2 t49 = M20 * M31 - M30 * M21;
const number2 t50 = M02 * M13 - M03 * M12;
const number2 det = t14 * t15 - t19 * t20 + t24 * t25 + t32 * t33 -
t37 * t38 + t49 * t50;
const number2 inv_det = number2(1.0) / det;
const number2 t81 = M01 * M12 - M02 * M11;
const number2 t83 = M01 * M13 - M03 * M11;
const number2 t93 = M01 * M22 - M02 * M21;
const number2 t95 = M11 * M23 - M13 * M21;
const number2 t97 = M01 * M23 - M03 * M21;
const number2 t99 = M11 * M22 - M12 * M21;
const number2 t101 = M10 * M22 - M20 * M12;
const number2 t103 = M10 * M23 - M20 * M13;
const number2 t115 = M00 * M22 - M20 * M02;
const number2 t117 = M00 * M23 - M20 * M03;
const number2 t129 = M00 * M12 - M10 * M02;
const number2 t131 = M00 * M13 - M10 * M03;

(*this)(0, 0) = (M11 * t15 - M21 * t20 + M31 * t25) * inv_det;
(*this)(0, 1) = -(M01 * t15 - M21 * t33 + M31 * t38) * inv_det;
(*this)(0, 2) = (M33 * t81 - M32 * t83 + M31 * t50) * inv_det;
(*this)(0, 3) = -(M23 * t81 - M22 * t83 + M21 * t50) * inv_det;
(*this)(1, 0) = -(M33 * t101 - M32 * t103 + M30 * t25) * inv_det;
(*this)(1, 1) = (M33 * t115 - M32 * t117 + M30 * t38) * inv_det;
(*this)(1, 2) = -(M33 * t129 - M32 * t131 + M30 * t50) * inv_det;
(*this)(1, 3) = (M23 * t129 - M22 * t131 + M20 * t50) * inv_det;
(*this)(2, 0) = (M33 * t32 - M31 * t103 + M30 * t95) * inv_det;
(*this)(2, 1) = -(M33 * t19 - M31 * t117 + M30 * t97) * inv_det;
(*this)(2, 2) = (M33 * t14 - M31 * t131 + M30 * t83) * inv_det;
(*this)(2, 3) = -(M23 * t14 - M21 * t131 + M20 * t83) * inv_det;
(*this)(3, 0) = -(M32 * t32 - M31 * t101 + M30 * t99) * inv_det;
(*this)(3, 1) = (M32 * t19 - M31 * t115 + M30 * t93) * inv_det;
(*this)(3, 2) = -(M32 * t14 - M31 * t129 + M30 * t81) * inv_det;
(*this)(3, 3) = (M22 * t14 - M21 * t129 + M20 * t81) * inv_det;

break;
}


default:
// if no inversion is
// hardcoded, fall back
// to use the
// Gauss-Jordan algorithm
if (!PointerComparison::equal(&M, this))
*this = M;
gauss_jordan();
};
gauss_jordan();
};
}


Expand Down

0 comments on commit 9d46a3b

Please sign in to comment.