In [2]:
import sys
sys.path.append('../cs336-basics')
from cs336_basics.model import scaled_dot_product_attention

In [3]:
def create_qkv(batch_size: int, seq_length: int, d_model: int, device: str):
    Q = torch.randn(batch_size, seq_length, d_model, requires_grad=True, device=device)
    K = torch.randn(batch_size, seq_length, d_model, requires_grad=True, device=device)
    V = torch.randn(batch_size, seq_length, d_model, requires_grad=True, device=device)
    return Q, K, V

In [4]:
def create_mask(seq_length: int, device: str):
    seq = torch.arange(seq_length, device=device)
    qi = seq.view(1, seq_length, 1)                      # (1, query, 1)
    kj = seq.view(1, 1, seq_length)                      # (1, 1, key)
    mask = qi >= kj                             # (1, query, key) ⇒ broadcast
    return mask

In [5]:
def get_current_allocated_memory(is_gpu: bool):
    if is_gpu:
        return torch.cuda.memory_allocated()
    else:
        return torch.mps.current_allocated_memory()

In [None]:
def empty_cache(is_gpu: bool) -> None:
    if is_gpu:
        torch.cuda.empty_cache()
    else:
        torch.mps.empty_cache()

In [None]:
import torch, time

d_model_list = [16, 32, 64, 128]
seq_length_list = [256, 1024, 4096, 8192, 16384]
batch_size = 8
is_gpu = torch.cuda.is_available()
device = "cuda" if is_gpu else "mps"


for d_model in d_model_list:
    for seq_length in seq_length_list:
        Q, K, V = create_qkv(batch_size, seq_length, d_model, device)
        mask = create_mask(seq_length, device)
        forward_total_time = 0
        backward_total_time = 0
        forward_max_memory = 0
        backward_max_memory = 0
        empty_cache(is_gpu)

        try:
            for _ in range(10): #warm-up
                out = scaled_dot_product_attention(Q, K, V, mask)
                loss = out.sum()    
                loss.backward()
            
            for _ in range(100):          # warm-up & timing loop
                forward_total_time -= time.time()
                out = scaled_dot_product_attention(Q, K, V, mask) # forward
                if is_gpu:
                    torch.cuda.synchronize()
                forward_total_time += time.time()
                forward_max_memory = max(forward_max_memory, get_current_allocated_memory(is_gpu))
                loss = out.sum()
    
                backward_total_time -= time.time()
                loss.backward()           # backward
                if is_gpu:
                    torch.cuda.synchronize()
                backward_total_time += time.time()
                backward_max_memory = max(backward_max_memory, get_current_allocated_memory(is_gpu))
    
            print(f"d_model: {d_model}, seq_length: {seq_length}, forward_total_time: {forward_total_time:.3f} sec, backward_total_time: {backward_total_time:.3f} sec, forward_max_memory: {forward_max_memory/1e6:.1f} MB, backward_max_memory: {backward_max_memory/1e6:.1f} MB")
        except Exception as e:
            print(f"d_model: {d_model}, seq_length: {seq_length}, exception: CUDA out of memory")

d_model: 16, seq_length: 256, forward_total_time: 0.034 sec, backward_total_time: 0.080 sec, forward_max_memory: 5.2 MB, backward_max_memory: 1.0 MB
d_model: 16, seq_length: 1024, forward_total_time: 0.121 sec, backward_total_time: 0.338 sec, forward_max_memory: 72.0 MB, backward_max_memory: 4.7 MB
d_model: 16, seq_length: 4096, forward_total_time: 1.890 sec, backward_total_time: 6.618 sec, forward_max_memory: 1185.6 MB, backward_max_memory: 111.2 MB
