# cristicbz/scid forked from DlangScience/scid

Fix issue 51: Invalid multiplication between ColumnVector and Matrix.

1 parent 6dcfdd6 commit 0130741f17047172ef934209e8ea10a134a1d229 dsimcha committed Dec 3, 2011
Showing with 147 additions and 41 deletions.
1. +26 −7 scid/demo.d
2. +42 −1 scid/ops/common.d
3. +13 −2 scid/ops/eval.d
4. +66 −31 scid/ops/expression.d
33 scid/demo.d
 @@ -2,7 +2,6 @@ module scid.demo; version( demo ) { import scid.matvec; - import scid.storage.diagonalmat; import scid.common.traits, scid.common.meta; import scid.internal.regionallocator; @@ -11,11 +10,8 @@ version( demo ) { import std.string, std.math; void main() { - auto x = DiagonalMatrix!double([1.,2,3,4]), y=DiagonalMatrix!double([2.,2.,2.,2.]); - - // auto w = SymmetricMatrix!double([[1.,2.,3.,4],[1.,2.,3.,4],[1.,2.,3.,4],[1.,2.,3.,4]]); - eval( x[0..2][] * y[][0..2] ); - //writeln(z.pretty); + testIssue61(); + testIssue51(); readln(); } @@ -559,6 +555,29 @@ version( demo ) { enforce( invdot == 0.05 ); } + /** Issue 51 - Invalid multiplication between ColumnVector and Matrix */ + void testIssue51()() { + // Testing row vectors, too. + auto col = vector( [1.0, 2, 3] ); + auto row = vector( [4.0, 5, 6] ); + auto mat = matrix( [[4.0, 5, 6]] ); + + auto rowRes = eval( col * row.t ); + auto matRes = eval( col * mat ); + + foreach( res; [rowRes, matRes] ) { + assert( res[0, 0] == 4 ); + assert( res[0, 1] == 5 ); + assert( res[0, 2] == 6 ); + assert( res[1, 0] == 8 ); + assert( res[1, 1] == 10 ); + assert( res[1, 2] == 12 ); + assert( res[2, 0] == 12 ); + assert( res[2, 1] == 15 ); + assert( res[2, 2] == 18 ); + } + } + /** Issue 50 - Matrix slice-slice-assign ends up transposed */ void testIssue50()() { auto a = Matrix!double([[1.0, 2], [4.0, 5]]); @@ -641,4 +660,4 @@ version( demo ) { enforce( mat2.cdata == mat3.cdata ); } -} +}
43 scid/ops/common.d
 @@ -11,6 +11,7 @@ public import scid.ops.eval; import std.complex; import std.math; +import std.traits; import scid.ops.expression; import scid.vector; import scid.matrix; @@ -461,4 +462,44 @@ void generalMatrixScaledAddition( Transpose tr, S, D, E )( T alpha, auto ref S s } } } -} +} + +/** +This function takes a vector and returns an ExternalMatrixView of +the vector. The purpose of this is to support things like vector-matrix +multiplication by treating the vector on the LHS as a matrix. It hacks +around the ref counting system because it's only supposed to be used +internally for evaluating e.g. vector-matrix multiplication. + +If V is already a matrix and not a vector, this is a no-op. +*/ +auto toMatrix( V )( V vec ) { + alias ExternalMatrixView!( typeof( vec[0] ) ) Ret; + immutable n = vec.length; + auto ptr = cast(Unqual!(typeof(*vec.cdata))*) vec.cdata; + + static if( closureOf!V == Closure.ColumnVector ) { + return Ret( 1, ptr[0..n] ); + } else { + return Ret( n, ptr[0..n] ); + } +} + +unittest { + auto col = Vector!double([1, 2, 3]); + auto row = eval(col.t); + + auto colMat = toMatrix(col); + auto rowMat = toMatrix(row); + + assert(colMat[0, 0] == 1); + assert(colMat[1, 0] == 2); + assert(colMat[2, 0] == 3); + + assert(rowMat[0, 0] == 1); + assert(rowMat[0, 1] == 2); + assert(rowMat[0, 2] == 3); + + assert(colMat.cdata is rowMat.cdata); + assert(rowMat.cdata is col.cdata); +}
