<a href="https://colab.research.google.com/github/morganmcg1/reformer-fastai/blob/main/exploration/lm_generation_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import sys
# if 'google.colab' in sys.modules:
#     !pip install -Uqq einops

In [1]:
import sys
sys.path.append("..")

import torch
import torch.autograd.profiler as profiler

# Timing

In [2]:
def do_cuda_timing(f, inp, context=None, n_loops=100):
    f.cuda()
    inp = inp.cuda()
    if context is not None: context = context.cuda()
    with profiler.profile(record_shapes=False, use_cuda=True) as prof:
        with profiler.record_function("model_inference"):
            with torch.no_grad():
                for _ in range(n_loops):
                    if context is None: f(inp)
                    else: f(inp, context)
                    torch.cuda.synchronize()
                    
    res = round((prof.key_averages().self_cpu_time_total / 1000) / n_loops, 3)
    print(f'{res}ms')
    return res

In [3]:
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [4]:
# With memory
#with profiler.profile(profile_memory=True, record_shapes=True) as prof:
#     arto_encoder_layer(src)

# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

In [5]:
comparison = {}

In [6]:
# Input 
pt_src = torch.rand(128, 64, 512).cuda()
arto_src = torch.rand(64, 128, 512).cuda()

pt_context = torch.rand(128, 64, 512).cuda()
arto_context = torch.rand(64, 128, 512).cuda()

# Encoder 

### Encoder Layer

In [28]:
from torch.nn import TransformerEncoderLayer
from basic_transformer import TransformerEncoderBlock as artoTransformerEncoderBlock

# Pytorch
pt_encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048, 
                                              dropout=0.1, activation="gelu")

# ARTO
arto_encoder_layer = artoTransformerEncoderBlock(dim=512, heads=8)

Do timing

In [29]:
pt = do_cuda_timing(pt_encoder_layer, pt_src)
t = do_cuda_timing(arto_encoder_layer, arto_src)

comparison['EncoderLayer'] = {'Arto': t, 'pt': pt}

7.508ms
6.885ms


### Encoder

6-layer Encoder test

In [30]:
from torch.nn import TransformerEncoder
from basic_transformer import TransformerEncoder as artoTransformerEncoder

# Pytorch
pt_encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048, 
                                              dropout=0.1, activation="gelu")
pt_encoder = TransformerEncoder(pt_encoder_layer, num_layers=6)


# Arto
arto_encoder = artoTransformerEncoder(dim=512, depth=6, heads=8,
                                      causal=False, attn_dropout=0.1,
                                      ff_dropout=0.1)

In [31]:
pt = do_cuda_timing(pt_encoder, pt_src)
t = do_cuda_timing(arto_encoder, arto_src)

comparison['Encoder'] = {'Arto': t, 'pt': pt}

43.618ms
41.954ms


# Decoder

### DecoderLayer

In [32]:
from torch.nn import TransformerDecoderLayer
from basic_transformer import TransformerDecoderBlock as artoTransformerDecoderBlock

# Pytorch
pt_decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8, dim_feedforward=2048, 
                                              dropout=0.1, activation="gelu")

# ARTO
arto_decoder_layer = artoTransformerDecoderBlock(dim=512, heads=8)

Do timing

In [33]:
pt = do_cuda_timing(pt_decoder_layer, pt_src, context=pt_context)

t = do_cuda_timing(arto_decoder_layer, arto_src, context=arto_context)

comparison['DecoderLayer'] = {'pt': pt, 'Arto': t}

11.097ms
9.944ms


### Decoder

6-layer Decoder test

In [34]:
from torch.nn import TransformerDecoder
from basic_transformer import TransformerDecoder as artoTransformerDecoder

# Pytorch
pt_decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8, dim_feedforward=2048, 
                                              dropout=0.1, activation="gelu")
pt_decoder = TransformerDecoder(pt_decoder_layer, num_layers=6)

# Arto
arto_decoder = artoTransformerDecoder(dim=512, depth=6, heads=8,
                                      attn_dropout=0.1,
                                      ff_dropout=0.1)

In [35]:
pt = do_cuda_timing(pt_decoder, pt_src, context=pt_context)

t = do_cuda_timing(arto_decoder, arto_src, context=arto_context)

comparison['Decoder'] = {'pt': pt, 'Arto':t}

64.522ms
60.558ms


In [36]:
for k in comparison.keys():
    print(k)
    for kk in comparison[k].keys():
        print(f'{kk} : {comparison[k][kk]}ms')
    print()

EncoderLayer
Arto : 6.885ms
pt : 7.508ms

Encoder
Arto : 41.954ms
pt : 43.618ms

DecoderLayer
pt : 11.097ms
Arto : 9.944ms

Decoder
pt : 64.522ms
Arto : 60.558ms

