New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory usage: native PyTorch vs. "full"-Attention #68
Comments
Hi, Could you provide a simple bench script. I am using the following and the timings are practically identical on my RTX 2060 S as is the memory. import torch
from fast_transformers.attention import AttentionLayer, FullAttention
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, FullMask, TriangularCausalMask
if __name__ == "__main__":
start = torch.cuda.Event(enable_timing=True)
stop = torch.cuda.Event(enable_timing=True)
x1 = torch.rand(32, 512, 256).cuda()
x2 = torch.rand(512, 32, 256).cuda()
lengths = LengthMask(torch.full((32,), 512, dtype=torch.long).cuda())
attn_mask1 = TriangularCausalMask(512, device="cuda")
attn_mask2 = torch.triu(torch.ones(512, 512)).bool().cuda()
transformer = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(256, 4, dim_feedforward=1024),
4,
torch.nn.LayerNorm(256)
).cuda()
transformer(x2, mask=attn_mask2).sum().backward()
start.record()
for i in range(10):
transformer(x2, mask=attn_mask2).sum().backward()
stop.record()
torch.cuda.synchronize()
print(start.elapsed_time(stop))
transformer = TransformerEncoderBuilder.from_kwargs(
n_layers=4,
n_heads=4,
query_dimensions=64,
value_dimensions=64,
feed_forward_dimensions=1024
).get().cuda()
transformer(x1, attn_mask=attn_mask1).sum().backward()
start.record()
for i in range(10):
transformer(x1, attn_mask=attn_mask1).sum().backward()
stop.record()
torch.cuda.synchronize()
print(start.elapsed_time(stop)) The outcome is
Cheers, |
Hey, I run the code above on two setups each 4 times. Lets call it GPU: RTX 2070
GPU: RTX 2080 Ti
Python: 3.7.9 |
Sure, feel free to share a benchmark. Maybe a naive question, have you ensured that the ordering of the input sequence dimensions is correct? In short, the ordering for PyTorch native is Cheers, |
I suppose my dimensions are correct (I feel it somehow a strange decision from PyTorch to put the sequence first). I did not write an own benchmark script, just managed to slightly modifiy your to obtain similar behaviour, as I experience with my modell. Here is the code. Let's name it
As before, I run the script 4 times on two different devices.
We see, that the memory allocation did not change between these devices, only the computation time, which remains proportional to each other. Furthermore the memory allocation, as well as the runtime did not vary much within the same configuration, thus I would from now on, only report one run and perform it on the RTX 2070. For the next test, I use two different parameter configurations and vary the number of layers in each run. The first parameter set is exactly the one, as you proposed in
Results of
Results of
For me, it looks like, the first configuration produces similar results for both implementations. But with a changed parameter set, the memory consumption, as well as the runtime diverge from each other. I suppose, that the key factor is the much longer sequence length. This seems to my, like an interessting observation, as an efficient implementation will be mostly tested on longer sequences (2k, 4k, 8k, ... ) rather than super short sequences, which can be handled already by a vanilla transformer. |
Awesome, thanks! It appears that the culprit is batch size = 1. I will look into it. It shouldn't be too hard to equalize the performance. As an aside, you could always try attention_type="linear" or attention_type="improved-clustered" to get a really significant speed boost at those sequence lengths. Cheers, |
I am already testing some of the implementations in my usecase (3D Shape Generation), but it seems like the "linear" attention, lacks behind the vanialla full scaled dot product in terms of relative time convergence and converged results. The "improved-clustered" attention will be queued next. My main goal of using your library was to easily test different implementations and swap different attentions, without modifying my code base or adding too much dependencies, to reduce the complexity of my project. Thank you for the big and nice code base. I also saw, that you are often using einsum in your project. Maybe the opt_einsum lib could provide some more speedup, as suggested by PyTorch. P.S. |
Hi, Sorry for the late reply, I just found some time to spend on this. So it is going to be funny but what is going on is actually the fact that the default pytorch layer has a single dropout parameter while we have different parameters for the transformer layers and the attention layer. Simply put if you set The extra memory given the attention dropout makes perfect sense since we need to keep the attention matrix in memory twice. Let me know if you are still experiencing problems. I will also add the test script for completeness and I will close the issue in a few days if there is no change. import argparse
import torch
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, TriangularCausalMask
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_size", type=int, default=1
)
parser.add_argument(
"--length", type=int, default=4096
)
parser.add_argument(
"--n_layers", type=int, default=4
)
parser.add_argument(
"--n_heads", type=int, default=4
)
parser.add_argument(
"--dim_embedding", type=int, default=64
)
parser.add_argument(
"--fast", action="store_true"
)
args = parser.parse_args()
NUM_LAYERS = args.n_layers
DIM_EMBEDDING = args.dim_embedding
NUM_HEADS = args.n_heads
DIM_FF = 4 * DIM_EMBEDDING
LEN_INPUT = args.length
BATCH_SIZE = args.batch_size
FAST_TRANSFORMERS = args.fast
if FAST_TRANSFORMERS:
# fast-transformer
x = torch.rand(BATCH_SIZE, LEN_INPUT, DIM_EMBEDDING).cuda()
attn_mask = TriangularCausalMask(LEN_INPUT, device="cuda")
transformer = TransformerEncoderBuilder.from_kwargs(
n_layers=NUM_LAYERS,
n_heads=NUM_HEADS,
query_dimensions=DIM_EMBEDDING // NUM_HEADS,
value_dimensions=DIM_EMBEDDING // NUM_HEADS,
feed_forward_dimensions=DIM_FF,
dropout=0.0,
attention_dropout=0.0, # this is the difference <-----------------------
activation='relu',
).get().cuda()
else:
# native PyTorch
x = torch.rand(LEN_INPUT, BATCH_SIZE, DIM_EMBEDDING).cuda()
attn_mask = torch.triu(torch.ones(LEN_INPUT, LEN_INPUT)).bool().cuda()
transformer = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(DIM_EMBEDDING, NUM_HEADS, DIM_FF, 0.0, 'relu'),
NUM_LAYERS,
torch.nn.LayerNorm(DIM_EMBEDDING),
).cuda()
def step(x, attn_mask):
if FAST_TRANSFORMERS:
transformer(x, attn_mask=attn_mask).sum().backward()
else:
transformer(x, mask=attn_mask).sum().backward()
start = torch.cuda.Event(enable_timing=True)
stop = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats()
step(x, attn_mask)
start.record()
for i in range(10):
step(x, attn_mask)
stop.record()
torch.cuda.synchronize()
print("max_allocated:", torch.cuda.max_memory_allocated() / 1024**2)
print("max_reserved:", torch.cuda.max_memory_reserved() / 1024**2)
print("total_runtime:", start.elapsed_time(stop)) |
Hi Gregor, I am closing the issue. Let me know in case you are still experiencing any problem or in case you find our transformer implementations slower in any way. Cheers, |
Hi, sorry for the late answer, I was quiet bussy with an other project. (Previous answer was incorrect, as I run a wrong script, which was quiet similar, but used different attentions.) I rerun Results of
Results of
Thanks for the clearification! I will need to update my implementation. P.S. |
Hello,
I wanted to leave some observations of myself here regarding the memory consumption (which is often a critical factor). It might be of some interesst for other who want to benchmark their implementation.
The
fast-transformer
implementation offull
-self-attention uses around 35% more GPU memory and is slightly slower, than the nativ PyTorch implementation. I would like to note, that this is true for my specific setup and I run only a limited number of test runs (4 each), which I report here. I did only discover this, as my initial configuration/implementation in PyTorch did fit into the memory.Both used modells use some embedding beforehand and differ only in the TransformerEncoderLayer / TransformerEncoderBuilder. I did not construct a minimal example, just exchanged the modules in my workflow to test different implementations.
The following numbers belong to this specific configuration:
Architecture: encoder only
Attention mask: Causal masked (upper triangle)
Layer number: 8
Embedding dimension: 64
Number of heads: 4
Feed-forward dimension: 4 * 64
Max sequence length: 4096
Batch size: 1
GPU: single RTX 2080 Ti
Peak memory usage in each run:
native PyTorch:
6152 - 6200 GB
fast-transformers:
8312 - 8454 GB
Computation time per epoch in each run:
native PyTorch:
9min 9s - 9min 33s
fast-transformers:
10min 18s - 10min 48s
The same configuration with 16 layers does fit into the GPU (~11GB) using native PyTorch and throws an OOM with fast-transformers.
I suppose this is not an important issue, as long as both implementations provide similar results (might test it in the next couple of days on my specific setup, too), as the focus of the library lies on efficient implementations.
The text was updated successfully, but these errors were encountered: