Skip to content

Commit

Permalink
Added matrix multiply (doesn't work for row major matrices). Fixed Is…
Browse files Browse the repository at this point in the history
…sue 39.
  • Loading branch information
cristicbz committed Aug 13, 2011
1 parent 344e1d9 commit ea99ace
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 55 deletions.
Binary file modified scid.suo
Binary file not shown.
30 changes: 29 additions & 1 deletion scid/blas.d
Expand Up @@ -186,7 +186,7 @@ struct blas {
static if( isFortranType!T && !forceNaive )
blas_.gemm( transa, transb, toi(m), toi(n), toi(k), alpha, a, toi(lda), b, toi(ldb), beta, c, toi(ldc) );
else
static assert( false, "There is no naive implementation of gemm available." );
naive_.gemm!( transa, transb )( m, n, k, alpha, a, lda, b, ldb, beta, c, ldc );

debug( blasCalls )
writeln( matrixToString( 'N', m, n, c, ldc ) );
Expand Down Expand Up @@ -609,6 +609,34 @@ private struct naive_ {

// Level 3

static void gemm( char transa_, char transb_, T )( size_t m, size_t n, size_t k, T alpha, const(T)* a, size_t lda, const(T) *b, size_t ldb, T beta, T *c, size_t ldc ) {
enum transa = cast(char)toUpper(transa_);
enum transb = cast(char)toUpper(transb_);

T geta( size_t i, size_t j ) {
static if ( transa == 'N' ) return a[ i + j * lda ];
else static if ( transa == 'T' ) return a[ j + i * lda ];
else static if ( transa == 'C' ) return blas.xconj( a[ j + i * lda ] );
}

T getb( size_t i, size_t j ) {
static if ( transb == 'N' ) return b[ i + j * ldb ];
else static if ( transb == 'T' ) return b[ j + i * ldb ];
else static if ( transb == 'C' ) return blas.xconj( b[ j + i * ldb ] );
}

void setc( T rhs, size_t i, size_t j ) { c[ i + j * ldc ] = rhs; }
T getc( size_t i, size_t j ) { return c[ i + j * ldc ]; }

foreach( col ; 0 .. n ) foreach( row ; 0 .. m ) {
T x = getc( row, col ) * beta;
T tmp = Zero!T;
foreach( ki ; 0 .. k )
tmp += geta( row, ki ) * getb( ki, col );
setc( x * beta + tmp * alpha, row, col );
}
}

static void trsm( char side_, char uplo_, char trans_, char diag_,T )( size_t m, size_t n, T alpha, const(T)* a, size_t lda, T* b, size_t ldb ) {
enum side = cast(char)toUpper(side_);
enum uplo = cast(char)toUpper(uplo_);
Expand Down
13 changes: 12 additions & 1 deletion scid/demo.d
Expand Up @@ -345,6 +345,17 @@ version( demo ) {
}

void main() {
opTest();
auto a = Matrix!double( [[1,2,3],[3,4,5]] );
auto b = Matrix!double( [[5,6],[7,8],[9,10]] );
Matrix!(double, StorageOrder.RowMajor) c;
c[] = a * b; writeln( c.pretty );
c[] = b.t * a.t; writeln( c.pretty );
c[] = a.t * b.t; writeln( c.pretty );


readln();

//opTest();
//dMatInvTest();
}
}
5 changes: 2 additions & 3 deletions scid/internal/assertmessages.d
Expand Up @@ -81,8 +81,7 @@ string stridedToString( S )( const(S)* ptr, size_t len, size_t stride ) {
auto app = appender!string("[");
app.put( to!string(*ptr) );
auto e = ptr + len * stride;
ptr += stride;
for( ; ptr < e ; ptr += stride ) {
for( ptr += stride; ptr < e ; ptr += stride ) {
app.put( ", " );
app.put( to!string( *ptr ) );
}
Expand Down Expand Up @@ -110,7 +109,7 @@ in {
} else {
app.put( stridedToString(a, m, 1) );
auto e = a + n * lda;
for( ++ a; a < e ; a += lda ) {
for( a += lda ; a < e ; a += lda ) {
app.put( ", " );
app.put( stridedToString( a, m, 1 ) );
}
Expand Down
29 changes: 15 additions & 14 deletions scid/lapack.d
Expand Up @@ -4,7 +4,7 @@ import scid.common.traits, scid.common.meta;
import std.algorithm, std.math, std.conv;
import std.ascii, std.exception;

//debug = lapackCalls;
// debug = lapackCalls;
//version = nodeps;

debug( lapackCalls ) {
Expand Down Expand Up @@ -277,7 +277,7 @@ private struct naive_ {

T ajj;
if( uplo == 'U' ) {
for( int j = 0; j < n; j++ ) {
for( size_t j = 0; j < n; j++ ) {
if( diag == 'N' ) {
// assert( get( j, j ) != Zero!T, "fbti: Singular matrix in inverse." );
if( get( j, j ) == Zero!T ) {
Expand All @@ -295,7 +295,7 @@ private struct naive_ {
blas.scal( j, ajj, a + j*lda, 1 );
}
} else {
for( int j = toi(n) - 1 ; j >= 0 ; -- j ) {
for( size_t j = n - 1 ; j >= 0 ; -- j ) {
if( diag == 'N' ) {
// assert( get( j, j ) != Zero!T, "fbti: Singular matrix in inverse." );
set( One!T / get(j,j), j, j );
Expand Down Expand Up @@ -325,10 +325,10 @@ private struct naive_ {
mixin("a[ j * lda + i ] "~op~"= x;");
}

for( int k = 0; k < n; k++ ) {
for( size_t k = 0; k < n; k++ ) {
pivot[ k ] = k;
T maxSoFar = abs( get( k, k ) );
for( int j = k + 1; j < n; j++ ) {
for( size_t j = k + 1; j < n; j++ ) {
T cur = abs( get(j, k) );
if( maxSoFar <= cur ) {
maxSoFar = cur;
Expand All @@ -344,18 +344,19 @@ private struct naive_ {
}
}

if( get(k,k) == Zero!T )
info = k;

foreach( i ; k + 1 .. n )
set!"/"( get(k,k), i, k );
if( get(k,k) != Zero!T ) {

foreach( i ; k + 1 .. n )
set!"/"( get(k,k), i, k );

foreach( i ; k + 1 .. n ) {
foreach( j ; k + 1 .. n ) {
set!"-"( get(i,k) * get(k,j), i, j );
foreach( i ; k + 1 .. n ) {
foreach( j ; k + 1 .. n ) {
set!"-"( get(i,k) * get(k,j), i, j );
}
}
} else if( info == 0 ) {
info = k + 1;
}

++ pivot[ k ]; // convert to FORTRAN index
}
}
Expand Down
86 changes: 55 additions & 31 deletions scid/ops/common.d
Expand Up @@ -16,6 +16,7 @@ import scid.vector;
import scid.matrix;

import std.typecons;
import std.exception;

import scid.common.traits;

Expand Down Expand Up @@ -114,6 +115,7 @@ mixin template GeneralMatrixScalingAndAddition() {

void invert() {
import scid.lapack;
import std.exception;

int info;
size_t n = this.rows;
Expand All @@ -126,12 +128,13 @@ mixin template GeneralMatrixScalingAndAddition() {
lapack.getrf( n, n, this.data, this.leading, ipiv.ptr, info );
lapack.getri( n, this.data, this.leading, ipiv.ptr, work.ptr , work.length, info );

assert( info == 0, "Inversion of singular matrix." );
enforce( info == 0, "Inversion of singular matrix." );
}

void solveRight( Transpose transM = Transpose.no, Side side, Dest )( auto ref Dest dest ) if( isStridedVectorStorage!Dest || isGeneralMatrixStorage!Dest ) {
import scid.blas;
import scid.lapack;
import std.exception;

size_t n = this.rows; // check that the matrix is square
assert( n == this.columns, "Inversion of non-square matrix." );
Expand Down Expand Up @@ -188,54 +191,75 @@ mixin template GeneralMatrixScalingAndAddition() {

enum chSide = (side == Side.Left) ? 'L' : 'R';

lapack.getrf( n, n, a, n, ipiv, info ); // perform LU decomposition
lapack.getrf( n, n, a, n, ipiv, info ); // perform LU decomposition
enforce( info == 0, "Inversion of singular matrix." );
lapack.xgetrs!(chTrans, chSide)( n, nrhs, a, n, ipiv, b, ldb, info ); // perform inv-mult

assert( info == 0, "Singular matrix in inversion." );

static if( vectorRhs ) {
// copy the data back to dest if needed
if( dest.stride != 1 )
blas.copy( ldb, b, 1, dest.data, dest.stride );
}
}
/*

void matrixProduct( Transpose transA = Transpose.no, Transpose transB = Transpose.no, A, B )
( ElementType alpha, auto ref A a, auto ref B b, ElementType beta ) if( isGeneralMatrixStorage!A && isGeneralMatrixStorage!B ) {
import scid.blas;

enum orderA = transposeStorageOrder!( storageOrderOf!A, transA );
enum orderB = transposeStorageOrder!( storageOrderOf!B, transB );
enum orderC = storageOrder;

static if( !isComplexScalar!ElementType ) {
static if( (orderA != orderC) || (orderB != orderC) )
matrixProduct!( transNot!transB, transNot!transA )( alpha, b, a, beta );
else {
enum chTransA = (orderA != orderB) ^ transA ? 't' : 'n';
enum chTransB = (orderB != orderA) ^ transB ? 't' : 'n';
static if( !transA )
auto m = a.rows, ak = a.columns;
else
auto m = a.columns, ak = a.rows;
static if( !transB )
auto n = b.columns, bk = b.rows;
else
auto n = b.rows, bk = b.columns;
assert( ak == bk, format("Inner dimensions do not match in matrix product: %d vs. %d", ak, bk) );
if( beta )
assert( this.rows == m && this.columns == n, dimMismatch_(m,n,"addition") );
else
this.resize( m, n, null );
assert( a.cdata && b.cdata );
blas.gemm( chTransA, chTransB, m, n, ak, alpha, a.cdata, a.leading, b.cdata, b.leading, beta, this.data, this.leading );
enum complexElems = isComplexScalar!ElementType;

static if( orderC == StorageOrder.RowMajor ) {
pragma( msg, "Row-Major gemm is unsupoorted, fallback will be used" );
static assert( false );
}

static if( !complexElems ) {
enum chA = orderA == orderC ? 'N' : 'T';
enum chB = orderB == orderC ? 'N' : 'T';
enum doConj = false;
} else {
static if( orderA == orderC && transA && orderB == orderC && transB ) {
enum bool doConj = true;
enum chA = 'N';
enum chB = 'N';
} else static if( orderA == orderC && transA ) {
enum bool doConj = true;
enum chA = 'N';
enum chB = orderB == orderC ? 'N' : (transb ? 'T' : 'C');
} else static if( orderB == orderC && transB ) {
enum bool doConj = true;
enum chA = orderA == orderC ? 'N' : (transa ? 'T' : 'C');
enum chB = 'N';
} else {
enum bool doConj = false;
enum chA = orderA == orderC ? 'N' : (transa ? 'C' : 'T');
enum chB = orderB == orderC ? 'N' : (transb ? 'C' : 'T');
}
}

static if( chA != 'N' )
size_t m = a.major, k = a.minor;
else
size_t m = a.minor, k = a.major;

static if( chB != 'N' ) {
size_t n = b.minor;
assert( k == b.major );
} else {
assert( false );
size_t n = b.major;
assert( k == b.minor );
}
}*/

resize( m, n, null );

blas.gemm!( chA, chB )( m, n, k, alpha, a.cdata, a.leading, b.cdata, b.leading, beta, this.data, this.leading );
static if( doConj )
blas.xgecopyc( m, n, this.data, this.leading, this.data, this.leading );
}
}

/** Compute the dot product of a row and a column of possibly transposed matrices. */
Expand Down
6 changes: 3 additions & 3 deletions scid/scid.visualdproj
Expand Up @@ -99,7 +99,7 @@
<oneobj>0</oneobj>
<trace>0</trace>
<quiet>0</quiet>
<verbose>1</verbose>
<verbose>0</verbose>
<vtls>0</vtls>
<symdebug>0</symdebug>
<optimize>1</optimize>
Expand All @@ -121,7 +121,7 @@
<useSwitchError>0</useSwitchError>
<useUnitTests>0</useUnitTests>
<useInline>1</useInline>
<release>0</release>
<release>1</release>
<preservePaths>0</preservePaths>
<warnings>1</warnings>
<infowarnings>1</infowarnings>
Expand Down Expand Up @@ -207,8 +207,8 @@
<File path="common\storagetraits.d" />
</Folder>
<Folder name="internal">
<File path="internal\assertmessages.d" />
<File path="internal\regionallocator.d" />
<File path="internal\assertmessages.d" />
</Folder>
<Folder name="storage">
<File path="storage\array.d" />
Expand Down
4 changes: 2 additions & 2 deletions scid/vector.d
Expand Up @@ -160,8 +160,8 @@ struct BasicVector( Storage_ ) {
}

void opSliceAssign( Rhs )( auto ref Rhs rhs ) {
static if( is( Rhs E : E[] ) && isConvertible( E, ElementType ) )
evalCopy( BasicVector(rhs), this );
static if( is( Rhs E : E[] ) && isConvertible!( E, ElementType ) )
evalCopy( ExternalVectorView!( E, vectorType )( rhs ), this );
else static if( closureOf!Rhs == Closure.Scalar )
evalCopy( relatedConstant( rhs, this ), this );
else
Expand Down

0 comments on commit ea99ace

Please sign in to comment.