Skip to content

Commit

Permalink
Matmul code and speed improvements (#1483)
Browse files Browse the repository at this point in the history
* Initial commit: remove code duplication

* Fix transposition order bug

* Remove the mysterious assert
  • Loading branch information
pmichel31415 authored and neubig committed Oct 25, 2018
1 parent bc3ff25 commit 946200c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 138 deletions.
1 change: 0 additions & 1 deletion dynet/matrix-multiply.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ inline void MatrixMultiplyTranspAcc(const dynet::Device_GPU & dev, const dynet::
r.v, r.d.rows(),
dev.kSCALAR_ONE, y.v, y.d.rows()));
} else {
DYNET_ARG_CHECK(false, "MatrixMultiplyTranspAcc");
CUBLAS_CHECK(cublasSgemmStridedBatched(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
y.d.rows(), y.d.cols(), l.d.cols(),
dev.kSCALAR_ONE,
Expand Down
71 changes: 4 additions & 67 deletions dynet/nodes-affinetransform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,11 @@ void AffineTransform::forward_dev_impl(const MyDevice & dev, const vector<const
}

// Perform multiplication
#ifdef __CUDACC__
for (unsigned i = 1; i < xs.size(); i += 2)
for (unsigned i = 1; i < xs.size(); i += 2) {
DYNET_ASSERT(xs[i+1]->d.bd == 1 || xs[i+1]->d.bd == xs[i]->d.bd, "Failed dimension check in AffineTransform::forward");
// fx = (acc_sclar)*fx + xs[0] * xs[1]
MatrixMultiply(dev, *xs[i], *xs[i + 1], fx, dev.kSCALAR_ONE);
#else
// Multiply
for (unsigned i = 1; i < xs.size(); i += 2) {
if(xs[i]->d.bd == 1 && xs[i+1]->d.bd == fx.d.bd) {
colbatch_matrix(fx).noalias() += mat(*xs[i]) * colbatch_matrix(*xs[i+1]);
} else {
DYNET_ASSERT(xs[i+1]->d.bd == 1 || xs[i+1]->d.bd == xs[i]->d.bd, "Failed dimension check in AffineTransform::forward");
for(unsigned b = 0; b < fx.d.bd; ++b) {
batch_matrix(fx, b).noalias() += batch_matrix(*xs[i], b) * batch_matrix(*xs[i+1], b);
}
}
}
#endif
}
}

Expand Down Expand Up @@ -160,60 +148,9 @@ void AffineTransform::backward_dev_impl(const MyDevice & dev,

// Left argument of matrix multiply
} else if (i % 2 == 1) {
int max_b = max(dEdf.d.bd, xs[i+1]->d.bd);
#ifdef __CUDACC__
if(dEdxi.d.bd == 1 && (dEdf.d.bd == xs[i+1]->d.bd)) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
dEdxi.d.rows(), dEdxi.d.cols(), dEdf.d.cols() * dEdf.d.batch_elems(),
dev.kSCALAR_ONE,
dEdf.v, dEdf.d.rows(),
xs[i+1]->v, xs[i+1]->d.rows(),
dev.kSCALAR_ONE, dEdxi.v, dEdxi.d.rows()));
} else {
for(int b = 0; b < max_b; ++b)
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
dEdxi.d.rows(), dEdxi.d.cols(), dEdf.d.cols(),
dev.kSCALAR_ONE,
dEdf.batch_ptr(b), dEdf.d.rows(),
xs[i+1]->batch_ptr(b), xs[i+1]->d.rows(),
dev.kSCALAR_ONE, dEdxi.batch_ptr(b), dEdxi.d.rows()));
}
#else
if(dEdxi.d.bd == 1 && (dEdf.d.bd == xs[i+1]->d.bd)) {
mat(dEdxi).noalias() += colbatch_matrix(dEdf) * colbatch_matrix(*xs[i+1]).transpose();
} else {
for(int b = 0; b < max_b; ++b)
batch_matrix(dEdxi, b).noalias() += batch_matrix(dEdf, b) * batch_matrix(*xs[i+1], b).transpose();
}
#endif
MatrixMultiplyTranspAcc(dev, dEdf, *xs[i+1], dEdxi);
} else { // right argument of matrix multiply
int max_b = max(xs[i-1]->d.bd, dEdf.d.bd);
#ifdef __CUDACC__
// Do a single multiply if xs[i-1] has one batch
if(xs[i-1]->d.bd == 1 && dEdxi.d.bd == dEdf.d.bd) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
dEdxi.d.rows(), dEdxi.d.cols()*dEdxi.d.batch_elems(), xs[i-1]->d.rows(),
dev.kSCALAR_ONE,
xs[i-1]->v, xs[i-1]->d.rows(),
dEdf.v, dEdf.d.rows(),
dev.kSCALAR_ONE, dEdxi.v, dEdxi.d.rows()));
} else {
for(int b = 0; b < max_b; ++b)
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
dEdxi.d.rows(), dEdxi.d.cols(), xs[i-1]->d.rows(),
dev.kSCALAR_ONE,
xs[i-1]->batch_ptr(b), xs[i-1]->d.rows(),
dEdf.batch_ptr(b), dEdf.d.rows(),
dev.kSCALAR_ONE, dEdxi.batch_ptr(b), dEdxi.d.rows()));
}
#else
if(xs[i-1]->d.bd == 1 && dEdxi.d.bd == dEdf.d.bd) {
colbatch_matrix(dEdxi).noalias() += mat(*xs[i-1]).transpose() * colbatch_matrix(dEdf);
} else {
for(int b = 0; b < max_b; ++b)
batch_matrix(dEdxi, b).noalias() += batch_matrix(*xs[i-1], b).transpose() * batch_matrix(dEdf, b);
}
#endif
MatrixTranspMultiplyAcc(dev, *xs[i-1], dEdf, dEdxi);
}
}
DYNET_NODE_INST_DEV_IMPL(AffineTransform)
Expand Down
76 changes: 6 additions & 70 deletions dynet/nodes-matrixmultiply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,9 @@ std::vector<int> MatrixMultiply::autobatch_concat(const ComputationGraph & cg) c
template<class MyDevice>
void MatrixMultiply::forward_dev_impl(const MyDevice & dev, const vector<const Tensor*>& xs, Tensor& fx) const {
DYNET_ASSERT(xs.size() == 2, "Failed dimension check in MatrixMultiply::forward");
#ifdef __CUDACC__
DYNET_ARG_CHECK(fx.d.bd == max(xs[0]->d.bd, xs[1]->d.bd), "Failed dimension check in MatrixMultiply::forward");
// fx = mat(fx0) + xs[0] * xs[1]
dynet::MatrixMultiply(dev, *xs[0], *xs[1], fx, dev.kSCALAR_ZERO);
#else
DYNET_ARG_CHECK(fx.d.bd == max(xs[0]->d.bd, xs[1]->d.bd), "Failed dimension check in MatrixMultiply::forward");
if(xs[0]->d.bd == 1) {
// If the left side has one batch, multiply by columns
// [x, z, b] = [x, y] * [y, z, b]
// -> [x, z*b] = [x, y], [y, z*b]
colbatch_matrix(fx).noalias() = mat(*xs[0]) * colbatch_matrix(*xs[1]);
} else {
// Otherwise, loop over the batches
DYNET_ARG_CHECK(xs[1]->d.bd == 1 || xs[1]->d.bd == xs[0]->d.bd,
"Number of batch elements in matrix multiply must match, but got:"
<< xs[0]->d.bd << " != " << xs[1]->d.bd);
for(unsigned b = 0; b < xs[0]->d.bd; ++b)
batch_matrix(fx, b).noalias() = batch_matrix(*xs[0], b) * batch_matrix(*xs[1], b);
}
#endif
}

