Skip to content

Commit

Permalink
Minor blocksparse refactoring, update block size restrictions, relax …
Browse files Browse the repository at this point in the history
…power of two constraint (#277)

* Relax device size restrictions

* Refactor device creation and run all tests

* linting

Co-authored-by: Cole Hawkins <colehawk@amazon.com>
  • Loading branch information
colehawkins and amazon-colehawk committed Apr 20, 2022
1 parent c4c7b5f commit b212063
Show file tree
Hide file tree
Showing 20 changed files with 64 additions and 52 deletions.
2 changes: 1 addition & 1 deletion examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class GPT(pl.LightningModule):
""" the full GPT language model, with a context size of block_size """
"""the full GPT language model, with a context size of block_size"""

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions stubs/torch/fft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from torch import Tensor, complex64
DType = TypeVar("DType")

Ts = TypeVarTuple("Ts")

@overload
def fft(
input: Tensor[DType, Unpack[Ts]],
Expand Down
1 change: 1 addition & 0 deletions stubs/torch/linalg/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ M = TypeVar("M", bound=int)
N = TypeVar("N", bound=int)

Ts = TypeVarTuple("Ts")

@overload
def pinv(
input: Tensor[FloatOrDouble, Unpack[Ts], N1, N1],
Expand Down
1 change: 1 addition & 0 deletions stubs/torch/nn/functional.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ N1 = TypeVar("N1", bound=int)
N2 = TypeVar("N2", bound=int)
N3 = TypeVar("N3", bound=int)
N4 = TypeVar("N4", bound=int)

@overload
def pad(
input: Tensor[DType, Unpack[Ts], N],
Expand Down
1 change: 1 addition & 0 deletions stubs/torch/nn/functional/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ N1 = TypeVar("N1", bound=int)
N2 = TypeVar("N2", bound=int)
N3 = TypeVar("N3", bound=int)
N4 = TypeVar("N4", bound=int)

@overload
def pad(
input: Tensor[DType, Unpack[Ts], N],
Expand Down
1 change: 1 addition & 0 deletions stubs/torch/sparse/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ DType = TypeVar("DType")
DType2 = TypeVar("DType2")

Ts = TypeVarTuple("Ts")

@overload
def softmax(
input: Tensor[DType, Unpack[Ts]], dim: int, dtype: Optional[DType2]
Expand Down
4 changes: 2 additions & 2 deletions stubs/torch_stub_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,9 +1695,9 @@ def test__lt__() -> None:
def test_pow() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]

y: Tensor[torch.float32, L[2], L[3], L[4]] = x ** 4
y: Tensor[torch.float32, L[2], L[3], L[4]] = x**4
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = x ** 4
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = x**4


def test_item() -> None:
Expand Down
14 changes: 8 additions & 6 deletions tests/test_triton_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("BLOCK", [32])
@pytest.mark.parametrize("BLOCK", [32, 128])
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
def test_softmax(BLOCK, WIDTH, DTYPE):
Expand Down Expand Up @@ -127,12 +127,12 @@ def test_softmax(BLOCK, WIDTH, DTYPE):


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("block", [32, 43]) # 16, 32,
@pytest.mark.parametrize("block", [32, 43, 128]) # 16, 32,
def test_attention_fwd_bwd(
block,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
n_ctx=384,
dtype=torch.float16,
batch_size=2,
n_heads=2,
Expand All @@ -148,16 +148,18 @@ def test_attention_fwd_bwd(
]

def loss_fn(x):
return (x ** 2).mean()
return (x**2).mean()

# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
layout = torch.tril(
torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long), diagonal=-1
)
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
if block not in [16, 32, 64]:
if block not in [16, 32, 64, 128]:
# Check that unsupported dimensions are caught
with pytest.raises(AssertionError):
_ = BlockSparseAttention(layout, block)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_triton_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"dtype", [torch.float32]
) # Triton use tensor cores, which return slightly different results to pytorch mm
def test_fused_matmul(shape, dtype):
""" Check that the matrix multiply kernel and Pytorch's give the same results"""
"""Check that the matrix multiply kernel and Pytorch's give the same results"""
torch.random.manual_seed(0)

# Raw fused matrix multiply first, to catch gross errors
Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/run_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def replace(config_dict, k, v):

def seed_worker(_: int):
# Make sure that non-pytorch random generators are properly set
worker_seed = torch.initial_seed() % 2 ** 32
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)

Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _train_for_several_steps(

if _use_cuda:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
max_memory = torch.cuda.max_memory_allocated() / 2**20
else:
max_memory = -1
run_time = time.time() - start_time
Expand Down
4 changes: 2 additions & 2 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def ref_attention(q, k, v):
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
memory = torch.cuda.max_memory_allocated() / 2**20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

Expand All @@ -83,7 +83,7 @@ def ref_attention(q, k, v):
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
memory = torch.cuda.max_memory_allocated() / 2**20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)
Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_nystrom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def iterative_pinv_analysis(
):

for i in range(1, 10):
B, M = 1, 2 ** i
B, M = 1, 2**i
a = torch.rand(B, M, M)
a = torch.softmax(a, dim=-1)

Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_pytorch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def run_training(model, optimizer, label):
torch.cuda.synchronize()

train(model, optimizer, label, steps, batch, seq, emb, device)
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
max_memory = torch.cuda.max_memory_allocated() // 2**20
print(f"Peak memory use: {max_memory}MB")

eval_stop = evaluate(model, batch, seq, emb, device)
Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
Expand Down
4 changes: 2 additions & 2 deletions xformers/components/attention/attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def random_pattern_from_probability_matrix(dist_matrix, nnz):
# should work fine if double tensor is passed on CPU. This is a bug that was introduced
# in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
# when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
if dist_matrix.numel() > 2 ** 24:
if dist_matrix.numel() > 2**24:
dist_matrix = dist_matrix.double()
dist_matrix /= dist_matrix.sum()
idxs = np.random.choice(
Expand Down Expand Up @@ -227,7 +227,7 @@ def get_slopes(n: int):
def get_slopes_power_of_2(n: int) -> List[float]:
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
return [start * ratio**i for i in range(n)]

# In the paper, we only train models that have 2^a heads for some a. This function has
# some good properties that only occur when the input is a power of 2. To maintain that even
Expand Down
63 changes: 34 additions & 29 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
16,
32,
64,
), "Only block sizes in [16, 32, 64] are supported"
128,
), "Only block sizes in [16, 32, 64, 128] are supported"

super().__init__()

Expand Down Expand Up @@ -112,6 +113,32 @@ def update_mask_type(self, mask: torch.Tensor):
)
mask = bool_mask_to_additive(mask)

def create_triton_kernels(self, device):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=device,
)

