Skip to content

Commit

Permalink
Merge branch 'main' into moe
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 19, 2022
2 parents 33870f5 + c16078b commit eaaa080
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 25 deletions.
38 changes: 38 additions & 0 deletions tests/ragged_inference/test_garbage_pad_ragged_acts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import pytest
import torch

from xformers.helpers.test_utils import assert_eq, bf16_cuda
from xformers.triton.garbage_pad_ragged_acts import RaggedActivations


def _make_seq(n_ctx: int, value: int, d_model: int):
return torch.full([n_ctx, d_model], value, **bf16_cuda())


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_garbage_pad_active_queries_correctness():
d_model = 6
seqs = [
_make_seq(n_ctx=1, value=33, d_model=d_model),
_make_seq(n_ctx=3, value=42, d_model=d_model),
_make_seq(n_ctx=7, value=55, d_model=d_model),
]
active_queries = RaggedActivations.from_list(seqs)
padded_queries = active_queries.to_garbage_padded()

# Check that the non-garbage portion of each is correct
assert_eq(padded_queries[0, :1, :], seqs[0])
assert_eq(padded_queries[1, :3, :], seqs[1])
assert_eq(padded_queries[2, :7, :], seqs[2])


def test_add_kernel():
pass
27 changes: 2 additions & 25 deletions tests/ragged_inference/test_seq_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,8 @@
import pytest
import torch

from xformers.helpers.test_utils import assert_eq


def bf16_cuda():
return dict(device="cuda", dtype=torch.bfloat16)


class RaggedActivations:
def __init__(self, raw_tensor: torch.Tensor, n_ctx_per_seq: List[int]):
self.raw_tensor = raw_tensor
self.n_ctx_per_seq = n_ctx_per_seq

@classmethod
def from_list(cls, tensors: List[torch.Tensor]):
"""Tensors must all be of shape [n_ctx, d_model]."""
return cls(
raw_tensor=torch.cat(tensors),
n_ctx_per_seq=[tensor.shape[0] for tensor in tensors],
)

def iter_full_tensors(self):
idx_so_far = 0
for n_ctx_in_this_seq in self.n_ctx_per_seq:
yield self.raw_tensor[idx_so_far : idx_so_far + n_ctx_in_this_seq]
idx_so_far += n_ctx_in_this_seq
from xformers.helpers.test_utils import assert_eq, bf16_cuda
from xformers.triton.garbage_pad_ragged_acts import RaggedActivations


class SingleSeqKVCache:
Expand Down
4 changes: 4 additions & 0 deletions xformers/helpers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ def init_torch_distributed_local():
world_size=1,
init_method=init_url,
)


def bf16_cuda():
return dict(device="cuda", dtype=torch.bfloat16)
81 changes: 81 additions & 0 deletions xformers/triton/garbage_pad_ragged_acts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
):
BLOCK_SIZE = meta["BLOCK_SIZE"] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)


class RaggedActivations:
def __init__(self, raw_tensor: torch.Tensor, n_ctx_per_seq: List[int]):
self.raw_tensor = raw_tensor
self.n_ctx_per_seq = n_ctx_per_seq

@classmethod
def from_list(cls, tensors: List[torch.Tensor]):
"""Tensors must all be of shape [n_ctx, d_model]."""
return cls(
raw_tensor=torch.cat(tensors),
n_ctx_per_seq=[tensor.shape[0] for tensor in tensors],
)

def iter_full_tensors(self):
idx_so_far = 0
for n_ctx_in_this_seq in self.n_ctx_per_seq:
yield self.raw_tensor[idx_so_far : idx_so_far + n_ctx_in_this_seq]
idx_so_far += n_ctx_in_this_seq

def to_garbage_padded(self) -> torch.Tensor:
"""
Create a tensor of shape (n_seqs, n_ctx_max, d_model) where the
sequences are right-padded with garbage data
"""
n_seqs = len(self.n_ctx_per_seq)
n_ctx_max = max(self.n_ctx_per_seq)

n_dim = self.raw_tensor.shape[-1]
padded_acts = torch.empty(
n_seqs, n_ctx_max, n_dim, dtype=self.raw_tensor.dtype, device="cuda"
)

idx_so_far = 0
for seq_idx, n_ctx_in_this_seq in enumerate(self.n_ctx_per_seq):
this_seq = self.raw_tensor[idx_so_far : idx_so_far + n_ctx_in_this_seq]
padded_acts[seq_idx, :n_ctx_in_this_seq, :] = this_seq
idx_so_far += n_ctx_in_this_seq

return padded_acts

0 comments on commit eaaa080

Please sign in to comment.