diff --git a/blas/testblas/dtrsm.go b/blas/testblas/dtrsm.go index d5ec6258a6..47c49e056a 100644 --- a/blas/testblas/dtrsm.go +++ b/blas/testblas/dtrsm.go @@ -7,6 +7,7 @@ package testblas import ( "testing" + "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/floats" ) @@ -16,7 +17,8 @@ type Dtrsmer interface { alpha float64, a []float64, lda int, b []float64, ldb int) } -func DtrsmTest(t *testing.T, blasser Dtrsmer) { +func DtrsmTest(t *testing.T, impl Dtrsmer) { + rnd := rand.New(rand.NewSource(1)) for i, test := range []struct { s blas.Side ul blas.Uplo @@ -27,7 +29,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { alpha float64 a [][]float64 b [][]float64 - ans [][]float64 + want [][]float64 }{ { s: blas.Left, @@ -47,7 +49,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {1, 3.4}, {-0.5, -0.5}, {2, 3.2}, @@ -71,7 +73,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {60, 96}, {-42, -66}, {10, 16}, @@ -95,7 +97,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 1, 3}, {5, 8, 9, 10}, }, - ans: [][]float64{ + want: [][]float64{ {1, 3.4, 1.2, 13}, {-0.5, -0.5, -4, -3.5}, {2, 3.2, 3.6, 4}, @@ -119,7 +121,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 1, 3}, {5, 8, 9, 10}, }, - ans: [][]float64{ + want: [][]float64{ {60, 96, 126, 146}, {-42, -66, -88, -94}, {10, 16, 18, 20}, @@ -143,7 +145,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {4.5, 9}, {-0.375, -1.5}, {-0.75, -12.0 / 7}, @@ -167,7 +169,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {9, 18}, {-15, -33}, {60, 132}, @@ -191,7 +193,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 1, 3}, {5, 8, 9, 10}, }, - ans: [][]float64{ + want: [][]float64{ {4.5, 9, 3, 13.5}, {-0.375, -1.5, -1.5, -63.0 / 8}, {-0.75, -12.0 / 7, 3, 39.0 / 28}, @@ -215,7 +217,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 1, 3}, {5, 8, 9, 10}, }, - ans: [][]float64{ + want: [][]float64{ {9, 18, 6, 27}, {-15, -33, -15, -72}, {60, 132, 87, 327}, @@ -239,7 +241,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {4.5, 9}, {-0.30, -1.2}, {-6.0 / 35, -24.0 / 35}, @@ -263,7 +265,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {9, 18}, {-15, -33}, {69, 150}, @@ -287,7 +289,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 8, 9}, {5, 8, 10, 11}, }, - ans: [][]float64{ + want: [][]float64{ {4.5, 9, 9, 10.5}, {-0.3, -1.2, -0.6, -0.9}, {-6.0 / 35, -24.0 / 35, -12.0 / 35, -18.0 / 35}, @@ -311,7 +313,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 8, 9}, {5, 8, 10, 11}, }, - ans: [][]float64{ + want: [][]float64{ {9, 18, 18, 21}, {-15, -33, -30, -36}, {69, 150, 138, 165}, @@ -335,7 +337,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {-0.46875, 0.375}, {0.1875, 0.75}, {1.875, 3}, @@ -359,7 +361,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7}, {5, 8}, }, - ans: [][]float64{ + want: [][]float64{ {168, 267}, {-78, -123}, {15, 24}, @@ -383,7 +385,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 4, 5}, {5, 8, 6, 7}, }, - ans: [][]float64{ + want: [][]float64{ {-0.46875, 0.375, -2.0625, -1.78125}, {0.1875, 0.75, -0.375, -0.1875}, {1.875, 3, 2.25, 2.625}, @@ -407,7 +409,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {4, 7, 4, 5}, {5, 8, 6, 7}, }, - ans: [][]float64{ + want: [][]float64{ {168, 267, 204, 237}, {-78, -123, -96, -111}, {15, 24, 18, 21}, @@ -432,7 +434,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {15, -2.4, -48.0 / 35}, {19.5, -3.3, -66.0 / 35}, {24, -4.2, -2.4}, @@ -458,7 +460,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {30, -57, 258}, {39, -75, 339}, {48, -93, 420}, @@ -482,7 +484,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {15, -2.4, -48.0 / 35}, {19.5, -3.3, -66.0 / 35}, }, @@ -504,7 +506,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {30, -57, 258}, {39, -75, 339}, }, @@ -528,7 +530,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {4.2, 1.2, 4.5}, {5.775, 1.65, 5.625}, {7.35, 2.1, 6.75}, @@ -554,7 +556,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {435, -183, 36}, {543, -228, 45}, {651, -273, 54}, @@ -578,7 +580,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {4.2, 1.2, 4.5}, {5.775, 1.65, 5.625}, }, @@ -600,7 +602,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {435, -183, 36}, {543, -228, 45}, }, @@ -624,7 +626,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {4.2, 1.2, 4.5}, {5.775, 1.65, 5.625}, {7.35, 2.1, 6.75}, @@ -650,7 +652,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {435, -183, 36}, {543, -228, 45}, {651, -273, 54}, @@ -674,7 +676,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {4.2, 1.2, 4.5}, {5.775, 1.65, 5.625}, }, @@ -696,7 +698,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {435, -183, 36}, {543, -228, 45}, }, @@ -720,7 +722,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {15, -2.4, -1.2}, {19.5, -3.3, -1.65}, {24, -4.2, -2.1}, @@ -746,7 +748,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {16, 17, 18}, {19, 20, 21}, }, - ans: [][]float64{ + want: [][]float64{ {30, -57, 258}, {39, -75, 339}, {48, -93, 420}, @@ -770,7 +772,7 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {15, -2.4, -1.2}, {19.5, -3.3, -1.65}, }, @@ -792,24 +794,75 @@ func DtrsmTest(t *testing.T, blasser Dtrsmer) { {10, 11, 12}, {13, 14, 15}, }, - ans: [][]float64{ + want: [][]float64{ {30, -57, 258}, {39, -75, 339}, }, }, + { + s: blas.Right, + ul: blas.Lower, + tA: blas.Trans, + d: blas.Unit, + m: 2, + n: 3, + alpha: 0, + a: [][]float64{ + {2, 0, 0}, + {3, 5, 0}, + {4, 6, 8}, + }, + b: [][]float64{ + {10, 11, 12}, + {13, 14, 15}, + }, + want: [][]float64{ + {0, 0, 0}, + {0, 0, 0}, + }, + }, } { - aFlat := flatten(test.a) - bFlat := flatten(test.b) - ansFlat := flatten(test.ans) - var lda int - if test.s == blas.Left { - lda = test.m - } else { - lda = test.n + m := test.m + n := test.n + na := m + if test.s == blas.Right { + na = n } - blasser.Dtrsm(test.s, test.ul, test.tA, test.d, test.m, test.n, test.alpha, aFlat, lda, bFlat, test.n) - if !floats.EqualApprox(ansFlat, bFlat, 1e-13) { - t.Errorf("Case %v: Want %v, got %v.", i, ansFlat, bFlat) + for _, lda := range []int{na, na + 3} { + for _, ldb := range []int{n, n + 5} { + a := make([]float64, na*lda) + for i := range a { + a[i] = rnd.NormFloat64() + } + for i := 0; i < na; i++ { + for j := 0; j < na; j++ { + a[i*lda+j] = test.a[i][j] + } + } + + b := make([]float64, m*ldb) + for i := range b { + b[i] = rnd.NormFloat64() + } + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + b[i*ldb+j] = test.b[i][j] + } + } + + impl.Dtrsm(test.s, test.ul, test.tA, test.d, test.m, test.n, test.alpha, a, lda, b, ldb) + + want := make([]float64, len(b)) + copy(want, b) + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + want[i*ldb+j] = test.want[i][j] + } + } + if !floats.EqualApprox(want, b, 1e-13) { + t.Errorf("Case %v: Want %v, got %v.", i, want, b) + } + } } } }