Tras haber visto como estabilizar la función de atención, voy a intentar comprobar si los einsums están haciendola mucho más lenta que sumas y multiplicaciones normales.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
# Stable multihead attention
class MultiHeadAttention(nn.Module):
  def __init__(self, hidden_size, head_size, num_heads):
    super().__init__()
    self.hidden_size = hidden_size
    self.head_size = head_size
    self.num_heads = num_heads

    self.w_u = nn.Linear(hidden_size, head_size * num_heads)

  def forward(self, x):
    B, L, H = x.size()
    u = self.w_u(x)
    u = u.view(B, L, self.num_heads, self.head_size) # [B, L, n, Dh]

    A = torch.einsum("blnd,bknd->blnk", u, u) # Attention matrices [B, L, n, L]
    A = A.contiguous() # important since einsum leaves A discontigous so view cannot be used

    # Substract the max value for each batch example to prevent overflows in the exp (stabilization).
    A = A.view(B, L, self.num_heads*L) - A.view(B, -1).max(dim=1, keepdim=True).values.unsqueeze(-1)
    exp_A = torch.exp(A)
    
    alpha = torch.einsum("bln->bl",exp_A) / torch.einsum("bln->b", exp_A).unsqueeze(-1)

    attended = torch.einsum("bl,blh->bh", alpha, x)
    return attended

# Stable multihead attention
class FastMultiHeadAttention(nn.Module):
  def __init__(self, hidden_size, head_size, num_heads):
    super().__init__()
    self.hidden_size = hidden_size
    self.head_size = head_size
    self.num_heads = num_heads

    self.w_u = nn.Linear(hidden_size, head_size * num_heads)

  def forward(self, x):
    B, L, H = x.size()
    u = self.w_u(x)
    u = u.view(B, L, self.num_heads, self.head_size) # [B, L, n, Dh]
    
    #  A = torch.einsum("blnd,bknd->blnk", u, u)
    u = u.permute(0, 2, 1, 3)
    A = torch.matmul(u, u.mT).permute(0, 2, 1, 3)
    A = A.contiguous()
    # Substract the max value for each batch example to prevent overflows in the exp (stabilization).
    A = A.view(B, L, self.num_heads*L) - A.view(B, -1).max(dim=1, keepdim=True).values.unsqueeze(-1)
    exp_A = torch.exp(A)
    
    alpha = torch.sum(exp_A, dim=-1) / torch.sum(exp_A.view(B, -1), dim=-1, keepdim=True)

    # attended = torch.einsum("bl,blh->bh", alpha, x)
    attended = torch.matmul(alpha.unsqueeze(1), x).squeeze(1)
    return attended

In [12]:
torch.set_float32_matmul_precision('high')
slowma = torch.compile(MultiHeadAttention(1024, 32, 2).to("cuda"))
fastma = torch.compile(FastMultiHeadAttention(1024, 32, 2).to("cuda"))
x = torch.randn(32, 1000, 1024).to("cuda")

In [4]:
from IPython.core.magic import register_cell_magic
import time

@register_cell_magic
def timercell(line, cell):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    exec(cell, globals())
    end.record()
    torch.cuda.synchronize()
    print(f"⏱️ Elapsed: {start.elapsed_time(end) / 1000} s")

In [13]:
%%timercell
with torch.no_grad():
    for i in range(10000):
        slowma(x)

⏱️ Elapsed: 13.7028505859375 s


In [14]:
%%timercell
with torch.no_grad():
    for i in range(10000):
        fastma(x)

⏱️ Elapsed: 13.560544921875 s


It seems like there is no much difference between the two approaches. I will stick with einsum.

In [9]:
slowma = MultiHeadAttention(1024, 32, 2).to("cuda")
fastma = FastMultiHeadAttention(1024, 32, 2).to("cuda")
x = torch.randn(32, 1000, 1024).to("cuda")

In [10]:
%%timercell
with torch.no_grad():
    for i in range(10000):
        slowma(x)

⏱️ Elapsed: 26.742609375 s


In [11]:
%%timercell
with torch.no_grad():
    for i in range(10000):
        fastma(x)

⏱️ Elapsed: 26.7568828125 s


Neither does it when they are not compiled

In [15]:
# This torch option seems to greatly improve performance when compiling:
#torch.set_float32_matmul_precision('high')