# Matrix Multiplication Speed Comparison

## torch.bmm

https://discuss.pytorch.org/t/difference-between-matmul-broadcast-and-bmm-on-computational-graph/22674

>These two should be equivalent even if they define different computational graphs (potentially doing broadcasting in a different way). That means that, during training, because of floating point precision, they can end up giving noticeably different results as errors are amplified by the training...Both will be correct. You can see changing from one to the other having the same effect as changing the random seed that you set at the beginning of your script: all the numbers you will get will be different but if your model is robust, both should converge to a similar solution in terms of performance.

**Example**

If input is a (b×n×m) tensor, mat2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

input = torch.randn(10, 3, 4)

mat2 = torch.randn(10, 4, 5)

res = torch.bmm(input, mat2)

res.size()

torch.Size([10, 3, 5])

In [36]:
import torch
# q = torch.tensor([0,0,0,0,0]
 
tgt_len = src_len = 5
bsz = 4
n_heads = 2
head_dim = 12
emb_sz = n_heads * head_dim

q = torch.empty(tgt_len, bsz, emb_sz).fill_(2.)
k = torch.empty(tgt_len, bsz, emb_sz).fill_(3.)
q.size()

torch.Size([5, 4, 24])

In [37]:
q_mod = q.contiguous().view(tgt_len, bsz * n_heads, head_dim).transpose(0, 1)
k_mod = k.contiguous().view(-1, bsz * n_heads, head_dim).transpose(0, 1)
q_mod.size(), k_mod.size()

(torch.Size([8, 5, 12]), torch.Size([8, 5, 12]))

In [53]:
attn_output_weights_mod = torch.bmm(q_mod, k_mod.transpose(1, 2))

assert list(attn_output_weights_mod.size()) == [bsz * n_heads, tgt_len, src_len]

attn_output_weights_mod.size(), (q_mod @ k_mod.transpose(1, 2)).size(), torch.matmul(q_mod, k_mod.transpose(1, 2)).size()

(torch.Size([8, 5, 5]), torch.Size([8, 5, 5]), torch.Size([8, 5, 5]))

In [1]:
import torch
 
tgt_len = src_len = 5
bsz = 1024
n_heads = 24
emb_sz = 48000
head_dim = int(emb_sz / n_heads)
# emb_sz = n_heads * head_dim

q = torch.empty(tgt_len, bsz, emb_sz).fill_(2.)
k = torch.empty(tgt_len, bsz, emb_sz).fill_(3.)

q_mod = q.contiguous().view(tgt_len, bsz * n_heads, head_dim).transpose(0, 1)
k_mod = k.contiguous().view(-1, bsz * n_heads, head_dim).transpose(0, 1)

q_mod.size(), k_mod.size()

(torch.Size([24576, 5, 2000]), torch.Size([24576, 5, 2000]))

# Matrix Multiplication Performance

## With View

**torch.bmm**

In [2]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.bmm(q_mod.cuda(), k_mod.transpose(1, 2).cuda())
torch.cuda.synchronize()

100 loops, best of 5: 234 ms per loop


**@ operator**

In [3]:
%%timeit -n 100 -r 5
attn_output_weights_mod = q_mod.cuda() @ k_mod.transpose(1, 2).cuda()
torch.cuda.synchronize()

100 loops, best of 5: 245 ms per loop


**torch.matmul**

In [4]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.matmul(q_mod.cuda(), k_mod.transpose(1, 2).cuda())
torch.cuda.synchronize()

100 loops, best of 5: 244 ms per loop


**eimsum**

In [2]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.einsum('bnm,bmp->bnp', q_mod.cuda(), k_mod.transpose(1, 2).cuda())
torch.cuda.synchronize()

100 loops, best of 5: 229 ms per loop


## Without View

In [15]:
q = q.contiguous().transpose(0, 1)
k = k.contiguous().transpose(0, 1)

q.size(), k.size()

(torch.Size([1024, 5, 48000]), torch.Size([1024, 5, 48000]))

**torch.bmm**

In [16]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.bmm(q.cuda(), k.transpose(1, 2).cuda())
torch.cuda.synchronize()

100 loops, best of 5: 233 ms per loop


**@ operator**

In [17]:
%%timeit -n 100 -r 5
attn_output_weights_mod = q.cuda() @ k.transpose(1, 2).cuda()
torch.cuda.synchronize()

100 loops, best of 5: 251 ms per loop


**torch.matmul**

In [18]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.matmul(q.cuda(), k.transpose(1, 2).cuda())
torch.cuda.synchronize()

100 loops, best of 5: 250 ms per loop


**eimsum**

In [20]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.einsum('bnm,bmp->bnp', q.cuda(), k.transpose(1, 2).cuda())
torch.cuda.synchronize()

100 loops, best of 5: 230 ms per loop


## Check Equivalent Calculation

In [24]:
q = torch.empty(tgt_len, bsz, emb_sz).fill_(2.)
k = torch.empty(tgt_len, bsz, emb_sz).fill_(3.)

In [25]:
q_mod = q.contiguous().view(tgt_len, bsz * n_heads, head_dim).transpose(0, 1)
k_mod = k.contiguous().view(-1, bsz * n_heads, head_dim).transpose(0, 1)

torch.einsum('bnm,bmp->bnp', q_mod.cuda(), k_mod.transpose(1, 2).cuda()).sum()

tensor(7.3728e+09, device='cuda:0')

In [26]:
q = q.contiguous().transpose(0, 1)
k = k.contiguous().transpose(0, 1)

torch.einsum('bnm,bmp->bnp', q.cuda(), k.transpose(1, 2).cuda()).sum()

tensor(7.3728e+09, device='cuda:0')

## torch.eisum speed with 4 dimensions

In [1]:
import torch
from einops import rearrange

tgt_len = src_len = 5
bsz = 1024
n_heads = 24
emb_sz = 48000
head_dim = int(emb_sz / n_heads)

q = torch.empty(tgt_len, bsz, emb_sz).fill_(2.)
k = torch.empty(tgt_len, bsz, emb_sz).fill_(3.)

h = n_heads
q = torch.empty(tgt_len, bsz, emb_sz).fill_(2.)
q.size()

torch.Size([5, 1024, 48000])

In [2]:
rearrange(q, 'b n (h d) -> b h n d', h = h).size()

torch.Size([5, 24, 1024, 2000])

In [3]:
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k = rearrange(k, 'b n (h d) -> b h n d', h = h)

In [4]:
%%timeit -n 100 -r 5
attn_output_weights_mod = torch.einsum('b h i d, b h j d-> b h i j', q.cuda(), k.cuda())
torch.cuda.synchronize()

100 loops, best of 5: 270 ms per loop