self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=device,
)

self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=device,
)

def forward(
self,
q: torch.Tensor,
Expand All @@ -132,31 +159,9 @@ def forward(
"""

# Delayed triton init, to make sure that we get the right device
# Infer device from query
if not hasattr(self, "sparse_dot_sdd"):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=q.device,
)

self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=q.device,
)

self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=q.device,
)
self.create_triton_kernels(q.device)

assert (
q.shape[-2] == k.shape[-2]
Expand All @@ -169,10 +174,10 @@ def forward(
k.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"

assert math.log(
q.shape[-2], 2
).is_integer(), (
"For now blocksparse only works on power-of-two sequence lengths"
assert (
q.shape[-2] % self.block_size
) == 0, "Sequence length {} must be a multiple of block size {}".format(
q.shape[-2], self.block_size
)

# Blocksparse only works on fp16
Expand Down
4 changes: 2 additions & 2 deletions xformers/components/attention/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def __init__(
self.value_dim * num_rules == dim_attn
), "value_dim must be divisible by num_rules"

self.scaling = self.dim_head ** -0.5
self.scaling_values = self.dim_selection ** -0.5
self.scaling = self.dim_head**-0.5
self.scaling_values = self.dim_selection**-0.5

self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)

Expand Down
2 changes: 1 addition & 1 deletion xformers/components/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_deepnorm_coefficients(
else:
# Encoder/decoder
encoder_coeffs = DeepNormCoefficients(
alpha=0.81 * ((N ** 4) * M) ** 0.0625, beta=0.87 * ((N ** 4) * M) ** -0.0625
alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
)

decoder_coeffs = DeepNormCoefficients(
Expand Down
2 changes: 1 addition & 1 deletion xformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def rmf(filename: str) -> None:

@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]

yield tuple(files)
Expand Down

0 comments on commit b212063

Please sign in to comment.