forked from WQCG/blitzdg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DirectSolver.cpp
72 lines (51 loc) · 1.6 KB
/
DirectSolver.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// Copyright (C) 2017-2018 Derek Steinmoeller.
// See COPYING and LICENSE files at project root for more details.
#include <DirectSolver.hpp>
#include <blitz/array.h>
using namespace blitz;
/**
* Constructor. Takes a reference to a SparseMatrixConverter.
*/
DirectSolver::DirectSolver(SparseMatrixConverter const & _matrixConverter) {
MatrixConverter = _matrixConverter;
}
extern "C" {
void dsgesv_( int* n, int* nrhs, double* a, int* lda,
int* ipiv, double* b, int* ldb, double* x, int* ldx,
double* work, float* swork, int* iter, int* info );
}
/**
* Solve A*X=B using LAPACK. Here, B and X are allowed to have multiple columns.
*/
void DirectSolver::solve(const Array<double,2> & A, const Array<double, 2> & B, Array<double, 2> & X) {
firstIndex ii;
secondIndex jj;
int sz = A.rows();
int Nrhs = B.cols();
int dim = sz*Nrhs;
int lda = sz;
int ldb = sz;
int ldx = sz;
int ipiv[sz];
double work[sz*Nrhs];
float swork[sz*(sz+Nrhs)];
int info;
int iter;
double Apod[sz*lda];
double Bpod[dim];
double Xpod[dim];
Array<double, 2> Atrans(sz, sz);
Array<double, 2> Btrans(Nrhs, sz);
Array<double, 2> Xtrans(Nrhs, sz);
Atrans = A(jj,ii);
Btrans = B(jj,ii);
MatrixConverter.fullToPodArray(Atrans, Apod);
MatrixConverter.fullToPodArray(Btrans, Bpod);
dsgesv_(&sz, &Nrhs, Apod, &lda,
ipiv, Bpod, &ldb, Xpod, &ldx,
work, swork, &iter, &info);
MatrixConverter.podArrayToFull(Xpod, Xtrans);
X = Xtrans(jj,ii);
}
DirectSolver::~DirectSolver() {
}