diff --git a/third_party/nvfuser/csrc/scheduler/matmul.cpp b/third_party/nvfuser/csrc/scheduler/matmul.cpp index 7d85ddd5e953..fc1d45e48010 100644 --- a/third_party/nvfuser/csrc/scheduler/matmul.cpp +++ b/third_party/nvfuser/csrc/scheduler/matmul.cpp @@ -12,6 +12,12 @@ namespace nvfuser { namespace { + +// Returns true if given number is power of 2 +bool isPowOf2(int x) { + return x > 1 && (x & (x - 1)) == 0; +} + // Move the broadcast axes to the left on the specified number of inner // dimensions e.g. (when number_of_inner_pos == 3): // [... I0, B, I1] -> [... B, I0, I1] @@ -44,6 +50,200 @@ void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) { tv->reorder(order_map); } +//! Automatically generates the shared memory swizzled data layout +//! for matmul mainloop. +//! The shared mem datalayout is always 2D currently, and this utility +//! function assumes that the innermost 2 dimensions on shared_mem_tv +//! are the ones begin swizzled. +void prologSwizzle(TensorView* shared_mem_tv, const MatmulParam& params) { + // Check that the innermost 2 dimensions are concrete and static + // sized so that the swizzle function can be defined. + + // Utility to check concrete static size: + auto check_concrete_static_dim = [](IterDomain* id) { + TORCH_INTERNAL_ASSERT( + !id->isBroadcast() && !id->isReduction(), + "no support on reduction or broadcast dims, but get ", + id->toString()); + TORCH_INTERNAL_ASSERT( + id->extent()->isConstInt(), + "swizzled dimensions need to be statically, but get ", + id->toString()); + }; + + TORCH_INTERNAL_ASSERT( + shared_mem_tv->nDims() >= 2, + "At least 2D input needed for swizzling, but get ", + shared_mem_tv->toString()); + check_concrete_static_dim(shared_mem_tv->axis(-2)); + check_concrete_static_dim(shared_mem_tv->axis(-1)); + + auto mma_config = params.mma_builder.build(); + + // Extract the constant sizes of the swizzled tile + const auto tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt(); + const auto tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt(); + + // TODO: add support for tf32(different macro) and fp32(ffma) + if (isTuring(mma_config.macro) || isAmpere(mma_config.macro)) { + // Dimension of each inner unit of swizzled indices. + // Turing and Ampere case, ldmatrix access assumed (see TODO above) + // Each ldmatrix access is 8x8 + int row_unit = 8; + int col_unit = 8; + + // Column size of the tile needs to be multiples of 8 for ldmatrix to work. + TORCH_INTERNAL_ASSERT( + tile_size_x >= row_unit && tile_size_x % row_unit == 0 && + tile_size_y >= col_unit && tile_size_y % col_unit == 0, + "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", + tile_size_x, + "x", + tile_size_y); + + int units_per_row = tile_size_y / col_unit; + + // Number of column units that can fit in a conflict free shared mem wave + // with memory width = 128 Byte assumed. + const int units_per_memory_row = + 128 / dataTypeSize(DataType::Half) / col_unit; + + // Calculate swizzle period: + int residue_unit_count = units_per_row % units_per_memory_row; + + // In the case where tile row is a multiple of memory row, the whole memory + // row + // is the repeated pattern of swizzle. In the case where tile row is not + // divisible, the residule part is the repeated pattern. + int repeated_pattern_size_in_units = + residue_unit_count == 0 ? units_per_memory_row : residue_unit_count; + + // Calculate row multiplier, which is defined as minimum number of rows + // to look down from an element until the same bank index is observed. + c10::optional maybe_row_multiplier = c10::nullopt; + + if (units_per_memory_row % repeated_pattern_size_in_units == 0) { + maybe_row_multiplier = + units_per_memory_row / repeated_pattern_size_in_units; + } else if ( + units_per_memory_row > repeated_pattern_size_in_units && + units_per_memory_row % + (units_per_memory_row - repeated_pattern_size_in_units) == + 0) { + maybe_row_multiplier = units_per_memory_row / + (units_per_memory_row - repeated_pattern_size_in_units); + } + + // The case where the row multiplier cannot be an integer would be where + // fractional tiling support is needed. Would gradually build out support + // on this one. + if (!maybe_row_multiplier.has_value()) { + // calculate effective row_period = lcm(row_period, repeated_pattern) / + // repeated_pattern_size which is the same as below + int row_period = units_per_memory_row / + std::gcd(units_per_memory_row, repeated_pattern_size_in_units); + + if (row_period < row_unit) { + TORCH_WARN_ONCE( + "Fractional pattern not yet implemented for swizzling memory row of size :", + units_per_memory_row, + " and tile row of size: ", + repeated_pattern_size_in_units); + // This would not lead to functional issue but just perf regression, so + // just do not swizzle anything yet. + // TODO: add support for swizzles with different row and col periods to + // enable this case. + return; + } else { + // This case would not need swizzling at all as the period of + // memory bank index over the row is wider than the access window. + return; + } + } else if (maybe_row_multiplier.value() >= row_unit) { + // No need to swizzle in this case. + return; + } + + // Calculate swizzle period, only equal row/col periods at the moment: + // TODO: aperiodic swizzle could also be supported in a follow up: + int max_swizzle_period = repeated_pattern_size_in_units; + + int swizzle_period = max_swizzle_period; + + // Do not have to use the max_swizzle period if we already had + // enough swizzle to permute a row_unit. This would encourage + // usage of power of 2 swizzle periods. + if (row_unit % maybe_row_multiplier.value() == 0) { + swizzle_period = + std::min(swizzle_period, row_unit / maybe_row_multiplier.value()); + } + + int row_multiplier = maybe_row_multiplier.value(); + + TORCH_INTERNAL_ASSERT( + tile_size_x % (swizzle_period * row_multiplier) == 0 && + tile_size_y % (swizzle_period * col_unit) == 0, + "need aperiodic swizzle config for tile size ", + tile_size_x, + "x", + tile_size_y, + "with units ", + row_unit, + "x", + col_unit); + + // add the swizzling op: + shared_mem_tv->split(-2, row_multiplier * swizzle_period); + shared_mem_tv->split(-2, row_multiplier); + + shared_mem_tv->split(-1, col_unit * swizzle_period); + shared_mem_tv->split(-1, col_unit); + + // -6 -5 -4 -3 -2 -1 + // [..., Irow_o, Irow_period, Irow_multiplier, Icol_o, Icol_period, + // Icol_unit] + if (isPowOf2(swizzle_period)) { + shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2); + } else { + shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2); + } + + // Merge back the tile for subsequent vectorization scheduling + // TODO: could potentially simplify away the merges + shared_mem_tv->merge(-6); + shared_mem_tv->merge(-5); + shared_mem_tv->merge(-3); + shared_mem_tv->merge(-2); + } else if (isVolta(mma_config.macro)) { + // TODO: Volta is slightly more complex, and a fixed recipe would + // not scale. In a follow up this would be inferred from the mma + // macro layout themselves as we already have them registered in + // the utils. + return; + } else { + TORCH_INTERNAL_ASSERT(false, "Prolog swizzle: unsupported mma macro"); + } +} + +//! Generates the prolog schedule on the shared memory buffer +//! tensor. The scheduling performs two steps: +//! +//! 1. Swizzled the shared mem data layout. +//! 2. Coalesce and vectorize the read write schedule. +void scheduleProlog(TensorView* shared_mem_tv, const MatmulParam& params) { + // Swizzle the shared memory data layout + prologSwizzle(shared_mem_tv, params); + + // Assuming we are always vectorizing smem write by 128b at the moment: + // TODO: would need a data-type and alignment dependent interface + // to support non-vectorizable shapes. + // The vectorizable width logic would be in a separate PR as the + // current effort tries to focus on generating swizzles. + shared_mem_tv->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + shared_mem_tv, params.tile_sizes, 8, false); +} + } // namespace void scheduleMatmul( @@ -198,15 +398,11 @@ void scheduleMatmul( // ------------------------------------------------------------------ scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(acw_smem); // [... M, K] - acw_smem->merge(-2); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - acw_smem, gemm_tile, 8, false); + scheduleProlog(acw_smem, params); - // [... N, K] scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(bcw_smem); - bcw_smem->merge(-2); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - bcw_smem, gemm_tile, 8, false); + // [... N, K] + scheduleProlog(bcw_smem, params); // Propagate prolog tensors // propagate up the DAG, and propagate parallel type. @@ -230,7 +426,7 @@ void scheduleMatmul( // CTA tile: // Swizzle block tiles: - c->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop); + // c->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop); a->computeAt(c, 2); b->computeAt(c, 2);