template<class MyDevice>
Expand All @@ -79,62 +63,14 @@ void MatrixMultiply::backward_dev_impl(const MyDevice & dev,
unsigned i,
Tensor& dEdxi) const {
DYNET_ASSERT(i < 2, "Failed dimension check in MatrixMultiply::backward");
int max_b = max(xs[0]->d.bd, xs[1]->d.bd);
#ifdef __CUDACC__
// y = A * B
if (i == 0) {
if(dEdxi.d.bd == 1 && (dEdf.d.bd == xs[1]->d.bd)) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
dEdxi.d.rows(), dEdxi.d.cols(), dEdf.d.cols() * dEdf.d.batch_elems(),
dev.kSCALAR_ONE,
dEdf.v, dEdf.d.rows(),
xs[1]->v, xs[1]->d.rows(),
dev.kSCALAR_ONE, dEdxi.v, dEdxi.d.rows()));
} else {
for(int b = 0; b < max_b; ++b)
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
dEdxi.d.rows(), dEdxi.d.cols(), dEdf.d.cols(),
dev.kSCALAR_ONE,
dEdf.batch_ptr(b), dEdf.d.rows(),
xs[1]->batch_ptr(b), xs[1]->d.rows(),
dev.kSCALAR_ONE, dEdxi.batch_ptr(b), dEdxi.d.rows()));
}
// dA = dy * B^T
MatrixMultiplyTranspAcc(dev, dEdf, *xs[1], dEdxi);
} else {
// Do a single multiply if xs[0] has one batch
if(xs[0]->d.bd == 1) {
// colbatch_matrix(dEdxi).noalias() += (mat(*xs[0])).transpose() * colbatch_matrix(dEdf);
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
dEdxi.d.rows(), dEdxi.d.cols()*dEdxi.d.batch_elems(), xs[0]->d.rows(),
dev.kSCALAR_ONE,
xs[0]->v, xs[0]->d.rows(),
dEdf.v, dEdf.d.rows(),
dev.kSCALAR_ONE, dEdxi.v, dEdxi.d.rows()));
} else {
for(int b = 0; b < max_b; ++b)
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
dEdxi.d.rows(), dEdxi.d.cols(), xs[0]->d.rows(),
dev.kSCALAR_ONE,
xs[0]->batch_ptr(b), xs[0]->d.rows(),
dEdf.batch_ptr(b), dEdf.d.rows(),
dev.kSCALAR_ONE, dEdxi.batch_ptr(b), dEdxi.d.rows()));
}
// dB = A^T * dy
MatrixTranspMultiplyAcc(dev, *xs[0], dEdf, dEdxi);
}
#else
if (i == 0) {
if(dEdxi.d.bd == 1 && (dEdf.d.bd == xs[1]->d.bd)) {
(mat(dEdxi)).noalias() += colbatch_matrix(dEdf) * colbatch_matrix(*xs[1]).transpose();
} else {
for(int b = 0; b < max_b; ++b)
batch_matrix(dEdxi, b).noalias() += batch_matrix(dEdf, b) * batch_matrix(*xs[1], b).transpose();
}
} else {
if(xs[0]->d.bd == 1) {
colbatch_matrix(dEdxi).noalias() += (mat(*xs[0])).transpose() * colbatch_matrix(dEdf);
} else {
for(int b = 0; b < max_b; ++b)
batch_matrix(dEdxi, b).noalias() += batch_matrix(*xs[0], b).transpose() * batch_matrix(dEdf, b);
}
}
#endif
}
DYNET_NODE_INST_DEV_IMPL(MatrixMultiply)

Expand Down

0 comments on commit 946200c

Please sign in to comment.