Skip to content

Commit

Permalink
Client improvements 2 (ROCm#229)
Browse files Browse the repository at this point in the history
* Improved variables_map emulation

* Prototype Arguments class

* Prototype argument model for gesvd

* Addressed review comments

* Prototype argument model for sygv/hegv

* Print names of arguments not consumed by tests

* Addressed review comment

* change format in messages and outputs (#5)

* format messages and outputs

* format messages and outputs

* review corrections

* Addressed review comments

* New argument model for syev/heev

* New argument model for sygst/hegst

* New argument model for sytrd/hetrd

* New argument model for potrf

* New argument model for getrs

* New argument model for getri

* New argument model for getrf

* New argument model for geqrf

* New argument model for geqlf

* New argument model for gels

* New argument model for gelqf

* New argument model for gebrd

* New argument model for bdsqr, sterf, and steqr

* New argument model for labrd, lacgv, laswp, and latrd

* New argument model for larf, larfb, larfg, and larft

* New argument model for orgxx functions

* New argument model for ormxx functions

* New argument model for HMM test

* Final clean-up

* Consistent style for bench arguments

* Updated changelog

* Alphabetize arguments

* Bug fixes

* Addressed review comments

* Apply clang format

* Addressed review comment

* Addressed some review comments

* Remove static defaults from m, n, k, storev, side, k1, k2, and nu

* Default to square matrices

* Updated help string

* Use m as required parameter instead of n

* Adjust laswp and latrd defaults

Co-authored-by: Juan Zuniga-Anaya <50754207+jzuniga-amd@users.noreply.github.com>
  • Loading branch information
tfalders and jzuniga-amd committed Mar 23, 2021
1 parent 6021fdc commit 12afdab
Show file tree
Hide file tree
Showing 147 changed files with 1,981 additions and 1,409 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ Full documentation for rocSOLVER is available at [rocsolver.readthedocs.io](http
### Optimizations

### Changed
- Argument names for the benchmark client now match argument names from the public API

### Removed

### Fixed
- Fixed known issues with Thin-SVD. The problem was identified in the test specification, not in the thin-SVD
implementation or the rocBLAS gemm\_batched routines.
- Benchmark client will no longer crash as a result of leading dimension or stride arguments not being provided
on the command line.

### Known Issues

Expand Down
502 changes: 303 additions & 199 deletions clients/benchmarks/client.cpp

Large diffs are not rendered by default.

128 changes: 68 additions & 60 deletions clients/common/lapack_host_reference.cpp

Large diffs are not rendered by default.

25 changes: 13 additions & 12 deletions clients/gtest/bdsqr_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,20 @@ Arguments bdsqr_setup_arguments(bdsqr_tuple tup)

Arguments arg;

arg.M = size[0]; // n
arg.N = size[1]; // nv
arg.K = size[2]; // nu
arg.S4 = size[3]; // nc
rocblas_int n = size[0];
rocblas_int nv = size[1];
rocblas_int nu = size[2];
rocblas_int nc = size[3];
arg.set<rocblas_int>("n", n);
arg.set<rocblas_int>("nv", nv);
arg.set<rocblas_int>("nu", nu);
arg.set<rocblas_int>("nc", nc);

arg.uplo_option = opt[0] ? 'L' : 'U';
arg.set<char>("uplo", opt[0] ? 'L' : 'U');

arg.lda = (arg.N > 0) ? arg.M : 1; // ldv
arg.lda += opt[1] * 10;
arg.ldb = (arg.K > 0) ? arg.K : 1; // ldu
arg.ldb += opt[2] * 10;
arg.ldc = (arg.S4 > 0) ? arg.M : 1; // ldc
arg.ldc += opt[3] * 10;
arg.set<rocblas_int>("ldv", (nv > 0 ? n : 1) + opt[1] * 10);
arg.set<rocblas_int>("ldu", (nu > 0 ? nu : 1) + opt[2] * 10);
arg.set<rocblas_int>("ldc", (nc > 0 ? n : 1) + opt[3] * 10);

arg.timing = 0;

Expand All @@ -97,7 +98,7 @@ class BDSQR : public ::TestWithParam<bdsqr_tuple>
{
Arguments arg = bdsqr_setup_arguments(GetParam());

if(arg.M == 0 && arg.uplo_option == 'L')
if(arg.peek<rocblas_int>("n") == 0 && arg.peek<char>("uplo") == 'L')
testing_bdsqr_bad_arg<T>();

testing_bdsqr<T>(arg);
Expand Down
15 changes: 6 additions & 9 deletions clients/gtest/gebd2_gebrd_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ Arguments gebrd_setup_arguments(gebrd_tuple tup)

Arguments arg;

arg.M = matrix_size[0];
arg.N = n_size;
arg.lda = matrix_size[1];
arg.set<rocblas_int>("m", matrix_size[0]);
arg.set<rocblas_int>("n", n_size);
arg.set<rocblas_int>("lda", matrix_size[1]);

arg.timing = 0;
// only testing standard use case/defaults for strides

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsp = min(arg.M, arg.N);
arg.bsa = arg.lda * arg.N;
arg.timing = 0;

return arg;
}
Expand All @@ -81,7 +78,7 @@ class GEBD2_GEBRD : public ::TestWithParam<gebrd_tuple>
{
Arguments arg = gebrd_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0)
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0)
testing_gebd2_gebrd_bad_arg<BATCHED, STRIDED, BLOCKED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
15 changes: 6 additions & 9 deletions clients/gtest/gelq2_gelqf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,13 @@ Arguments gelqf_setup_arguments(gelqf_tuple tup)

Arguments arg;

arg.M = matrix_size[0];
arg.N = n_size;
arg.lda = matrix_size[1];
arg.set<rocblas_int>("m", matrix_size[0]);
arg.set<rocblas_int>("n", n_size);
arg.set<rocblas_int>("lda", matrix_size[1]);

arg.timing = 0;
// only testing standard use case/defaults for strides

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsp = min(arg.M, arg.N);
arg.bsa = arg.lda * arg.N;
arg.timing = 0;

return arg;
}
Expand All @@ -82,7 +79,7 @@ class GELQ2_GELQF : public ::TestWithParam<gelqf_tuple>
{
Arguments arg = gelqf_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0)
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0)
testing_gelq2_gelqf_bad_arg<BATCHED, STRIDED, BLOCKED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
23 changes: 10 additions & 13 deletions clients/gtest/gels_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,18 @@ Arguments gels_setup_arguments(gels_tuple tup)

