In [1]:
import torch
from supervoice_valle import Attend
from torch.nn.attention import SDPBackend, sdpa_kernel
import time

In [2]:
a_pt = Attend(engine = "torch", heads = 32).to("cuda").eval()
a_nt = Attend(engine = "direct", heads = 32).to("cuda").eval()
a_xt = Attend(engine = "xformers", heads = 32).to("cuda").eval()
a_ft = Attend(engine = "flash", heads = 32).to("cuda").eval()
attentions = [a_pt, a_nt, a_xt, a_ft]

In [3]:
# Source
query = torch.rand(1, 32, 32, 16, dtype=torch.float16, device="cuda")
key = torch.rand(1, 32, 32, 16, dtype=torch.float16, device="cuda")
value = torch.rand(1, 32, 32, 16, dtype=torch.float16, device="cuda")
lengths = [4, 8, 8, 12]

print("Without padding")
source = a_pt(query, key, value)
for a in attentions:
    dest = a(query, key, value)
    print(a.engine, (dest - source).abs().max().item())

print("With padding")
source = a_pt(query, key, value, lenghts = lengths)
for a in attentions:
    dest = a(query, key, value, lenghts = lengths)
    print(a.engine, (dest - source).abs().max().item())

Without padding
torch 0.0
direct 0.00048828125
xformers 0.0
flash 0.0
With padding
torch 0.0
direct 0.00048828125
xformers 0.0
flash 0.0


In [4]:
# Benchmark
print("Without padding")
for a in attentions:
    start = time.time()
    for i in range(100000):
        a(query, key, value)
    print(a.engine, time.time() - start)

print("With padding")
for a in attentions:
    start = time.time()
    for i in range(100000):
        a(query, key, value, lenghts = lengths)
    print(a.engine, time.time() - start)

Without padding
torch 4.43998122215271
direct 10.454708576202393
xformers 8.981621742248535
flash 3.2445619106292725
With padding
torch 16.73130464553833
direct 25.77487087249756
xformers 16.095849990844727
flash 15.358261585235596
