Skip to content

Commit

Permalink
Updated Householder QR based solver and found a bug
Browse files Browse the repository at this point in the history
- The solver works for A with dim nRows >= nCols
- Fixed a bug related to the case w/o reuse feature
- Disabled export of the solver in case reuse it TRUE, until bugs are
  fixed
- Added calculation of determinant, the solver return determinant of
  matrix R now.
  • Loading branch information
mvukov committed Aug 12, 2013
1 parent b4b3436 commit d0219e4
Showing 1 changed file with 152 additions and 90 deletions.
242 changes: 152 additions & 90 deletions src/code_generation/linear_solvers/householder_qr_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,138 +92,194 @@ returnValue ExportHouseholderQR::getFunctionDeclarations( ExportStatementBlock&
returnValue ExportHouseholderQR::getCode( ExportStatementBlock& code
)
{
if (REUSE == BT_TRUE)
return ACADOERRORTEXT(RET_NOT_IMPLEMENTED_YET, "There is a bug in the QR based lin. solver with reusing the factorization. Please use LU solver instead.");

uint run1, run2, run3;
ExportIndex i( "i" );
ExportIndex j( "j" );
ExportIndex k( "k" );
if( !UNROLLING ) {
solve.addIndex( i );
solve.addIndex( j );
solve.addIndex( k );
}

solve.addStatement( determinant == 1.0 );

if( UNROLLING || nRows <= 5 )
{
// Start the factorization:
for( run1 = 0; run1 < (dim-1); run1++ ) {
for( run2 = run1; run2 < dim; run2++ ) {
solve.addStatement( rk_temp.getCol( run2 ) == A.getSubMatrix( run2,run2+1,run1,run1+1 ) );
for (run1 = 0; run1 < nCols; run1++)
{
for (run2 = run1; run2 < nRows; run2++)
{
solve.addStatement(
rk_temp.getCol(run2)
== A.getSubMatrix(run2, run2 + 1, run1,
run1 + 1));
}
// calculate norm:
solve.addStatement( rk_temp.getCol( dim ) == rk_temp.getCol( run1 )*rk_temp.getCol( run1 ) );
for( run2 = run1+1; run2 < dim; run2++ ) {
solve.addStatement( rk_temp.getCol( dim ) += rk_temp.getCol( run2 )*rk_temp.getCol( run2 ) );
}
solve.addStatement( rk_temp.getFullName() << "[" << String( dim ) << "] = sqrt(" << rk_temp.getFullName() << "[" << String( dim ) << "]);\n" );

solve.addStatement(rk_temp.getCol(nRows) ==
rk_temp.getCols(run1, nRows) * rk_temp.getTranspose().getRows(run1, nRows));
solve.addStatement(
rk_temp.getFullName() << "[" << String(nRows) << "] = sqrt("
<< rk_temp.getFullName() << "[" << String(nRows)
<< "]);\n");

// update first element:
solve.addStatement( rk_temp.getFullName() << "[" << String( run1 ) << "] += (" << rk_temp.getFullName() << "[" << String( run1 ) << "] < 0 ? -1 : 1)*" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );

solve.addStatement(
rk_temp.getFullName() << "[" << String(run1) << "] += ("
<< rk_temp.getFullName() << "[" << String(run1)
<< "] < 0 ? -1 : 1)*" << rk_temp.getFullName()
<< "[" << String(nRows) << "];\n");

// calculate norm:
solve.addStatement( rk_temp.getCol( dim ) == rk_temp.getCol( run1 )*rk_temp.getCol( run1 ) );
for( run2 = run1+1; run2 < dim; run2++ ) {
solve.addStatement( rk_temp.getCol( dim ) += rk_temp.getCol( run2 )*rk_temp.getCol( run2 ) );
}
solve.addStatement( rk_temp.getFullName() << "[" << String( dim ) << "] = sqrt(" << rk_temp.getFullName() << "[" << String( dim ) << "]);\n" );

solve.addStatement(rk_temp.getCol(nRows) ==
rk_temp.getCols(run1, nRows) * rk_temp.getTranspose().getRows(run1, nRows));
solve.addStatement(
rk_temp.getFullName() << "[" << String(nRows) << "] = sqrt("
<< rk_temp.getFullName() << "[" << String(nRows)
<< "]);\n");

// normalization:
for( run2 = run1; run2 < dim; run2++ ) {
solve.addStatement( rk_temp.getFullName() << "[" << String( run2 ) << "] = " << rk_temp.getFullName() << "[" << String( run2 ) << "]/" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );
for (run2 = run1; run2 < nRows; run2++)
{
solve.addStatement(
rk_temp.getFullName() << "[" << String(run2) << "] = "
<< rk_temp.getFullName() << "[" << String(run2)
<< "]/" << rk_temp.getFullName() << "["
<< String(nRows) << "];\n");
}

// update current column:
solve.addStatement( rk_temp.getCol( dim ) == rk_temp.getCols( run1,dim )*A.getSubMatrix( run1,dim,run1,run1+1 ) );
solve.addStatement( rk_temp.getFullName() << "[" << String( dim ) << "] *= 2;\n" );
solve.addStatement( A.getSubMatrix( run1,run1+1,run1,run1+1 ) -= rk_temp.getCol( run1 )*rk_temp.getCol( dim ) );
if( REUSE ) {
solve.addStatement(
rk_temp.getCol(nRows)
== rk_temp.getCols(run1, nRows)
* A.getSubMatrix(run1, nRows, run1,
run1 + 1));
solve.addStatement(
rk_temp.getFullName() << "[" << String(nRows) << "] *= 2;\n");
solve.addStatement(
A.getSubMatrix(run1, run1 + 1, run1, run1 + 1) -=
rk_temp.getCol(run1) * rk_temp.getCol(nRows));

solve.addStatement( determinant == determinant * A.getElement(run1, run1) );

if (REUSE)
{
// replace zeros by results that can be reused:
for( run2 = run1; run2 < dim-1; run2++ ) {
solve.addStatement( A.getSubMatrix( run2+1,run2+2,run1,run1+1 ) == rk_temp.getCol( run2 ) );
for (run2 = run1; run2 < dim - 1; run2++)
{
solve.addStatement(
A.getSubMatrix(run2 + 1, run2 + 2, run1, run1 + 1)
== rk_temp.getCol(run2));
}
}

// update following columns:
for( run2 = run1+1; run2 < dim; run2++ ) {
solve.addStatement( rk_temp.getCol( dim ) == rk_temp.getCols( run1,dim )*A.getSubMatrix( run1,dim,run2,run2+1 ) );
solve.addStatement( rk_temp.getFullName() << "[" << String( dim ) << "] *= 2;\n" );
for( run3 = run1; run3 < dim; run3++ ) {
solve.addStatement( A.getSubMatrix( run3,run3+1,run2,run2+1 ) -= rk_temp.getCol( run3 )*rk_temp.getCol( dim ) );
for (run2 = run1 + 1; run2 < nCols; run2++)
{
solve.addStatement(
rk_temp.getCol(nRows)
== rk_temp.getCols(run1, nRows)
* A.getSubMatrix(run1, nRows, run2,
run2 + 1));
solve.addStatement(
rk_temp.getFullName() << "[" << String(nRows)
<< "] *= 2;\n");
for (run3 = run1; run3 < nRows; run3++)
{
solve.addStatement(
A.getSubMatrix(run3, run3 + 1, run2, run2 + 1) -=
rk_temp.getCol(run3) * rk_temp.getCol(nRows));
}
}
// update right-hand side:
solve.addStatement( rk_temp.getCol( dim ) == rk_temp.getCols( run1,dim )*b.getRows( run1,dim ) );
solve.addStatement( rk_temp.getFullName() << "[" << String( dim ) << "] *= 2;\n" );
for( run3 = run1; run3 < dim; run3++ ) {
solve.addStatement( b.getRow( run3 ) -= rk_temp.getCol( run3 )*rk_temp.getCol( dim ) );
solve.addStatement(
rk_temp.getCol(nRows)
== rk_temp.getCols(run1, nRows)
* b.getRows(run1, nRows));
solve.addStatement(
rk_temp.getFullName() << "[" << String(nRows) << "] *= 2;\n");
for (run3 = run1; run3 < nRows; run3++)
{
solve.addStatement( b.getRow(run3) -= rk_temp.getCol(run3) * rk_temp.getCol(nRows));
}

if( REUSE ) {

if (REUSE)
{
// store last element to be reused:
solve.addStatement( rk_temp.getCol( run1 ) == rk_temp.getCol( dim-1 ) );
solve.addStatement(
rk_temp.getCol(run1) == rk_temp.getCol(dim - 1));
}
}
}
else {
solve.addStatement( String( "for( i=0; i < " ) << String( dim-1 ) << "; i++ ) {\n" );
solve.addStatement( String( " for( j=i; j < " ) << String( dim ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[j] = A[j*" << String( dim ) << "+i];\n" );
else
{
ExportIndex i( "i" );
ExportIndex j( "j" );
ExportIndex k( "k" );

solve.addIndex( i );
solve.addIndex( j );
solve.addIndex( k );

solve.addStatement( String( "for( i=0; i < " ) << String( nCols ) << "; i++ ) {\n" );
solve.addStatement( String( " for( j=i; j < " ) << String( nRows ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[j] = A[j*" << String( nCols ) << "+i];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n" );
solve.addStatement( String( " for( j=i+1; j < " ) << String( dim ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = sqrt(" << rk_temp.getFullName() << "[" << String( dim ) << "]);\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = sqrt(" << rk_temp.getFullName() << "[" << String( nRows ) << "]);\n" );
// update first element:
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[i] += (" << rk_temp.getFullName() << "[i] < 0 ? -1 : 1)*" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n" );
solve.addStatement( String( " for( j=i+1; j < " ) << String( dim ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[i] += (" << rk_temp.getFullName() << "[i] < 0 ? -1 : 1)*" << rk_temp.getFullName() << "[" << String( nRows ) << "];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n" );
solve.addStatement( String( " for( j=i+1; j < " ) << String( nRows ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = sqrt(" << rk_temp.getFullName() << "[" << String( dim ) << "]);\n" );
solve.addStatement( String( " for( j=i; j < " ) << String( dim ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[j] = " << rk_temp.getFullName() << "[j]/" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = sqrt(" << rk_temp.getFullName() << "[" << String( nRows ) << "]);\n" );
solve.addStatement( String( " for( j=i; j < " ) << String( nRows ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[j] = " << rk_temp.getFullName() << "[j]/" << rk_temp.getFullName() << "[" << String( nRows ) << "];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << String( dim ) << "+i];\n" );
solve.addStatement( String( " for( j=i+1; j < " ) << String( dim ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] += " << rk_temp.getFullName() << "[j]*A[j*" << String( dim ) << "+i];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << String( nCols ) << "+i];\n" );
solve.addStatement( String( " for( j=i+1; j < " ) << String( nRows ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] += " << rk_temp.getFullName() << "[j]*A[j*" << String( nCols ) << "+i];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] *= 2;\n" );
solve.addStatement( String( " A[i*" ) << String( dim ) << "+i] -= " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] *= 2;\n" );
solve.addStatement( String( " A[i*" ) << String( nCols ) << "+i] -= " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[" << String( nRows ) << "];\n" );

solve.addStatement( String( " " ) << determinant.getFullName() << " *= " << " A[i * " << String( nCols ) << " + i];\n" );

if( REUSE ) {
solve.addStatement( String( " for( j=i; j < (" ) << String( dim ) << "-1); j++ ) {\n" );
solve.addStatement( String( " A[(j+1)*" ) << String( dim ) << "+i] = " << rk_temp.getFullName() << "[j];\n" );
solve.addStatement( String( " }\n" ) );
}
solve.addStatement( String( " for( j=i+1; j < " ) << String( dim ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << String( dim ) << "+j];\n" );
solve.addStatement( String( " for( k=i+1; k < " ) << String( dim ) << "; k++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] += " << rk_temp.getFullName() << "[k]*A[k*" << String( dim ) << "+j];\n" );
solve.addStatement( String( " for( j=i+1; j < " ) << String( nCols ) << "; j++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << String( nCols ) << "+j];\n" );
solve.addStatement( String( " for( k=i+1; k < " ) << String( nRows ) << "; k++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] += " << rk_temp.getFullName() << "[k]*A[k*" << String( nCols ) << "+j];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] *= 2;\n" );
solve.addStatement( String( " for( k=i; k < " ) << String( dim ) << "; k++ ) {\n" );
solve.addStatement( String( " A[k*" ) << String( dim ) << "+j] -= " << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] *= 2;\n" );
solve.addStatement( String( " for( k=i; k < " ) << String( nRows ) << "; k++ ) {\n" );
solve.addStatement( String( " A[k*" ) << String( nCols ) << "+j] -= " << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << String( nRows ) << "];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] = " << rk_temp.getFullName() << "[i]*b[i];\n" );
solve.addStatement( String( " for( k=i+1; k < " ) << String( dim ) << "; k++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] += " << rk_temp.getFullName() << "[k]*b[k];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] = " << rk_temp.getFullName() << "[i]*b[i];\n" );
solve.addStatement( String( " for( k=i+1; k < " ) << String( nRows ) << "; k++ ) {\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] += " << rk_temp.getFullName() << "[k]*b[k];\n" );
solve.addStatement( String( " }\n" ) );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( dim ) << "] *= 2;\n" );
solve.addStatement( String( " for( k=i; k < " ) << String( dim ) << "; k++ ) {\n" );
solve.addStatement( String( " b[k] -= " ) << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << String( dim ) << "];\n" );
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[" << String( nRows ) << "] *= 2;\n" );
solve.addStatement( String( " for( k=i; k < " ) << String( nRows ) << "; k++ ) {\n" );
solve.addStatement( String( " b[k] -= " ) << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << String( nRows ) << "];\n" );
solve.addStatement( String( " }\n" ) );
if( REUSE ) {
solve.addStatement( String( " " ) << rk_temp.getFullName() << "[i] = " << rk_temp.getFullName() << "[" << String( dim-1 ) << "];\n" );
}
solve.addStatement( String( "}\n" ) );
}
// updates last column:
solve.addStatement( String( "A[" ) << String( dim*dim-1 ) << "] *= -1;\n" );
solve.addStatement( String( "b[" ) << String( dim-1 ) << "] *= -1;\n" );

solve.addLinebreak();

solve.addFunctionCall( solveTriangular, A, b );

// // updates last column: Something is wrong here!
// solve.addStatement(String("A[") << String(nCols * nCols - 1) << "] *= -1.0;\n");
// solve.addStatement(String("b[") << String(nCols - 1) << "] *= -1.0;\n");

solve.addFunctionCall(solveTriangular, A, b);
code.addFunction( solve );

code.addLinebreak( 2 );
Expand All @@ -249,13 +305,19 @@ returnValue ExportHouseholderQR::getCode( ExportStatementBlock& code
}

// Solve the upper triangular system of equations:
for( run1 = dim; run1 > 0; run1--) {
for( run2 = dim-1; run2 > (run1-1); run2--) {
solveTriangular.addStatement( b.getRow( (run1-1) ) -= A.getSubMatrix( (run1-1),(run1-1)+1,run2,run2+1 ) * b.getRow( run2 ) );
for (run1 = nCols; run1 > (nCols - nBacksolves); run1--)
{
for (run2 = nCols - 1; run2 > (run1 - 1); run2--)
{
solveTriangular.addStatement(
b.getRow(run1 - 1) -= A.getSubMatrix((run1 - 1), run1, run2, run2 + 1) * b.getRow(run2));
}
solveTriangular.addStatement( String( "b[" ) << String( (run1-1) ) << "] = b[" << String( (run1-1) ) << "]/A[" << String( (run1-1)*dim+(run1-1) ) << "];\n" );
solveTriangular.addStatement(
String("b[") << String((run1 - 1)) << "] = b["
<< String((run1 - 1)) << "]/A["
<< String((run1 - 1) * nCols + (run1 - 1)) << "];\n");
}
code.addFunction( solveTriangular );
code.addFunction(solveTriangular);

return SUCCESSFUL_RETURN;
}
Expand Down

0 comments on commit d0219e4

Please sign in to comment.