Arguments arg;

arg.M = std::get<0>(matrix_sizeA);
arg.N = std::get<1>(matrix_sizeA);
arg.lda = std::get<2>(matrix_sizeA);
arg.ldb = std::get<3>(matrix_sizeA);
arg.singular = std::get<4>(matrix_sizeA);
arg.set<rocblas_int>("m", std::get<0>(matrix_sizeA));
arg.set<rocblas_int>("n", std::get<1>(matrix_sizeA));
arg.set<rocblas_int>("lda", std::get<2>(matrix_sizeA));
arg.set<rocblas_int>("ldb", std::get<3>(matrix_sizeA));

arg.K = std::get<0>(matrix_sizeB);
arg.transA_option = std::get<1>(matrix_sizeB);
arg.set<rocblas_int>("nrhs", std::get<0>(matrix_sizeB));
arg.set<char>("trans", std::get<1>(matrix_sizeB));

arg.timing = 0;
// only testing standard use case/defaults for strides

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsa = arg.lda * arg.N;
arg.bsb = arg.ldb * arg.K;
arg.timing = 0;
arg.singular = std::get<4>(matrix_sizeA);

return arg;
}
Expand All @@ -106,7 +103,7 @@ class GELS : public ::TestWithParam<gels_tuple>
{
Arguments arg = gels_setup_arguments(GetParam());

if(arg.M == 0 && arg.K == 0)
if(arg.peek<rocblas_int>("n") == 0 && arg.peek<rocblas_int>("nrhs") == 0)
testing_gels_bad_arg<BATCHED, STRIDED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
15 changes: 6 additions & 9 deletions clients/gtest/geql2_geqlf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,13 @@ Arguments geqlf_setup_arguments(geqlf_tuple tup)

Arguments arg;

arg.M = matrix_size[0];
arg.N = n_size;
arg.lda = matrix_size[1];
arg.set<rocblas_int>("m", matrix_size[0]);
arg.set<rocblas_int>("n", n_size);
arg.set<rocblas_int>("lda", matrix_size[1]);

arg.timing = 0;
// only testing standard use case/defaults for strides

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsp = min(arg.M, arg.N);
arg.bsa = arg.lda * arg.N;
arg.timing = 0;

return arg;
}
Expand All @@ -82,7 +79,7 @@ class GEQL2_GEQLF : public ::TestWithParam<geqlf_tuple>
{
Arguments arg = geqlf_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0)
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0)
testing_geql2_geqlf_bad_arg<BATCHED, STRIDED, BLOCKED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
15 changes: 6 additions & 9 deletions clients/gtest/geqr2_geqrf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,13 @@ Arguments geqrf_setup_arguments(geqrf_tuple tup)

Arguments arg;

arg.M = matrix_size[0];
arg.N = n_size;
arg.lda = matrix_size[1];
arg.set<rocblas_int>("m", matrix_size[0]);
arg.set<rocblas_int>("n", n_size);
arg.set<rocblas_int>("lda", matrix_size[1]);

arg.timing = 0;
// only testing standard use case/defaults for strides

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsp = min(arg.M, arg.N);
arg.bsa = arg.lda * arg.N;
arg.timing = 0;

return arg;
}
Expand All @@ -82,7 +79,7 @@ class GEQR2_GEQRF : public ::TestWithParam<geqrf_tuple>
{
Arguments arg = geqrf_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0)
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0)
testing_geqr2_geqrf_bad_arg<BATCHED, STRIDED, BLOCKED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
49 changes: 23 additions & 26 deletions clients/gtest/gesvd_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,49 +87,45 @@ Arguments gesvd_setup_arguments(gesvd_tuple tup)
Arguments arg;

