Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cutlass calculate matrix size problem #75

Closed
wanghr323 opened this issue Dec 23, 2019 · 5 comments
Closed

cutlass calculate matrix size problem #75

wanghr323 opened this issue Dec 23, 2019 · 5 comments

Comments

@wanghr323
Copy link

I wrote a function using cutlass to test the performance of cutlass calculation (int8, int8 to int), but I have now found a problem. M, N, and K in my parameters cannot be selected at random, where N and K must Multiples of 16. Choosing something else will cause an error. Is there something wrong with my writing of this function?

int Int8Operator:: cutlass_gemm32I_tensorop(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const void *alpha, const void* A, const void* B, const void *beta,
    void* C,cublasGemmAlgo_t algo/*non used*/)
    {
            using A_Major = cutlass::layout::ColumnMajor;
            using B_Major = cutlass::layout::ColumnMajor;
            using ElementOutput = int32_t;
            using ElementAccumulator = int32_t;
            int lda = (TransA == CblasNoTrans) ? K : M;
            int ldb = (TransB == CblasNoTrans) ? N : K;
            int ldc = N;
            using Gemm = cutlass::gemm::device::Gemm<
            int8_t,
            A_Major,
            int8_t,
            B_Major,
            ElementOutput,
            cutlass::layout::RowMajor,
            ElementAccumulator,
            cutlass::arch::OpClassWmmaTensorOp,
            cutlass::arch::Sm75,
            cutlass::gemm::GemmShape<128, 128, 32>,
            cutlass::gemm::GemmShape<64, 64, 32>,
            cutlass::gemm::GemmShape<16, 16, 16>,
            cutlass::epilogue::thread::LinearCombination<
            ElementOutput,
            128 / cutlass::sizeof_bits<ElementOutput>::value,
            ElementAccumulator,
            ElementAccumulator
            >,
            cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
            2
        >;
        Gemm gemm_op;
        int alpha_ = *(static_cast<const int*>(alpha));
        int beta_ = *(static_cast<const int*>(beta));
        cutlass::Status status = gemm_op({
            {M, N, K},
            {static_cast<const int8_t *>(A), lda},
            {static_cast<const int8_t *>(B), ldb},
            {static_cast<int*>(C), ldc},
            {static_cast<int*>(C), ldc}, 
            {alpha_,beta_}
        });
        if (status != cutlass::Status::kSuccess) {
            return cudaErrorUnknown;
          }
          return cudaSuccess;
    }
      
@kerrmudgeon
Copy link
Collaborator

kerrmudgeon commented Dec 23, 2019 via email

@wanghr323
Copy link
Author

OK,that is to say , M ,N ,K at least two of them should be Multiples of 16.
Thank u , I will close the issue.

@wanghr323
Copy link
Author

thank u for your reply,Kerr,Then I have a need in my job now, calculating C (int) = A (int8) × B (int8), where I want A, B, and C to be Rowmajor matrices, the size of A is M × K, and the size of B is K × N, the size of C is M * N.
I can guarantee that K is a multiple of 16, and M can be converted to a multiple of 16 (if you can choose it arbitrarily, it is the best, if not, it is fine), but N must be a random number.
How do I achieve it with cutlass?
I tested all combinations in cutlass. If ABC is rowmajor, then N and K must be multiples of 16. If I convert my thoughts and convert A × B to B.trans * A.trans (ABC selects column_major, and brings it back in), then M becomes N and N becomes M, this time it becomes, N can be chosen at will, M must be a multiple of 16, still cannot solve my problem.
Can this problem be solved by cutlass? It's fine if you don't use tensorcoreop, or even wmma.

@wanghr323 wanghr323 reopened this Dec 24, 2019
@kerrmudgeon
Copy link
Collaborator

Here are a three possible recourses:

1.) Padding.

Size the matrices such that they are divisible by 16 elements and initialize the extra elements with zero.

2.) Reduce the alignment requirement at the expense of performance.

The device-level GEMM API accepts an admittedly long list of template arguments including the alignment constraints.

https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm.h#L201

using Gemm = cutlass::gemm::device::Gemm<
      int8_t,
      cutlass::layout::RowMajor,
      int8_t,
      B_Major,
      cutlass::layout::ColumnMajor,
      cutlass::layout::RowMajor,
      ElementAccumulator,
      cutlass::arch::OpClassTensorOp,
      cutlass::arch::Sm75,
      cutlass::gemm::GemmShape<128, 128, 64>,
      cutlass::gemm::GemmShape<64, 64, 64>,
      cutlass::gemm::GemmShape<8, 8, 16>,
      cutlass::epilogue::thread::LinearCombination<
        ElementOutput,
        1,     // alignment of C units
        ElementAccumulator,
        ElementAccumulator
      >,
      cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
      2,
      1,   // alignment of A in units of number of elements
      1    // alignment of B in units of number of elements
  >;

3.) Use the integer-valued SIMT kernels.

You may consider using a kernel targeting integer dot product "dp4" instructions, first available in the Pascal microarchitecture and beyond.

Here is the definition syntax, visible in unit tests for these kernels.
https://github.com/NVIDIA/cutlass/blob/master/test/unit/gemm/device/simt_int8_igemm_sm61.cu

  // Output data type - may be int8_t or int32_t
  using ElementOutput = int8_t;

  // Accumulator data type
  using ElementAccumulator = int32_t;

  // Scalar data type
  using ElementCompute = float;

  // Instruction shape - describes a 1x1x4 dot product computed by
  // the "dp4" instruction.
  using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;

  using Gemm = cutlass::gemm::device::Gemm<
    int8_t,
    cutlass::layout::ColumnMajor,
    int8_t,
    cutlass::layout::ColumnMajor,
    ElementOutput,
    cutlass::layout::RowMajor,
    int32_t,
    cutlass::arch::OpClassSimt,
    cutlass::arch::Sm61,
    ThreadBlockShape,
    WarpShape,
    InstructionShape
  >;

There is no restriction on M, N, or K, but the matrices themselves must be 32b aligned. That is, pointers and leading dimensions must be divisible by 4 bytes.

@wanghr323
Copy link
Author

thank you for your help.I will close the question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants