-
Notifications
You must be signed in to change notification settings - Fork 552
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
125 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import pytest | ||
import torch | ||
|
||
from xformers.helpers.test_utils import assert_eq, bf16_cuda | ||
from xformers.triton.garbage_pad_ragged_acts import RaggedActivations | ||
|
||
|
||
def _make_seq(n_ctx: int, value: int, d_model: int): | ||
return torch.full([n_ctx, d_model], value, **bf16_cuda()) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not torch.cuda.is_available(), reason="This test requires a CUDA device" | ||
) | ||
def test_garbage_pad_active_queries_correctness(): | ||
d_model = 6 | ||
seqs = [ | ||
_make_seq(n_ctx=1, value=33, d_model=d_model), | ||
_make_seq(n_ctx=3, value=42, d_model=d_model), | ||
_make_seq(n_ctx=7, value=55, d_model=d_model), | ||
] | ||
active_queries = RaggedActivations.from_list(seqs) | ||
padded_queries = active_queries.to_garbage_padded() | ||
|
||
# Check that the non-garbage portion of each is correct | ||
assert_eq(padded_queries[0, :1, :], seqs[0]) | ||
assert_eq(padded_queries[1, :3, :], seqs[1]) | ||
assert_eq(padded_queries[2, :7, :], seqs[2]) | ||
|
||
|
||
def test_add_kernel(): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import List | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def add_kernel( | ||
x_ptr, # *Pointer* to first input vector | ||
y_ptr, # *Pointer* to second input vector | ||
output_ptr, # *Pointer* to output vector | ||
n_elements, # Size of the vector | ||
**meta, # Optional meta-parameters for the kernel | ||
): | ||
BLOCK_SIZE = meta["BLOCK_SIZE"] # How many inputs each program should process | ||
# There are multiple 'program's processing different data. We identify which program | ||
# we are here | ||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 | ||
# This program will process inputs that are offset from the initial data. | ||
# for instance, if you had a vector of length 256 and block_size of 64, the programs | ||
# would each access the elements [0:64, 64:128, 128:192, 192:256]. | ||
# Note that offsets is a list of pointers | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
# Create a mask to guard memory operations against out-of-bounds accesses | ||
mask = offsets < n_elements | ||
# Load x and y from DRAM, masking out any extar elements in case the input is not a | ||
# multiple of the block size | ||
x = tl.load(x_ptr + offsets, mask=mask) | ||
y = tl.load(y_ptr + offsets, mask=mask) | ||
output = x + y | ||
# Write x + y back to DRAM | ||
tl.store(output_ptr + offsets, output, mask=mask) | ||
|
||
|
||
class RaggedActivations: | ||
def __init__(self, raw_tensor: torch.Tensor, n_ctx_per_seq: List[int]): | ||
self.raw_tensor = raw_tensor | ||
self.n_ctx_per_seq = n_ctx_per_seq | ||
|
||
@classmethod | ||
def from_list(cls, tensors: List[torch.Tensor]): | ||
"""Tensors must all be of shape [n_ctx, d_model].""" | ||
return cls( | ||
raw_tensor=torch.cat(tensors), | ||
n_ctx_per_seq=[tensor.shape[0] for tensor in tensors], | ||
) | ||
|
||
def iter_full_tensors(self): | ||
idx_so_far = 0 | ||
for n_ctx_in_this_seq in self.n_ctx_per_seq: | ||
yield self.raw_tensor[idx_so_far : idx_so_far + n_ctx_in_this_seq] | ||
idx_so_far += n_ctx_in_this_seq | ||
|
||
def to_garbage_padded(self) -> torch.Tensor: | ||
""" | ||
Create a tensor of shape (n_seqs, n_ctx_max, d_model) where the | ||
sequences are right-padded with garbage data | ||
""" | ||
n_seqs = len(self.n_ctx_per_seq) | ||
n_ctx_max = max(self.n_ctx_per_seq) | ||
|
||
n_dim = self.raw_tensor.shape[-1] | ||
padded_acts = torch.empty( | ||
n_seqs, n_ctx_max, n_dim, dtype=self.raw_tensor.dtype, device="cuda" | ||
) | ||
|
||
idx_so_far = 0 | ||
for seq_idx, n_ctx_in_this_seq in enumerate(self.n_ctx_per_seq): | ||
this_seq = self.raw_tensor[idx_so_far : idx_so_far + n_ctx_in_this_seq] | ||
padded_acts[seq_idx, :n_ctx_in_this_seq, :] = this_seq | ||
idx_so_far += n_ctx_in_this_seq | ||
|
||
return padded_acts |