Skip to content
Browse files

Fix issue 51: Invalid multiplication between ColumnVector and Matrix.

  • Loading branch information...
1 parent 6dcfdd6 commit 0130741f17047172ef934209e8ea10a134a1d229 @dsimcha dsimcha committed Dec 3, 2011
Showing with 147 additions and 41 deletions.
  1. +26 −7 scid/demo.d
  2. +42 −1 scid/ops/common.d
  3. +13 −2 scid/ops/eval.d
  4. +66 −31 scid/ops/expression.d
View
33 scid/demo.d
@@ -2,7 +2,6 @@ module scid.demo;
version( demo ) {
import scid.matvec;
- import scid.storage.diagonalmat;
import scid.common.traits, scid.common.meta;
import scid.internal.regionallocator;
@@ -11,11 +10,8 @@ version( demo ) {
import std.string, std.math;
void main() {
- auto x = DiagonalMatrix!double([1.,2,3,4]), y=DiagonalMatrix!double([2.,2.,2.,2.]);
-
- // auto w = SymmetricMatrix!double([[1.,2.,3.,4],[1.,2.,3.,4],[1.,2.,3.,4],[1.,2.,3.,4]]);
- eval( x[0..2][] * y[][0..2] );
- //writeln(z.pretty);
+ testIssue61();
+ testIssue51();
readln();
}
@@ -559,6 +555,29 @@ version( demo ) {
enforce( invdot == 0.05 );
}
+ /** Issue 51 - Invalid multiplication between ColumnVector and Matrix */
+ void testIssue51()() {
+ // Testing row vectors, too.
+ auto col = vector( [1.0, 2, 3] );
+ auto row = vector( [4.0, 5, 6] );
+ auto mat = matrix( [[4.0, 5, 6]] );
+
+ auto rowRes = eval( col * row.t );
+ auto matRes = eval( col * mat );
+
+ foreach( res; [rowRes, matRes] ) {
+ assert( res[0, 0] == 4 );
+ assert( res[0, 1] == 5 );
+ assert( res[0, 2] == 6 );
+ assert( res[1, 0] == 8 );
+ assert( res[1, 1] == 10 );
+ assert( res[1, 2] == 12 );
+ assert( res[2, 0] == 12 );
+ assert( res[2, 1] == 15 );
+ assert( res[2, 2] == 18 );
+ }
+ }
+
/** Issue 50 - Matrix slice-slice-assign ends up transposed */
void testIssue50()() {
auto a = Matrix!double([[1.0, 2], [4.0, 5]]);
@@ -641,4 +660,4 @@ version( demo ) {
enforce( mat2.cdata == mat3.cdata );
}
-}
+}
View
43 scid/ops/common.d
@@ -11,6 +11,7 @@ public import scid.ops.eval;
import std.complex;
import std.math;
+import std.traits;
import scid.ops.expression;
import scid.vector;
import scid.matrix;
@@ -461,4 +462,44 @@ void generalMatrixScaledAddition( Transpose tr, S, D, E )( T alpha, auto ref S s
}
}
}
-}
+}
+
+/**
+This function takes a vector and returns an ExternalMatrixView of
+the vector. The purpose of this is to support things like vector-matrix
+multiplication by treating the vector on the LHS as a matrix. It hacks
+around the ref counting system because it's only supposed to be used
+internally for evaluating e.g. vector-matrix multiplication.
+
+If V is already a matrix and not a vector, this is a no-op.
+*/
+auto toMatrix( V )( V vec ) {
+ alias ExternalMatrixView!( typeof( vec[0] ) ) Ret;
+ immutable n = vec.length;
+ auto ptr = cast(Unqual!(typeof(*vec.cdata))*) vec.cdata;
+
+ static if( closureOf!V == Closure.ColumnVector ) {
+ return Ret( 1, ptr[0..n] );
+ } else {
+ return Ret( n, ptr[0..n] );
+ }
+}
+
+unittest {
+ auto col = Vector!double([1, 2, 3]);
+ auto row = eval(col.t);
+
+ auto colMat = toMatrix(col);
+ auto rowMat = toMatrix(row);
+
+ assert(colMat[0, 0] == 1);
+ assert(colMat[1, 0] == 2);
+ assert(colMat[2, 0] == 3);
+
+ assert(rowMat[0, 0] == 1);
+ assert(rowMat[0, 1] == 2);
+ assert(rowMat[0, 2] == 3);
+
+ assert(colMat.cdata is rowMat.cdata);
+ assert(rowMat.cdata is col.cdata);
+}
View
15 scid/ops/eval.d
@@ -67,13 +67,22 @@ auto eval(string op_, Lhs, Rhs)( auto ref Expression!(op_, Lhs, Rhs) expr )
}
/** Evaluate a matrix or vector expression. */
-ExpressionResult!E eval( E )( auto ref E expr ) if( isExpression!E && E.closure != Closure.Scalar ) {
+ExpressionResult!E eval( E )( auto ref E expr )
+if( isExpression!E && E.closure != Closure.Scalar && E.operation != Operation.ToMatrix ) {
// ExpressionResult gives you the type required to save the result of an expression.
typeof(return) result;
evalCopy( expr, result );
return result;
}
+/**
+Convert a vector expression to a matrix.
+*/
+ExpressionResult!E eval( E )( auto ref E expr )
+if( isExpression!E && E.operation == Operation.ToMatrix ) {
+ return toMatrix( eval( expr.lhs_ ) );
+}
+
/** Evaluate a matrix or vector expression in memory allocated with the specified allocator. */
ExpressionResult!(E).Temporary eval( E, Allocator )( auto ref E expr, Allocator allocator )
if( isAllocator!Allocator && E.closure != Closure.Scalar ) {
@@ -201,7 +210,9 @@ void evalCopy( Transpose tr = Transpose.no, Source, Dest )( auto ref Source sour
evalMatrixVectorProduct!( Transpose.no, Transpose.no )( One!T, source.lhs, source.rhs, Zero!T, dest );
} else static if( op == Operation.MatInverse ) {
evalInverse!tr( One!T, source.lhs, Zero!T, dest );
- } else {
+ } else static if( op == Operation.ToMatrix ) {
+ evalCopy!tr( eval( source ), dest );
+ } else {
// can't solve the expression - use a temporary by calling eval
RegionAllocator alloc = newRegionAllocator();
evalCopy!tr( eval(source, alloc), dest );
View
97 scid/ops/expression.d
@@ -8,6 +8,7 @@
module scid.ops.expression;
import scid.common.traits, scid.common.meta;
+import scid.matrix : ExternalMatrixView;
import std.traits, std.range;
import std.conv;
import std.typecons;
@@ -23,6 +24,8 @@ enum Operation {
MatScalProd, // Matrix = Matrix * Scalar
MatInverse, // Matrix = Matrix ^^ (-1)
MatTrans, // Matrix = Matrix.t
+ ToMatrix, // Convert vector to matrix for ColumnVector * RowVector
+ // or ColumnVector * Matrix.
RowRowSum, // RowVector = RowVector + RowVector
RowScalProd, // RowVector = RowVector * Scalar
@@ -50,6 +53,33 @@ enum Closure {
Scalar
}
+// This gives a small compile time boost when checking the closure of a given operation.
+private enum operationClosures = [
+ Closure.Matrix, // MatMatSum,
+ Closure.Matrix, // MatMatProd,
+ Closure.Matrix, // MatScalProd,
+ Closure.Matrix, // MatInverse,
+ Closure.Matrix, // MatTrans,
+ Closure.Matrix, // ToMatrix,
+
+ Closure.RowVector, // RowRowSum,
+ Closure.RowVector, // RowScalProd,
+ Closure.RowVector, // RowMatProd,
+ Closure.RowVector, // ColTrans,
+
+ Closure.ColumnVector, // ColColSum,
+ Closure.ColumnVector, // ColScalProd,
+ Closure.ColumnVector, // MatColProd,
+ Closure.ColumnVector, // RowTrans,
+
+ Closure.Scalar, // DotProd,
+ Closure.Scalar, // ScalScalSum,
+ Closure.Scalar, // ScalScalSub,
+ Closure.Scalar, // ScalScalProd,
+ Closure.Scalar, // ScalScalDiv,
+ Closure.Scalar // ScalScalPow,
+];
+
/** Convinience function to create an expression object. */
Expression!( op, Unqual!Lhs, Unqual!Rhs ) expression( string op, Lhs, Rhs )( auto ref Lhs lhs, auto ref Rhs rhs ) {
static assert( isConvertible!(BaseElementType!Lhs, BaseElementType!Rhs) &&
@@ -129,6 +159,9 @@ struct Expression( string op_, Lhs_, Rhs_ = void ) {
size_t rows() const @property {
static if( operation == Operation.MatTrans )
return lhs_.columns;
+ else static if( operation == Operation.ToMatrix )
+ return ( closureOf!Lhs == Closure.ColumnVector ) ?
+ ( lhs_.length ) : 1;
else
return lhs_.rows;
}
@@ -139,6 +172,9 @@ struct Expression( string op_, Lhs_, Rhs_ = void ) {
return lhs_.rows;
else static if( operation == Operation.MatMatProd )
return rhs_.columns;
+ else static if( operation == Operation.ToMatrix )
+ return ( closureOf!Lhs == Closure.RowVector ) ?
+ ( lhs_.length ) : 1;
else
return lhs_.columns;
}
@@ -209,7 +245,25 @@ template Operand( Closure closure_ ) {
}
auto opBinary( string op, NewRhs )( auto ref NewRhs newRhs ) if( op == "*" ) {
- return expression!op( this, newRhs );
+ // If we're multiplying a column vector by a row vector or a matrix,
+ // rewrite the expression as matrix-matrix multiplication.
+ static if( this.closure == Closure.ColumnVector ) {
+ static assert( closureOf!NewRhs != Closure.ColumnVector,
+ "Invalid multiplication between ColumnVector and ColumnVector." );
+
+ auto thisConverted = expression!"toMatrix"( this );
+
+ static if( closureOf!NewRhs == Closure.RowVector ) {
+ auto rhsConverted = expression!"toMatrix"( newRhs );
+ } else {
+ alias newRhs rhsConverted;
+ }
+ } else {
+ alias this thisConverted;
+ alias newRhs rhsConverted;
+ }
+
+ return expression!op( thisConverted, rhsConverted );
}
auto opBinaryRight( string op, E )( E newLhs ) if( isConvertible!(E,ElementType) && op == "*" ) {
@@ -247,7 +301,10 @@ template ExpressionResult( E ) {
} else static if( E.closure == Closure.Scalar ) {
// if the node results in a scalar then the result is of the same type as the element type
alias E.ElementType ExpressionResult;
- } else static if( isTransposition!(E.operation) ) {
+ } else static if( E.operation == Operation.ToMatrix ) {
+ // ToMatrix converts a vector to an ExternalMatrixView.
+ alias ExternalMatrixView!( E.Lhs.ElementType ) ExpressionResult;
+ } else static if( isTransposition!(E.operation) ) {
// if the node is a transposition then the result is the Transposed of the child node
alias ExpressionResult!(E.Lhs).Transposed ExpressionResult;
} else static if( E.isBinary ) {
@@ -346,34 +403,6 @@ private template PromotionImpl( A, B ) {
}
}
-
-// This gives a small compile time boost when checking the closure of a given operation.
-private enum operationClosures = [
- Closure.Matrix, // MatMatSum,
- Closure.Matrix, // MatMatProd,
- Closure.Matrix, // MatScalProd,
- Closure.Matrix, // MatInverse,
- Closure.Matrix, // MatTrans,
-
- Closure.RowVector, // RowRowSum,
- Closure.RowVector, // RowScalProd,
- Closure.RowVector, // RowMatProd,
- Closure.RowVector, // ColTrans,
-
- Closure.ColumnVector, // ColColSum,
- Closure.ColumnVector, // ColScalProd,
- Closure.ColumnVector, // MatColProd,
- Closure.ColumnVector, // RowTrans,
-
- Closure.Scalar, // DotProd,
- Closure.Scalar, // ScalScalSum,
- Closure.Scalar, // ScalScalSub,
- Closure.Scalar, // ScalScalProd,
- Closure.Scalar, // ScalScalDiv,
- Closure.Scalar // ScalScalPow,
-];
-
-
// Find the operation type given the operator and the closure types of the operands
private template operationOf( string op, Closure l, Closure r ) if( op == "*" ) {
@@ -431,4 +460,10 @@ private template operationOf( string op, Closure l, Closure r ) if( op == "-" )
private template operationOf( string op, Closure l ) if( op == "inv" ) {
static if( l == Closure.Matrix ) { enum operationOf = Operation.MatInverse; }
else static assert( false, "Invalid inversion of " ~ to!string(l) );
-}
+}
+
+private template operationOf( string op, Closure l ) if( op == "toMatrix" ) {
+ static if( l == Closure.RowVector ) { enum operationOf = Operation.ToMatrix; }
+ else static if( l == Closure.ColumnVector ) { enum operationOf = Operation.ToMatrix; }
+ else static assert( false, "Can only convert vectors to matrices." );
+}

0 comments on commit 0130741

Please sign in to comment.
Something went wrong with that request. Please try again.