In [1]:
import torch
from torch import nn
import math
import numpy as np

# Don't track gradients
# torch.set_grad_enabled(False)

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


## Node-to-node Version

In [5]:
torch.manual_seed(0);

In [6]:
torch.cuda.reset_max_memory_allocated()

In [7]:
num_nodes = 100000
num_hidden = 256
batch_size = 1
attention_heads = 1
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")

In [8]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

# apply linear layers to compute query, key, and value
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [9]:
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")

GPU memory used:  393.00390625 MB


In [10]:
%%time
# compute dot product of query and key, and apply softmax
dot_product = torch.matmul(q, k.transpose(2, 3))

CPU times: user 8.81 ms, sys: 16 ms, total: 24.8 ms
Wall time: 23.7 ms


In [11]:
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")

GPU memory used:  38539.9765625 MB


In [12]:
%%time
# apply dot product of attention weights and value
attn_output = torch.matmul(dot_product, v)

CPU times: user 1.39 ms, sys: 0 ns, total: 1.39 ms
Wall time: 634 Âµs


In [13]:
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")

GPU memory used:  38637.9765625 MB


In [14]:
attn_output[0][0][0][1]

tensor(-19410.9531, device='cuda:0', grad_fn=<SelectBackward0>)

## Feature-to-feature Version

In [2]:
torch.manual_seed(0);

In [3]:
num_nodes = 100000
num_hidden = 256
batch_size = 1
attention_heads = 1
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")

In [4]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

# apply linear layers to compute query, key, and value
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [5]:
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")

GPU memory used:  393.00390625 MB


In [11]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
# compute dot product of query and key, and apply softmax
dot_product = torch.matmul(k.transpose(2, 3), v)
end.record()
torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  423.771484375 MB
Time taken:  1.656831979751587 ms
CPU times: user 2.53 ms, sys: 0 ns, total: 2.53 ms
Wall time: 1.99 ms


In [11]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
# apply dot product of attention weights and value
attn_output = torch.matmul(q, dot_product)
end.record()
torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  589.25390625 MB
Time taken:  1.7776639461517334 ms
CPU times: user 2.69 ms, sys: 0 ns, total: 2.69 ms
Wall time: 2.09 ms


In [6]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
# apply dot product of attention weights and value
dot_product = torch.matmul(k.transpose(2, 3), v)
attn_output = torch.matmul(q, dot_product)
end.record()
torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  491.25390625 MB
Time taken:  3.200000047683716 ms
CPU times: user 4.3 ms, sys: 0 ns, total: 4.3 ms
Wall time: 3.48 ms


In [7]:
attn_output[0][0][0][1]

tensor(-19410.9375, device='cuda:0', grad_fn=<SelectBackward0>)

## End-to-End Version

In [2]:
torch.manual_seed(0);

In [3]:
num_nodes = 300000
num_hidden = 256
batch_size = 1
attention_heads = 1
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")

In [4]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

# apply linear layers to compute query, key, and value
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [5]:
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")

GPU memory used:  1172.87890625 MB


In [11]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
# compute dot product of query and key, and apply softmax
attn_output = torch.matmul(torch.matmul(q, k.transpose(2, 3)), v)
end.record()
torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  39070.4951171875 MB
Time taken:  708.6981201171875 ms
CPU times: user 699 ms, sys: 11.4 ms, total: 710 ms
Wall time: 709 ms


In [6]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
# apply dot product of attention weights and value
attn_output = torch.matmul(q, torch.matmul(k.transpose(2, 3), v))
end.record()
torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  491.25390625 MB
Time taken:  3.3904640674591064 ms
CPU times: user 3.94 ms, sys: 0 ns, total: 3.94 ms
Wall time: 3.69 ms


In [6]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

with torch.no_grad():
    start.record()
    # apply dot product of attention weights and value
    attn_output = torch.matmul(q, torch.matmul(k.transpose(2, 3), v))
    end.record()
    torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  1466.09765625 MB
Time taken:  8.452095985412598 ms
CPU times: user 7.33 ms, sys: 2.98 ms, total: 10.3 ms
Wall time: 8.84 ms


In [7]:
attn_output[0][0][0][1]

tensor(-19410.9375, device='cuda:0', grad_fn=<SelectBackward0>)

## Sparse Version

In [2]:
torch.manual_seed(0);

In [3]:
num_nodes = 100000
edge_density = 10
num_edges = num_nodes * edge_density
num_hidden = 256
batch_size = 1
attention_heads = 1
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")

In [4]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

# apply linear layers to compute query, key, and value
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [5]:
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")

GPU memory used:  393.00390625 MB


In [6]:
# Sparse version of the attention mechanism
adjacency = torch.randint(0, num_nodes, (2, num_edges))

# Convert to sparse torch tensor
sp_adjacency = torch.sparse_coo_tensor(adjacency, torch.ones(num_edges), (num_nodes, num_nodes)).to("cuda")

In [7]:
dot_product = torch.sparse.mm(sp_adjacency, k.transpose(2, 3).squeeze())

RuntimeError: CUDA out of memory. Tried to allocate 37.25 GiB (GPU 0; 39.45 GiB total capacity; 413.00 MiB already allocated; 4.41 GiB free; 414.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [29]:
sp_adjacency.shape

torch.Size([10000, 10000])

In [30]:
q.shape

torch.Size([1, 1, 10000, 10000])

In [28]:
torch.sparse.mm(sp_adjacency, q)

RuntimeError: The expanded size of the tensor (1) must match the existing size (10000) at non-singleton dimension 1.  Target sizes: [10000, 1].  Tensor sizes: [10000, 10000]

In [6]:
%%time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

with torch.no_grad():
    start.record()
    dot_product = torch.matmul(k.transpose(2, 3), v)
    attn_output = torch.matmul(q, dot_product)
    end.record()
    torch.cuda.synchronize()
print("GPU memory used: ", torch.cuda.max_memory_allocated()/1024**2, "MB")
print("Time taken: ", start.elapsed_time(end), "ms")

GPU memory used:  3820.154296875 MB
Time taken:  233.67987060546875 ms
CPU times: user 236 ms, sys: 0 ns, total: 236 ms
Wall time: 234 ms
