In [1]:
import torch
import time

In [9]:
model_size = 4096
seq_len = 10000
batch = 4

# Pure loading time

In [10]:

Key = torch.randn(batch,seq_len,model_size)
Value = torch.randn(batch,seq_len,model_size)
start_time = time.time()
Key = Key.to(0)
Value = Value.to(0)
torch.cuda.synchronize()
end_time = time.time()
print("Time taken to move to GPU: ", end_time - start_time)
Key = Key.to('cpu')
Value = Value.to('cpu')

Time taken to move to GPU:  0.22755217552185059


# Pure computing time

In [14]:
from transformers.pytorch_utils import Conv1D
input = torch.randn(batch,seq_len, model_size).to(0)
attn = Conv1D(3 * model_size, model_size).to(0)
W_q = torch.randn(model_size, model_size).to(0)
W_k = torch.randn(model_size, model_size).to(0)
W_v = torch.randn(model_size, model_size).to(0)
start_time = time.time()
q = torch.matmul(input[-1], W_q)
k = torch.matmul(input, W_k)
v = torch.matmul(input, W_v)
# q,k,v = attn(input).split(model_size, dim=2)

torch.cuda.synchronize()
end_time = time.time()
print("Time taken for matmul: ", end_time - start_time)

NameError: name 'W_q' is not defined

# Loading and recompute

In [None]:
model_size = 4096
compute_length = 5000
load_length = 5000
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
input = torch.randn(compute_length+load_length, model_size).to(0)
W_q = torch.randn(model_size, model_size).to(0)
W_k = torch.randn(model_size, model_size).to(0)
W_v = torch.randn(model_size, model_size).to(0)

K_cache = torch.randn(load_length, model_size)
V_cache = torch.randn(load_length, model_size)
start_time = time.time()
# compute
with torch.cuda.stream(s1):
    q = torch.matmul(input[-1], W_q)
    k = torch.matmul(input[:compute_length], W_k)
    v = torch.matmul(input[:compute_length], W_v)
with torch.cuda.stream(s2):
    K_cache = K_cache.to(0)
    V_cache = V_cache.to(0)
# torch.cuda.synchronize()
    K_cache = torch.cat([K_cache, k], dim=0)
    V_cache = torch.cat([V_cache, v], dim=0)
torch.cuda.synchronize()
end_time = time.time()
print("Time taken for combination: ", end_time - start_time)


# Generating graph

In [None]:

model_size_list = [2048,4096,8192,16384]
seq_len_list = [2000,4000,6000,8000]
batch_size = 4
time_list_compute = []
for model_size in model_size_list:
    time_list1 = []
    for seq_len in seq_len_list:
        
        input = torch.randn(batch_size,seq_len, model_size).to(0)
        W_q = torch.randn(model_size, model_size).to(0)
        W_k = torch.randn(model_size, model_size).to(0)
        W_v = torch.randn(model_size, model_size).to(0)
        start_time = time.time()
        q = torch.matmul(input[:,-1,:], W_q)
        k = torch.matmul(input, W_k)
        v = torch.matmul(input, W_v)
        torch.cuda.synchronize()
        end_time = time.time()
        time_list1.append(end_time - start_time)
    time_list_compute.append(time_list1)
print(time_list_compute)


In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()

# Each model_size will be a separate line in the graph
for i, model_size in enumerate(model_size_list):
    ax.plot(seq_len_list, time_list_compute[i], label=f'Model size {model_size}')

# Label the axes and the plot
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Time (seconds)')
ax.set_title('Pure computation of QKV on Titan RTX')
ax.legend()

# Show grid for better readability
plt.grid(True)

# Displaying the graph
plt.show()

In [None]:

model_size_list = [2048,4096,8192,16384]
seq_len_list = [2000,4000,6000,8000]
batch_size = 4
time_list_load = []
for model_size in model_size_list:
    time_list1 = []
    for seq_len in seq_len_list:
        
        Key = torch.randn(batch_size,seq_len,model_size)
        Value = torch.randn(batch_size,seq_len,model_size)
        start_time = time.time()
        Key = Key.to(0)
        Value = Value.to(0)
        torch.cuda.synchronize()
        end_time = time.time()
        time_list1.append(end_time - start_time)
    time_list_load.append(time_list1)
print(time_list_load)


In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()

# Each model_size will be a separate line in the graph
for i, model_size in enumerate(model_size_list):
    ax.plot(seq_len_list, time_list_load[i], label=f'Model size {model_size}')

# Label the axes and the plot
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Time (seconds)')
ax.set_title('Pure Loading of KV Cache on Titan RTX')
ax.legend()

# Show grid for better readability
plt.grid(True)

# Displaying the graph
plt.show()

In [None]:
difference = []

for i in range(len(model_size_list)):
    difference.append([x/y for x, y in zip(time_list_compute[i], time_list_load[i])])
print(difference)
fig, ax = plt.subplots()
for i, model_size in enumerate(model_size_list):
    ax.plot(seq_len_list, difference[i], label=f'Model size {model_size}')

# Label the axes and the plot
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Compute time / Load time')
ax.set_title('Compute Loading rate on Titan RTX')
ax.legend()

# Show grid for better readability
plt.grid(True)

# Displaying the graph
plt.show()
