In [1]:
import torch

torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"Running on {device}")

batch_size = 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)

PyTorch version: 2.2.2
Running on cpu


In [4]:
from attentions import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper

mha_ch03_wrapper = Ch03_MHA_Wrapper(
    d_in=embed_dim,
    d_out=embed_dim//12,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03_wrapper(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


In [5]:
from attentions import MultiHeadAttention as Ch03_MHA

mha_ch03 = Ch03_MHA(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


### Quick speed comparison

In [6]:
%timeit mha_ch03_wrapper(embeddings)

507 ms ± 25.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%timeit mha_ch03(embeddings)

527 ms ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


- second one should be faster on gpu

### Speed comparison

In [11]:
def time_pytorch_function(func, *input, num_repeats = 1_000):
    # CUDA IS ASYNC so can't use python time module
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(5):
        func(*input)
    torch.cuda.synchronize()

    start.record()
    for _ in range(num_repeats):
        func(*input)
        torch.cuda.synchronize()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / num_repeats

In [None]:
import matplotlib.pyplot as plt


embeddings_cuda = embeddings.to(torch.device("cuda"))

functions = {
    "1) MHA wrapper class": mha_ch03_wrapper,
    "2) MHA Ch03": mha_ch03,
}
execution_times = [time_pytorch_function(fn, embeddings_cuda) for name,fn in functions.items()]


# Plotting

# Customize further for dark mode aesthetics
plt.rcParams['figure.facecolor'] = '#121212'  # Dark figure background
plt.rcParams['axes.facecolor'] = '#121212'    # Dark axes background
plt.rcParams['axes.edgecolor'] = 'white'      # White axes border
plt.rcParams['axes.labelcolor'] = 'white'     # White labels
plt.rcParams['text.color'] = 'white'          # White text
plt.rcParams['xtick.color'] = 'white'         # White x ticks
plt.rcParams['ytick.color'] = 'white'         # White y ticks
plt.rcParams['grid.color'] = '#444444'        # Lighter grid lines for contrast
plt.rcParams['lines.linewidth'] = 2           # Thicker plot lines for visibility
plt.rcParams['lines.markersize'] = 8          # Larger markers for visibility

fig, ax = plt.subplots()
bars = plt.bar(functions.keys(), execution_times)

plt.ylabel('Execution time (ms)')
plt.xticks(rotation=45, ha="right")

# Calculate new ylim with a margin
max_execution_time = max(execution_times)
upper_ylim = max_execution_time + 0.2 * max_execution_time  # Adding a 20% margin

plt.ylim(0, upper_ylim)  # Setting new ylim

# Annotate bars with execution times
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha='center', va='bottom')


plt.tight_layout()
#plt.savefig("comparison.pdf")
plt.show()