15 scid/ops/eval.d
 @@ -67,13 +67,22 @@ auto eval(string op_, Lhs, Rhs)( auto ref Expression!(op_, Lhs, Rhs) expr ) } /** Evaluate a matrix or vector expression. */ -ExpressionResult!E eval( E )( auto ref E expr ) if( isExpression!E && E.closure != Closure.Scalar ) { +ExpressionResult!E eval( E )( auto ref E expr ) +if( isExpression!E && E.closure != Closure.Scalar && E.operation != Operation.ToMatrix ) { // ExpressionResult gives you the type required to save the result of an expression. typeof(return) result; evalCopy( expr, result ); return result; } +/** +Convert a vector expression to a matrix. +*/ +ExpressionResult!E eval( E )( auto ref E expr ) +if( isExpression!E && E.operation == Operation.ToMatrix ) { + return toMatrix( eval( expr.lhs_ ) ); +} + /** Evaluate a matrix or vector expression in memory allocated with the specified allocator. */ ExpressionResult!(E).Temporary eval( E, Allocator )( auto ref E expr, Allocator allocator ) if( isAllocator!Allocator && E.closure != Closure.Scalar ) { @@ -201,7 +210,9 @@ void evalCopy( Transpose tr = Transpose.no, Source, Dest )( auto ref Source sour evalMatrixVectorProduct!( Transpose.no, Transpose.no )( One!T, source.lhs, source.rhs, Zero!T, dest ); } else static if( op == Operation.MatInverse ) { evalInverse!tr( One!T, source.lhs, Zero!T, dest ); - } else { + } else static if( op == Operation.ToMatrix ) { + evalCopy!tr( eval( source ), dest ); + } else { // can't solve the expression - use a temporary by calling eval RegionAllocator alloc = newRegionAllocator(); evalCopy!tr( eval(source, alloc), dest );
97 scid/ops/expression.d
 @@ -8,6 +8,7 @@ module scid.ops.expression; import scid.common.traits, scid.common.meta; +import scid.matrix : ExternalMatrixView; import std.traits, std.range; import std.conv; import std.typecons; @@ -23,6 +24,8 @@ enum Operation { MatScalProd, // Matrix = Matrix * Scalar MatInverse, // Matrix = Matrix ^^ (-1) MatTrans, // Matrix = Matrix.t + ToMatrix, // Convert vector to matrix for ColumnVector * RowVector + // or ColumnVector * Matrix. RowRowSum, // RowVector = RowVector + RowVector RowScalProd, // RowVector = RowVector * Scalar @@ -50,6 +53,33 @@ enum Closure { Scalar } +// This gives a small compile time boost when checking the closure of a given operation. +private enum operationClosures = [ + Closure.Matrix, // MatMatSum, + Closure.Matrix, // MatMatProd, + Closure.Matrix, // MatScalProd, + Closure.Matrix, // MatInverse, + Closure.Matrix, // MatTrans, + Closure.Matrix, // ToMatrix, + + Closure.RowVector, // RowRowSum, + Closure.RowVector, // RowScalProd, + Closure.RowVector, // RowMatProd, + Closure.RowVector, // ColTrans, + + Closure.ColumnVector, // ColColSum, + Closure.ColumnVector, // ColScalProd, + Closure.ColumnVector, // MatColProd, + Closure.ColumnVector, // RowTrans, + + Closure.Scalar, // DotProd, + Closure.Scalar, // ScalScalSum, + Closure.Scalar, // ScalScalSub, + Closure.Scalar, // ScalScalProd, + Closure.Scalar, // ScalScalDiv, + Closure.Scalar // ScalScalPow, +]; + /** Convinience function to create an expression object. */ Expression!( op, Unqual!Lhs, Unqual!Rhs ) expression( string op, Lhs, Rhs )( auto ref Lhs lhs, auto ref Rhs rhs ) { static assert( isConvertible!(BaseElementType!Lhs, BaseElementType!Rhs) && @@ -129,6 +159,9 @@ struct Expression( string op_, Lhs_, Rhs_ = void ) { size_t rows() const @property { static if( operation == Operation.MatTrans ) return lhs_.columns; + else static if( operation == Operation.ToMatrix ) + return ( closureOf!Lhs == Closure.ColumnVector ) ? + ( lhs_.length ) : 1; else return lhs_.rows; } @@ -139,6 +172,9 @@ struct Expression( string op_, Lhs_, Rhs_ = void ) { return lhs_.rows; else static if( operation == Operation.MatMatProd ) return rhs_.columns; + else static if( operation == Operation.ToMatrix ) + return ( closureOf!Lhs == Closure.RowVector ) ? + ( lhs_.length ) : 1; else return lhs_.columns; } @@ -209,7 +245,25 @@ template Operand( Closure closure_ ) { } auto opBinary( string op, NewRhs )( auto ref NewRhs newRhs ) if( op == "*" ) { - return expression!op( this, newRhs ); + // If we're multiplying a column vector by a row vector or a matrix, + // rewrite the expression as matrix-matrix multiplication. + static if( this.closure == Closure.ColumnVector ) { + static assert( closureOf!NewRhs != Closure.ColumnVector, + "Invalid multiplication between ColumnVector and ColumnVector." ); + + auto thisConverted = expression!"toMatrix"( this ); + + static if( closureOf!NewRhs == Closure.RowVector ) { + auto rhsConverted = expression!"toMatrix"( newRhs ); + } else { + alias newRhs rhsConverted; + } + } else { + alias this thisConverted; + alias newRhs rhsConverted; + } + + return expression!op( thisConverted, rhsConverted ); } auto opBinaryRight( string op, E )( E newLhs ) if( isConvertible!(E,ElementType) && op == "*" ) { @@ -247,7 +301,10 @@ template ExpressionResult( E ) { } else static if( E.closure == Closure.Scalar ) { // if the node results in a scalar then the result is of the same type as the element type alias E.ElementType ExpressionResult; - } else static if( isTransposition!(E.operation) ) { + } else static if( E.operation == Operation.ToMatrix ) { + // ToMatrix converts a vector to an ExternalMatrixView. + alias ExternalMatrixView!( E.Lhs.ElementType ) ExpressionResult; + } else static if( isTransposition!(E.operation) ) { // if the node is a transposition then the result is the Transposed of the child node alias ExpressionResult!(E.Lhs).Transposed ExpressionResult; } else static if( E.isBinary ) { @@ -346,34 +403,6 @@ private template PromotionImpl( A, B ) { } } - -// This gives a small compile time boost when checking the closure of a given operation. -private enum operationClosures = [ - Closure.Matrix, // MatMatSum, - Closure.Matrix, // MatMatProd, - Closure.Matrix, // MatScalProd, - Closure.Matrix, // MatInverse, - Closure.Matrix, // MatTrans, - - Closure.RowVector, // RowRowSum, - Closure.RowVector, // RowScalProd, - Closure.RowVector, // RowMatProd, - Closure.RowVector, // ColTrans, - - Closure.ColumnVector, // ColColSum, - Closure.ColumnVector, // ColScalProd, - Closure.ColumnVector, // MatColProd, - Closure.ColumnVector, // RowTrans, - - Closure.Scalar, // DotProd, - Closure.Scalar, // ScalScalSum, - Closure.Scalar, // ScalScalSub, - Closure.Scalar, // ScalScalProd, - Closure.Scalar, // ScalScalDiv, - Closure.Scalar // ScalScalPow, -]; - - // Find the operation type given the operator and the closure types of the operands private template operationOf( string op, Closure l, Closure r ) if( op == "*" ) { @@ -431,4 +460,10 @@ private template operationOf( string op, Closure l, Closure r ) if( op == "-" ) private template operationOf( string op, Closure l ) if( op == "inv" ) { static if( l == Closure.Matrix ) { enum operationOf = Operation.MatInverse; } else static assert( false, "Invalid inversion of " ~ to!string(l) ); -} +} + +private template operationOf( string op, Closure l ) if( op == "toMatrix" ) { + static if( l == Closure.RowVector ) { enum operationOf = Operation.ToMatrix; } + else static if( l == Closure.ColumnVector ) { enum operationOf = Operation.ToMatrix; } + else static assert( false, "Can only convert vectors to matrices." ); +}