Skip to content

Commit

Permalink
Implement offset mode for linearize_cache_indices CUDA kernel (pytorc…
Browse files Browse the repository at this point in the history
…h#2554)

Summary:

This patch will allow base_offset as a parameter so all values in `offset` will be decresed by that amount. This is done in a way that no copy of `offset`.

The ultimate goal for this is to achieve multipass prefetch, which require calling this kernel on a segment of `indices` (rather than the whole). See unittest for its usage.

Reviewed By: SherlockNoMad, henryoier

Differential Revision: D56863774
  • Loading branch information
levythu authored and facebook-github-bot committed May 7, 2024
1 parent f2b1b50 commit a410601
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
indices,
offsets,
/*B_offsets=*/c10::optional<Tensor>(),
/*max_B=*/-1);
/*max_B=*/-1,
/*indices_base_offset=*/0);

bool gather_uvm_stats = false;
// populate_uvm_stats indicates whether to calculate cache related ratios,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ at::Tensor linearize_cache_indices_cuda(
const at::Tensor& indices,
const at::Tensor& offsets,
const c10::optional<at::Tensor>& B_offsets,
const int64_t max_B);
const int64_t max_B,
const int64_t indices_base_offset);

///@ingroup table-batched-embed-cuda
/// Linearize the indices of all tables to make it be unique.
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/split_embeddings_cache/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Tensor linearize_cache_indices_cpu(
const Tensor& indices,
const Tensor& offsets,
const c10::optional<Tensor>& B_offsets,
const int64_t max_B);
const int64_t max_B,
const int64_t indices_base_offset);

Tensor linearize_cache_indices_from_row_idx_cpu(
Tensor cache_hash_size_cumsum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ DLL_PUBLIC Tensor linearize_cache_indices_cpu(
const Tensor& indices,
const Tensor& /*offsets*/,
const c10::optional<Tensor>& /*B_offsets*/,
const int64_t /*max_B*/) {
const int64_t /*max_B*/,
const int64_t /*indices_base_offset*/) {
return at::empty_like(indices);
}

Expand Down
12 changes: 8 additions & 4 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel(
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits>
table_offsets,
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
linear_cache_indices) {
linear_cache_indices,
const int64_t indices_base_offset) {
const index_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= indices.size(0)) {
return;
Expand All @@ -31,10 +32,11 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel(
// Perform binary search.
int left = 0;
int right = table_offsets.size(0);
const auto index_with_offset = index + indices_base_offset;
while (left != right) {
const int middle =
left + (right - left) / 2; // Avoid overflow in midpoint calculation
if (table_offsets[middle] <= index) {
if (table_offsets[middle] <= index_with_offset) {
left = middle + 1;
} else {
right = middle;
Expand All @@ -61,7 +63,8 @@ DLL_PUBLIC Tensor linearize_cache_indices_cuda(
const Tensor& indices,
const Tensor& offsets,
const c10::optional<Tensor>& B_offsets,
const int64_t max_B) {
const int64_t max_B,
const int64_t indices_base_offset) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
cache_hash_size_cumsum, indices, offsets);

Expand Down Expand Up @@ -115,7 +118,8 @@ DLL_PUBLIC Tensor linearize_cache_indices_cuda(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, table_offsets, offset_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, linear_cache_indices, int64_t, 1, 32));
func_name, linear_cache_indices, int64_t, 1, 32),
indices_base_offset);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace {

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets, Tensor? B_offsets=None, int max_B=-1) -> Tensor");
"linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets, Tensor? B_offsets=None, int max_B=-1, int indices_base_offset=0) -> Tensor");
m.def(
"linearize_cache_indices_from_row_idx(Tensor cache_hash_size_cumsum, Tensor update_table_indices, Tensor update_row_indices) -> Tensor");
m.def(
Expand Down
23 changes: 23 additions & 0 deletions fbgemm_gpu/test/tbe/cache/linearize_cache_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_linearize_cache_indices(self) -> None:
dtype=torch.int,
device="cuda",
)
N = indices.numel()
pruned_indices = torch.tensor(
[10, -1, 3, 7, 1, 4, -1, 9, 2, -1, 6, 8, 5, 1, -1, 4],
dtype=torch.int,
Expand Down Expand Up @@ -141,6 +142,28 @@ def test_linearize_cache_indices(self) -> None:
output_ref = self.execute_linearize_cache_indices_ref(*args)
self.assertTrue(torch.equal(output_test, output_ref))

for partial_start, partial_end in [
(0, N),
(2, 2),
(3, N),
(N - 2, N),
(1, N - 1),
]:
args = (
hash_size_cumsum,
indices[partial_start:partial_end],
offsets,
B_offsets,
max_B,
partial_start,
)
partial_output = torch.ops.fbgemm.linearize_cache_indices(*args)
self.assertTrue(
torch.equal(
partial_output, output_ref[partial_start:partial_end]
)
)

@unittest.skipIf(*gpu_unavailable)
def test_linearize_cache_indices_from_row_idx(self) -> None:
update_row_indices = torch.tensor(
Expand Down

0 comments on commit a410601

Please sign in to comment.