In [1]:
import torch
import fast_hadamard_transform.kernel as fast_hadamard_transform
import lsh_cumulation.kernel as lsh_cumulation
import weighted_lsh_cumulation.count_sort.kernel as count_sort
import weighted_lsh_cumulation.kernel as weighted_lsh_cumulation
import time
import math

Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/fast_hadamard_transform_kernel/build.ninja...
Building extension module fast_hadamard_transform_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module fast_hadamard_transform_kernel...
Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/lsh_cumulation_kernel/build.ninja...
Building extension module lsh_cumulation_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module lsh_cumulation_kernel...
Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_

In [97]:
with torch.no_grad():
    B = 1
    S = 4096
    H = 16
    D = 64
    num_buckets = S // 2
    hashcode_len = int(math.log2(num_buckets))
    num_part = int(H / (D / hashcode_len))

    Q = torch.randn(B * 12, S, D).cuda().float()
    K = torch.randn(B * 12, S, D).cuda().float()
    V = torch.randn(B * 12, S, D).cuda().float()
    mask = torch.ones(B * 12, S, dtype = torch.int32).cuda()
    
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        Dmat = fast_hadamard_transform.generate_Dmat(B, D, H, hashcode_len, device = Q.device)
        Q_hash = fast_hadamard_transform.fast_hash(mask, Q, Dmat, H, hashcode_len)
        K_hash = fast_hadamard_transform.fast_hash(mask, K, Dmat, H, hashcode_len)
    torch.cuda.synchronize()
    t1 = time.time()
    
    latency_hash = (t1 - t0) * 1000 / 100
    print(round(latency_hash, 2))
        
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        result = lsh_cumulation.lsh_cumulation_query(mask, Q_hash, mask, K_hash, V, hashcode_len)
    torch.cuda.synchronize()
    t1 = time.time()
    
    latency_table = (t1 - t0) * 1000 / 100
    print(round(latency_table, 2))
    
    latency_attn = latency_hash + latency_table
    
    ff = torch.nn.Linear(D * 12, D * 12).cuda()
    X = torch.randn(B, 12, S, D).cuda().float()

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        X = ff(X.transpose(1, 2).reshape(B, S, D * 12))
    torch.cuda.synchronize()
    t1 = time.time()
    
    latency_dense = (t1 - t0) * 1000 / 100
    print(round(latency_attn + latency_dense, 2))
    
    W_q = torch.nn.Linear(D * 12, D * 12).cuda()
    W_k = torch.nn.Linear(D * 12, D * 12).cuda()
    W_v = torch.nn.Linear(D * 12, D * 12).cuda()
    X = torch.randn(B, S, D * 12).cuda().float()

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        Q, K, V = W_q(X), W_k(X), W_v(X)
        Q = Q.reshape(B, S, 12, D).transpose(1, 2)
        K = K.reshape(B, S, 12, D).transpose(1, 2)
        V = V.reshape(B, S, 12, D).transpose(1, 2)
    torch.cuda.synchronize()
    t1 = time.time()

    latency_prep = (t1 - t0) * 1000 / 100
    print(round(latency_prep + latency_attn + latency_dense, 2))

0.9
0.88
2.29
3.62


In [45]:
with torch.no_grad():
    B = 1
    S = 4096
    D = 64
    num_buckets = S // 2
    hashcode_len = int(math.log2(num_buckets))
    num_part = int(H / (D / hashcode_len))

    Q = torch.randn(B * 12, S, D).cuda().float()
    K = torch.randn(B * 12, S, D).cuda().float()
    V = torch.randn(B * 12, S, D).cuda().float()
    mask = torch.ones(B * 12, S, dtype = torch.int32).cuda()
    
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        dot = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(D)
        dot = dot - 1e6 * mask[:, None, :]
        softmax = torch.nn.functional.softmax(dot, dim = -1)
        out = torch.matmul(softmax, V)
    torch.cuda.synchronize()
    t1 = time.time()
    
    latency_attn = (t1 - t0) * 1000 / 100
    print(round(latency_attn, 2))
    
    ff = torch.nn.Linear(D * 12, D * 12).cuda()
    X = torch.randn(B, 12, S, D).cuda().float()

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        X = ff(X.transpose(1, 2).reshape(B, S, D * 12))
    torch.cuda.synchronize()
    t1 = time.time()
    
    latency_dense = (t1 - t0) * 1000 / 100
    print(round(latency_attn + latency_dense, 2))
    
    W_q = torch.nn.Linear(D * 12, D * 12).cuda()
    W_k = torch.nn.Linear(D * 12, D * 12).cuda()
    W_v = torch.nn.Linear(D * 12, D * 12).cuda()
    X = torch.randn(B, S, D * 12).cuda().float()

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        Q, K, V = W_q(X), W_k(X), W_v(X)
        Q = Q.reshape(B, S, 12, D).transpose(1, 2)
        K = K.reshape(B, S, 12, D).transpose(1, 2)
        V = V.reshape(B, S, 12, D).transpose(1, 2)
    torch.cuda.synchronize()
    t1 = time.time()

    latency_prep = (t1 - t0) * 1000 / 100
    print(round(latency_prep + latency_attn + latency_dense, 2))

14.52
15.04
16.4


In [66]:
with torch.no_grad():
    for b in [1, 2]:
        B = b
        S = 4096
        D = 768

        ff = torch.nn.Linear(D, D).cuda()
        X = torch.randn(B, 12, S, 64).cuda().float()

        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(100):
            X = ff(X.transpose(1, 2).reshape(B, S, D))
        torch.cuda.synchronize()
        t1 = time.time()

        print(round((t1 - t0) * 1000 / 100, 2))

0.62
1.22


In [55]:
0.21, 0.21+0.4, 0.21+0.4+0.12

(0.21, 0.61, 0.73)

In [63]:
with torch.no_grad():
    for b in [1, 2, 32]:
        B = b
        S = 512
        D = 64

        Q = torch.randn(B, 12, S, D).cuda().float()
        K = torch.randn(B, 12, S, D).cuda().float()
        V = torch.randn(B, 12, S, D).cuda().float()
        mask = torch.ones(B, S, dtype = torch.int32).cuda()

        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(100):
            dot = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(D)
            dot = dot - 1e6 * mask[:, None, None, :]
            softmax = torch.nn.functional.softmax(dot, dim = -1)
            out = torch.matmul(softmax, V)
        torch.cuda.synchronize()
        t1 = time.time()

        print(round((t1 - t0) * 1000 / 100, 2))

0.29
0.49
7.74


In [51]:
7.73/32, (7.73+5.64)/32, (7.73+5.64+2.09)/32

(0.2415625, 0.41781250000000003, 0.483125)