Skip to content

Commit

Permalink
got rid of iterateCols
Browse files Browse the repository at this point in the history
[-------------------------------------------- attn --------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     12.5    |     7.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     15.4    |     9.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     12.6    |     7.7
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     15.5    |     9.2
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     10.3    |     6.0
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     12.9    |     7.6
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     45.1    |    29.2
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     55.7    |    35.2
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     45.6    |    29.3
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     56.1    |    35.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     38.7    |    22.7
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     46.8    |    29.0

Times are in milliseconds (ms).

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     19.3    |    24.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     19.4    |    24.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     22.3    |    28.7
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     22.3    |    29.1
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     19.4    |    22.7
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     19.5    |    23.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     63.8    |    91.3
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     63.9    |    94.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     75.4    |   109.7
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     75.6    |   111.2
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     63.9    |    85.8
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     64.3    |    90.3
  • Loading branch information
jfc4050 committed Dec 15, 2022
1 parent d9c21dc commit 0062b22
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 76 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
Expand Down Expand Up @@ -195,43 +193,6 @@ struct AttentionScalingCoefsUpdaterSm80
}
}

template <typename BeginColFn, typename VisitFn, typename EndColFn>
CUTLASS_DEVICE static void iterateCols(
cutlass::MatrixCoord& lane_offset,
BeginColFn onBeginCol,
VisitFn onVisit,
EndColFn onEndCol
) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {

CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
const int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
col + lane_offset.column();

onBeginCol(accum_n);

CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
const int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
(mma_n * Policy::MmaIterations::kRow + mma_m);

CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
const int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
row * kRowsPerTile + lane_offset.row();
const int accum_idx = mma_accum_start + row * kElementsPerAccess + col;

onVisit(accum_m, accum_n, accum_idx);
}
}

onEndCol(accum_n);
}
}
}

template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
// In each warp, 4 threads will work on the same row
Expand Down Expand Up @@ -423,42 +384,6 @@ struct AttentionScalingCoefsUpdaterSimt
}
}

template <typename BeginColFn, typename VisitFn, typename EndColFn>
CUTLASS_DEVICE static void iterateCols(
cutlass::MatrixCoord& lane_offset,
BeginColFn on_begin_col_fn,
VisitFn on_visit_element_fn,
EndColFn on_end_col_fn
) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
const int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
lane_offset.column();

on_begin_col_fn(accum_n);

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {

CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {

CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
const int accum_m = mma_m * Delta::kRow + m + lane_offset.row();

int idx = n + Policy::LaneMmaShape::kN *
(mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM));

on_visit_element_fn(accum_m, accum_n + n, idx);
}
}
}

on_end_col_fn(accum_n);
}
}

static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
int8_t lane_id,
int8_t warp_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ struct AttentionBackwardKernel {
// Pij += Bij, where Pij is in register fragment and Bij is in shared memory
auto lane_offset = MatmulQK::ScalingCoefsUpdater::get_lane_offset(
lane_id, warp_id, output_tile_coords);
MatmulQK::ScalingCoefsUpdater::iterateCols(
MatmulQK::ScalingCoefsUpdater::iterateRows(
lane_offset,
[&](int accum_n) {},
[&](int accum_m, int accum_n, int idx) {
Expand Down

0 comments on commit 0062b22

Please sign in to comment.