Let's look at computational time complexity via a GPU.

In [1]:
# Pytorch version (to take advantage of the GPU)
import torch
import time

# same parameters as for the CPU
n = 512
d = 512

# use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#define the inputs
input_seq = torch.rand(n,d,device=device)

# run the same simulation, but this on the GPU
start_time = time.time()
_ = torch.mm(input_seq, input_seq.t()) # assuming t() is the transpose

at = time.time() - start_time
print(f"Self-attention computation time: {at} seconds")


cuda
Self-attention computation time: 0.20593857765197754 seconds


Now, let's look at the recurrent layer

In [2]:
# simulation of the recurrent layer
start_time = time.time()
hidden_state = torch.zeros(d,device=device)
for i in range(n):
  for j in range(n):
    for k in range(d):
      hidden_state[j] += input_seq[i,j] * hidden_state[k]
      ct = time.time() - start_time
      if ct > at * 10:
        break
rt = time.time() - start_time
print(f"Recurrent layer computation time: {rt} seconds")

Recurrent layer computation time: 10.924416303634644 seconds


Let's calculate the percentage of attention layer in the total time.

In [4]:
# calculate the total
total = at + rt

# calculate the percentage of at
percentage_at = round( (at/total) * 100, 2)

print(f"Percentage of self-attention layer in the sum of self-attention and recurrent is : {percentage_at}")

Percentage of self-attention layer in the sum of self-attention and recurrent is : 1.85


In [5]:
!nvidia-smi

Tue May 21 14:08:58 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P0              28W /  70W |    153MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    