In [1]:
import torch
import fast_hadamard_transform.kernel as fast_hadamard_transform
import lsh_cumulation.kernel as lsh_cumulation
import weighted_lsh_cumulation.count_sort.kernel as count_sort
import weighted_lsh_cumulation.kernel as weighted_lsh_cumulation
import time
import math

Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/fast_hadamard_transform_kernel/build.ninja...
Building extension module fast_hadamard_transform_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module fast_hadamard_transform_kernel...
Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/lsh_cumulation_kernel/build.ninja...
Building extension module lsh_cumulation_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module lsh_cumulation_kernel...
Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_

In [3]:
for b in [1, 2, 4, 8, 16, 32]:
    B = b * 12
    S = 4096
    H = 128
    D = 64
    num_buckets = S // 2
    hashcode_len = int(math.log2(num_buckets))
    num_part = int(H / (D / hashcode_len))

    Q = torch.randn(B, S, D).cuda().float()
    K = torch.randn(B, S, D).cuda().float()
    V = torch.randn(B, S, D).cuda().float()
    mask = torch.ones(B, S, dtype = torch.int32).cuda()
    Dmat = fast_hadamard_transform.generate_Dmat(B, D, H, hashcode_len, device = Q.device)

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        Q_hash = fast_hadamard_transform.fast_hash(mask, Q, Dmat, H, hashcode_len)
        K_hash = fast_hadamard_transform.fast_hash(mask, K, Dmat, H, hashcode_len)
        hashtable = lsh_cumulation.lsh_cumulation(mask, K_hash, V, hashcode_len)
        result = lsh_cumulation.lsh_query(mask, Q_hash, hashtable)
    torch.cuda.synchronize()
    t1 = time.time()

    flops_hashcode = 2 * num_part * (3 * B * S * D + 3 * B * S * D * math.log2(D))
    flops_hashtable = B * S * H * D
    flops_query = B * S * H * D + B * S * D
    flops = flops_hashcode + flops_hashtable + flops_query

    latency = (t1 - t0) / 100
    
    print(round(latency * 1000, 4))

18.3423
36.3394
72.9286


RuntimeError: CUDA out of memory. Tried to allocate 6.00 GiB (GPU 0; 10.76 GiB total capacity; 3.85 GiB already allocated; 5.45 GiB free; 4.50 GiB reserved in total by PyTorch)

In [5]:
for b in [1, 2, 4, 8, 16, 32, 64, 128]:
    B = b * 12
    S = 512
    H = 16
    D = 64
    num_buckets = 256
    hashcode_len = int(math.log2(num_buckets))
    num_part = int(H / (D / hashcode_len))

    flops_hashcode = 2 * num_part * (3 * B * S * D + 3 * B * S * D * math.log2(D))
    flops_hashtable = B * S * H * D
    flops_query = B * S * H * D + B * S * D
    flops = flops_hashcode + flops_hashtable + flops_query

    print(round(flops / 1000000, 4))

46.0063
92.0125
184.0251
368.0502
736.1004
1472.2007
2944.4014
5888.8028


In [4]:
Q_hash = torch.randint(0, num_buckets, size = (B, S, H), dtype = torch.int32).cuda()
K_hash = torch.randint(0, num_buckets, size = (B, S, H), dtype = torch.int32).cuda()
V = torch.rand(B, S, D).cuda()
Q_mask = torch.ones(B, S, dtype = torch.int32).cuda()
K_mask = torch.ones(B, S, dtype = torch.int32).cuda()

torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
    
torch.cuda.synchronize()
t1 = time.time()


latency = (t1 - t0) / 100
(flops / latency) / (13.45 * (10 ** 12))

0.011042825760408922

In [5]:
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
    result = lsh_cumulation.lsh_query(Q_mask, Q_hash, hashtable)
torch.cuda.synchronize()
t1 = time.time()

flops = B * S * H * D + B * S * D
latency = (t1 - t0) / 100
(flops / latency) / (13.45 * (10 ** 12))

0.01388575504705759

In [6]:
Q0 = torch.rand(B, S, D).cuda()
K0 = torch.rand(B, S, D).cuda()
Q1 = torch.rand(B, S, D).cuda()
K1 = torch.rand(B, S, D).cuda()
Q_hash = torch.randint(0, num_buckets, size = (B, S, H), dtype = torch.int32).cuda()
K_hash = torch.randint(0, num_buckets, size = (B, S, H), dtype = torch.int32).cuda()
V = torch.rand(B, S, D).cuda()
Q_mask = torch.ones(B, S, dtype = torch.int32).cuda()
K_mask = torch.ones(B, S, dtype = torch.int32).cuda()

Q1 = torch.nn.functional.normalize(Q1, p = 2, dim = -1)
K1 = torch.nn.functional.normalize(K1, p = 2, dim = -1)

K_sort_info, K_sorted_idxes = count_sort.count_sort(K_mask, K_hash, num_buckets)

nnz = (1 / num_buckets) * 1024 * 1024

torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
    result = weighted_lsh_cumulation.weighted_lsh_cumulation_sorted_key(
        Q_mask, Q_hash, K_sort_info, K_sorted_idxes, Q0, K0, Q1, K1, V, 1024, 8)
torch.cuda.synchronize()
t1 = time.time()

flops = B * nnz * D + B * nnz * (D - 1) + B * nnz + B * nnz + B * nnz * D + B * nnz * D
latency = (t1 - t0) / 100
(flops / latency) / (13.45 * (10 ** 12))

0.0003793339733467389