Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions csrc/src/rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

#include <cute/tensor.hpp>

#include "namespace_config.h"
#include "utils.h"

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace FLASH_NAMESPACE {

using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
cute::copy(Cos(_, m, k), rCos(_, m, k));
cute::copy(Sin(_, m, k), rSin(_, m, k));
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS) / 2; ++i) {
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
S_fp32(2 * i) = real;
S_fp32(2 * i + 1) = imag;
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of negative pointer offset here may raise safety concerns; please add a comment or safeguard to clarify that accessing memory with a negative offset is valid and correctly bounded.

Copilot uses AI. Check for mistakes.
cute::copy(gS_other, rS_other);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
cute::copy(gCos, rCos(_, m, k));
cute::copy(gSin, rSin(_, m, k));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor S_other_fp32 = convert_type<float>(rS_other);
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS); ++i) {
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace FLASH_NAMESPACE