Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <algorithm>
#include <cfloat>
#include <cuda.h>
#include <cuda.h> // for CUDA_VERSION

Check warning on line 21 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: moe_kernel.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:21: Found C system header after C++ system header. Should be: moe_kernel.h, c system, c++ system, other. [build/include_order] [4]
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
Expand All @@ -38,19 +38,12 @@

#include "moe_kernel.h"

#if CUDA_VERSION >= 11000
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#else
#include "cub/cub.cuh"
#include "cub/device/device_radix_sort.cuh"
#include "cub/util_type.cuh"
#endif

namespace ort_fastertransformer {
static constexpr int WARP_SIZE = 32;

// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
Expand All @@ -65,13 +58,6 @@

const int thread_row_offset = blockIdx.x * num_cols;

#if CUDA_VERSION >= 12090
::cuda::std::plus sum;
#else
// Deprecated on CUDA 12.9
cub::Sum sum;
#endif

float threadData(-FLT_MAX);

// Don't touch finished rows.
Expand All @@ -84,7 +70,12 @@
threadData = max(static_cast<float>(input[idx]), threadData);
}

#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::maximum());
#else
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
#endif

if (threadIdx.x == 0) {
float_max = maxElem;
}
Expand All @@ -97,7 +88,12 @@
threadData += exp((static_cast<float>(input[idx]) - float_max));
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::std::plus());
#else
// Deprecated on CUDA 12.9
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, cub::Sum());
#endif

if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
Expand Down Expand Up @@ -993,6 +989,7 @@
if (experts_start_index > 0) {
total_past_rows = total_rows_before_expert_host_[experts_start_index - 1];
}

total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows;
}

Expand Down
Loading