Skip to content

Commit

Permalink
fixup! [mlir][SVE] Add more e2e test for vector.contract
Browse files Browse the repository at this point in the history
Further generalization
  • Loading branch information
banach-space committed Oct 27, 2023
1 parent 2558a03 commit e80df20
Showing 1 changed file with 26 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,20 @@ func.func @dot_product_i32() {
%vector_b = arith.constant dense<314> : vector<[4]xi32>
%vector_c = arith.constant dense<0> : vector<[4]xi32>

// The result of this dot-product will depend
// on the vector length, so we are unable to verify it.
// DOT PRODUCT 1
%dp1 = vector.contract #dotp_trait %vector_a, %vector_b, %acc
: vector<[4]xi32>, vector<[4]xi32> into i32
// Dot product should be (123 * 314) * 4 * vscale, so ...
// Dot product should be:
// * val = (123 * 314) * 4 * vscale,
// so ...
%vscale = vector.vscale
%vscale_i32 = arith.index_cast %vscale : index to i32
%dp1_divvl = arith.divui %dp1, %vscale_i32 : i32
// ... %dp/%vscale = 123 * 314 * 4 = 154488
%dp1_div = arith.divui %dp1, %vscale_i32 : i32
// ... val / vscale = 123 * 314 * 4 = 154488
// DP: 154488
vector.print %dp1_divvl : i32
vector.print %dp1_div : i32

// DOT PRODUCT 2
// The result of this dot-product should be 0.
%dp2 = vector.contract #dotp_trait %vector_a, %vector_c, %acc
: vector<[4]xi32>, vector<[4]xi32> into i32
Expand All @@ -96,18 +98,27 @@ func.func @matvec_i32() {
%vector_b = arith.constant dense<314> : vector<[4]xi32>
%vector_c = arith.constant dense<0> : vector<[4]xi32>

// The result of this matvec will depend on the vector length, so we are
// unable to verify it.
%dp1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
// MATVEC 1
%mv1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
: vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
// MV: {{[0-9]*}}, {{[0-9]*}}, {{[0-9]*}}
vector.print %dp1 : vector<3xi32>

// The result of this matvc should be a vector of 0s.
%dp2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
// Every element in the output vector is a result of a dot product, for
// which:
// val = (123 * 314) * 4 * vscale
// so ...
%vscale = vector.vscale
%vscale_v = vector.splat %vscale : vector<3xindex>
%vscale_i32 = arith.index_cast %vscale_v : vector<3xindex> to vector<3xi32>
%mv1_div = arith.divui %mv1, %vscale_i32 : vector<3xi32>
// ... val / vscale = 123 * 314 * 4 = 154488
// MV: 154488, 154488, 154488
vector.print %mv1_div : vector<3xi32>

// MATVEC 2
// The result of this matvec should be a vector of 0s.
%mv2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
: vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
// MV: 0, 0, 0
vector.print %dp2 : vector<3xi32>
vector.print %mv2 : vector<3xi32>

// MV: SVE: END OF TEST OUTPUT
vector.print str "SVE: END OF TEST OUTPUT"
Expand Down

0 comments on commit e80df20

Please sign in to comment.