<a href="https://colab.research.google.com/github/morganmcg1/reformer-fastai/blob/main/exploration/einops_timing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Wed Nov 11 16:44:03 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install -q einops

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.profiler as profiler
from torch.cuda import amp

from einops import rearrange

In [8]:
def time_cuda(f, *args):
    f(*args)
    torch.cuda.synchronize()

In [4]:
def dotprod_matmul(q, k, v):
    return F.softmax(q@k.transpose(-2,-1), -1) @ v

In [5]:
def dotprod_einops(q, k, v):
    return F.softmax(torch.einsum('bid,bjd->bij', q, k), -1) @ v

In [6]:
def all_einops(q, k, v):
    return torch.einsum('bij,bjd->bid', F.softmax(torch.einsum('bid,bjd->bij', q, k), -1), v)

In [7]:
bs = 8
sl = 512
d = 1024

In [9]:
q, k, v = torch.randn(bs, sl, d*3, device='cuda').chunk(3, -1)

In [10]:
assert torch.allclose(dotprod_matmul(q,k,v), dotprod_einops(q,k,v))
assert torch.allclose(dotprod_matmul(q,k,v), all_einops(q,k,v))

In [11]:
time_cuda(dotprod_matmul, q,k,v)
%timeit time_cuda(dotprod_matmul, q,k,v)

100 loops, best of 3: 5.46 ms per loop


In [12]:
time_cuda(dotprod_einops, q,k,v)
%timeit time_cuda(dotprod_einops, q,k,v)

100 loops, best of 3: 4.47 ms per loop


In [13]:
time_cuda(all_einops, q,k,v)
%timeit time_cuda(all_einops, q,k,v)

100 loops, best of 3: 4.29 ms per loop


## Multihead

In [14]:
def dotprod_matmul(q, k, v):
    n_heads = 8
    bs, seq_len, d = q.size()
    k = k.view(bs, seq_len, n_heads, d//n_heads).transpose(1, 2)
    q = q.view(bs, seq_len, n_heads, d//n_heads).transpose(1, 2)
    v = v.view(bs, seq_len, n_heads, d//n_heads).transpose(1, 2)
    out = F.softmax(q@k.transpose(-2,-1), -1) @ v
    return out.transpose(1, 2).contiguous().view(bs, seq_len, d)

In [15]:
def dotprod_einops(q, k, v):
    n_heads = 8
    bs, seq_len, d = q.size()
    k = k.view(bs, seq_len, n_heads, d//n_heads).transpose(1, 2)
    q = q.view(bs, seq_len, n_heads, d//n_heads).transpose(1, 2)
    v = v.view(bs, seq_len, n_heads, d//n_heads).transpose(1, 2)
    out = torch.einsum('bhij,bhjd->bhid', F.softmax(torch.einsum('bhid,bhjd->bhij', q, k), -1), v)
    return out.transpose(1, 2).contiguous().view(bs, seq_len, d)

In [16]:
def all_einops(q, k, v):
    n_heads = 8
    bs, seq_len, d = q.size()
    q = rearrange(q, 'b l (h d) -> b h l d', h=n_heads)
    k = rearrange(k, 'b l (h d) -> b h l d', h=n_heads)
    v = rearrange(v, 'b l (h d) -> b h l d', h=n_heads)
    out = torch.einsum('bhij,bhjd->bhid', F.softmax(torch.einsum('bhid,bhjd->bhij', q, k), -1), v)
    return rearrange(out, 'b h n d -> b n (h d)')

In [17]:
assert torch.allclose(dotprod_matmul(q,k,v), dotprod_einops(q,k,v))
assert torch.allclose(dotprod_matmul(q,k,v), all_einops(q,k,v))

In [18]:
time_cuda(dotprod_matmul, q,k,v)
%timeit time_cuda(dotprod_matmul, q,k,v)

100 loops, best of 3: 7.68 ms per loop


In [19]:
time_cuda(dotprod_einops, q,k,v)
%timeit time_cuda(dotprod_einops, q,k,v)

100 loops, best of 3: 7.61 ms per loop


In [20]:
time_cuda(all_einops, q,k,v)
%timeit time_cuda(all_einops, q,k,v)

100 loops, best of 3: 7.75 ms per loop