// sizes
arg.M = size[0];
arg.N = size[1];
rocblas_int m = size[0];
rocblas_int n = size[1];
arg.set<rocblas_int>("m", m);
arg.set<rocblas_int>("n", n);

// fast algorithm
if(size[2] == 0)
arg.workmode = 'I';
arg.set<char>("fast_alg", 'I');
else
arg.workmode = 'O';
arg.set<char>("fast_alg", 'O');

// leading dimensions
arg.lda = arg.M; // lda
arg.ldb = arg.M; // ldu
arg.ldv = opt[4] == 2 ? arg.N : min(arg.M, arg.N); // ldv
arg.lda += opt[0] * 10;
arg.ldb += opt[1] * 10;
arg.ldv += opt[2] * 10;
arg.set<rocblas_int>("lda", m + opt[0] * 10);
arg.set<rocblas_int>("ldu", m + opt[1] * 10);
if(opt[4] == 2)
arg.set<rocblas_int>("ldv", n + opt[2] * 10);
else
arg.set<rocblas_int>("ldv", min(m, n) + opt[2] * 10);

// vector options
if(opt[3] == 0)
arg.left_svect = 'O';
arg.set<char>("left_svect", 'O');
else if(opt[3] == 1)
arg.left_svect = 'S';
arg.set<char>("left_svect", 'S');
else if(opt[3] == 2)
arg.left_svect = 'A';
arg.set<char>("left_svect", 'A');
else
arg.left_svect = 'N';
arg.set<char>("left_svect", 'N');

if(opt[4] == 0)
arg.right_svect = 'O';
arg.set<char>("right_svect", 'O');
else if(opt[4] == 1)
arg.right_svect = 'S';
arg.set<char>("right_svect", 'S');
else if(opt[4] == 2)
arg.right_svect = 'A';
arg.set<char>("right_svect", 'A');
else
arg.right_svect = 'N';
arg.set<char>("right_svect", 'N');

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsa = arg.lda * arg.N; // strideA
arg.bsb = min(arg.M, arg.N); // strideS
arg.bsc = arg.ldb * arg.M; // strideU
arg.bsp = arg.ldv * arg.N; // strideV
arg.bs5 = arg.bsb; // strideE
// only testing standard use case/defaults for strides

arg.timing = 0;

Expand All @@ -148,7 +144,8 @@ class GESVD : public ::TestWithParam<gesvd_tuple>
{
Arguments arg = gesvd_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0 && arg.left_svect == 'N' && arg.right_svect == 'N')
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0
&& arg.peek<char>("left_svect") == 'N' && arg.peek<char>("right_svect") == 'N')
testing_gesvd_bad_arg<BATCHED, STRIDED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
17 changes: 7 additions & 10 deletions clients/gtest/getf2_getrf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,15 @@ Arguments getrf_setup_arguments(getrf_tuple tup)

Arguments arg;

arg.M = matrix_size[0];
arg.N = n_size;
arg.lda = matrix_size[1];
arg.set<rocblas_int>("m", matrix_size[0]);
arg.set<rocblas_int>("n", n_size);
arg.set<rocblas_int>("lda", matrix_size[1]);

// only testing standard use case/defaults for strides

arg.timing = 0;
arg.singular = matrix_size[2];

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsp = min(arg.M, arg.N);
arg.bsa = arg.lda * arg.N;

return arg;
}

Expand All @@ -90,7 +87,7 @@ class GETF2_GETRF : public ::TestWithParam<getrf_tuple>
{
Arguments arg = getrf_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0)
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0)
testing_getf2_getrf_bad_arg<BATCHED, STRIDED, BLOCKED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand All @@ -115,7 +112,7 @@ class GETF2_GETRF_NPVT : public ::TestWithParam<getrf_tuple>
{
Arguments arg = getrf_setup_arguments(GetParam());

if(arg.M == 0 && arg.N == 0)
if(arg.peek<rocblas_int>("m") == 0 && arg.peek<rocblas_int>("n") == 0)
testing_getf2_getrf_npvt_bad_arg<BATCHED, STRIDED, BLOCKED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down
13 changes: 5 additions & 8 deletions clients/gtest/getri_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,14 @@ Arguments getri_setup_arguments(getri_tuple tup)

Arguments arg;

arg.N = tup[0];
arg.lda = tup[1];
arg.set<rocblas_int>("n", tup[0]);
arg.set<rocblas_int>("lda", tup[1]);

// only testing standard use case/defaults for strides

arg.timing = 0;
arg.singular = tup[2];

// only testing standard use case for strides
// strides are ignored in normal and batched tests
arg.bsp = arg.N;
arg.bsa = arg.lda * arg.N;

return arg;
}

Expand All @@ -68,7 +65,7 @@ class GETRI : public ::TestWithParam<getri_tuple>
{
Arguments arg = getri_setup_arguments(GetParam());

if(arg.N == 0)
if(arg.peek<rocblas_int>("n") == 0)
testing_getri_bad_arg<BATCHED, STRIDED, T>();

arg.batch_count = (BATCHED || STRIDED ? 3 : 1);
Expand Down

0 comments on commit 12afdab

Please sign in to comment.