Skip to content

Commit

Permalink
Insert PTX inline assembly for vectorization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 525718007
  • Loading branch information
Aliia Khasanova authored and Copybara-Service committed Apr 20, 2023
1 parent 5e9c0e8 commit 76ebd44
Showing 1 changed file with 59 additions and 7 deletions.
66 changes: 59 additions & 7 deletions xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,76 @@ limitations under the License.

namespace experiments {
namespace benchmark {
#define DFUNC __forceinline__ __device__
#define HDFUNC DFUNC __host__

template <typename ET, size_t S>
class Vec {
public:
using ElementType = ET;
constexpr static size_t Size = S;

template <typename... Ts>
HDFUNC Vec(Ts... elements) : data_() {
InsertElements(0, elements...);
}

HDFUNC ElementType& operator[](size_t idx) { return data_[idx]; }
HDFUNC const ElementType& operator[](size_t idx) const { return data_[idx]; }

private:
template <typename T, typename... Ts>
HDFUNC void InsertElements(size_t idx, T element, Ts... rest) {
data_[idx] = element;
InsertElements(idx + 1, rest...);
}
HDFUNC void InsertElements(size_t idx) {}

ElementType data_[Size];
};

template <typename VectorType, typename T>
DFUNC void Store(VectorType vx, T* __restrict__ x, size_t id) {
reinterpret_cast<VectorType* __restrict__>(x)[id] = vx;
}
template <>
DFUNC void Store(Vec<float, 4> vx, float* __restrict__ x, size_t id) {
asm("st.global.v4.f32 [%0], {%1, %2, %3, %4};"
:
: "l"(x + 4 * id), "f"(vx[0]), "f"(vx[1]), "f"(vx[2]), "f"(vx[3]));
}

template <typename VectorType, typename T>
DFUNC void LoadNc(VectorType& vx, const T* __restrict__ x, size_t id) {
vx = reinterpret_cast<const VectorType* __restrict__>(x)[id];
}

template <>
DFUNC void LoadNc(Vec<float, 4>& vx, const float* __restrict__ x, size_t id) {
asm("ld.global.nc.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(vx[0]), "=f"(vx[1]), "=f"(vx[2]), "=f"(vx[3])
: "l"(x + 4 * id));
}

template <int chunks>
__global__ void BenchmarkDeviceCopyKernel(const float* __restrict__ in,
float* __restrict__ out,
const int64_t size) {
int64_t size) {
const int64_t lines = size / (blockDim.x * chunks);
const int64_t start_line = lines * blockIdx.x / gridDim.x;
const int64_t end_line = lines * (blockIdx.x + 1) / gridDim.x;
const int64_t start_offset = start_line * blockDim.x * chunks + threadIdx.x;
const int64_t start_offset =
start_line * blockDim.x * chunks + 4 * threadIdx.x;
const int64_t end_offset = end_line * blockDim.x * chunks;
float buffer[chunks];
Vec<float, 4> buffer[chunks / 4];
for (int64_t i = start_offset; i < end_offset; i += blockDim.x * chunks) {
#pragma unroll
for (int j = 0; j < chunks; j++) {
buffer[j] = in[i + blockDim.x * j];
for (int j = 0; j < chunks; j += 4) {
LoadNc(buffer[j / 4], in + i + blockDim.x * j, 0);
}
#pragma unroll
for (int j = 0; j < chunks; j++) {
out[i + blockDim.x * j] = buffer[j];
for (int j = 0; j < chunks; j += 4) {
Store(buffer[j / 4], out + i + blockDim.x * j, 0);
}
}
}
Expand Down

0 comments on commit 76ebd44

Please sign in to comment.