In [1]:
import torch, os, sys
import numpy as np
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
import inspect
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import pandas as pd
from fast_transformers.builders import TransformerEncoderBuilder
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from fast_transformers.masking import FullMask, LengthMask

root_dir = os.path.dirname(os.getcwd())
sys.path.append(root_dir)
from benchmark import get_csv, benchmark

In [2]:
n_layers = 2
num_heads = 2
embed_dim = 32
n_hid = 49

In [3]:
# compute number of parameters
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
torch.manual_seed(0)
# Create the builder for our transformers
builder = TransformerEncoderBuilder.from_kwargs(
    n_layers=n_layers,
    n_heads=num_heads,
    query_dimensions=embed_dim // num_heads,
    value_dimensions=embed_dim // num_heads,
    feed_forward_dimensions=n_hid
)

# Build a transformer with softmax attention
builder.attention_type = "full"
softmax_model = builder.get().to('cuda')

# Build a transformer with linear attention
builder.attention_type = "linear"
linear_model = builder.get().to('cuda')


In [5]:
model = TransformerEncoder(
    TransformerEncoderLayer(
        d_model = embed_dim,
        nhead = num_heads,
        dim_feedforward = n_hid,
        batch_first = True,
    ),
    num_layers = n_layers,
)

# Transformer Layer Comparison

In [6]:
print('Model Size')
print("| Standard Attention |  Their Implementation  |: ", get_num_params(softmax_model))
print("|  Linear Attention  |  Their Implementation  |: ", get_num_params(linear_model))
print("| Standard Attention | Pytorch Implementation |: ", get_num_params(model))

Model Size
| Standard Attention |  Their Implementation  |:  15202
|  Linear Attention  |  Their Implementation  |:  15202
| Standard Attention | Pytorch Implementation |:  15138


In [7]:
print('Inference Time')

batch_size = 10
seq_len = int(1e3)
num_simulations = 1

time_elaspsed = {'softmax':[], 'linear':[], 'pytorch':[]}
for _ in range(num_simulations):
    
    # Construct the dummy input
    X = torch.rand(batch_size, seq_len, embed_dim)

    # Prepare everythin for CUDA
    X = X.cuda()

    model.cuda()
    model.eval()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    with torch.no_grad():
        start.record()
        y = model(X)
        end.record()
        torch.cuda.synchronize()
        time_elaspsed['pytorch'].append(start.elapsed_time(end))

print("| Standard Attention | Pytorch Implementation |: ", f"{np.mean(time_elaspsed['pytorch']):.2f}", f"({np.std(time_elaspsed['pytorch']):.2f})", "ms")

for _ in range(num_simulations):

    softmax_model.cuda()
    softmax_model.eval()
    linear_model.cuda()
    linear_model.eval()

    # Warmup the GPU
    with torch.no_grad():
        # softmax_model(X)
        linear_model(X)
    torch.cuda.synchronize()

    # Measure the execution time
    softmax_start = torch.cuda.Event(enable_timing=True)
    softmax_end = torch.cuda.Event(enable_timing=True)
    linear_start = torch.cuda.Event(enable_timing=True)
    linear_end = torch.cuda.Event(enable_timing=True)


    # with torch.no_grad():
    #     softmax_start.record()
    #     y = softmax_model(X)
    #     softmax_end.record()
    #     torch.cuda.synchronize()
    #     time_elaspsed['softmax'].append(softmax_start.elapsed_time(softmax_end))
        
    #     # Softmax: 144 ms (on a GTX1080Ti)

    with torch.no_grad():
        linear_start.record()
        y = linear_model(X)
        linear_end.record()
        torch.cuda.synchronize()
        time_elaspsed['linear'].append(linear_start.elapsed_time(linear_end))
        
        # Linear: 68 ms (on a GTX1080Ti)

print("| Standard Attention |  Their Implementation  |: ", f"{np.mean(time_elaspsed['softmax']):.2f}", f"({np.std(time_elaspsed['softmax']):.2f})", "ms")
print("|  Linear Attention  |  Their Implementation  |: ", f"{np.mean(time_elaspsed['linear']):.2f}", f"({np.std(time_elaspsed['linear']):.2f})", "ms")

Inference Time
| Standard Attention | Pytorch Implementation |:  36.94 (0.00) ms
| Standard Attention |  Their Implementation  |:  nan (nan) ms
|  Linear Attention  |  Their Implementation  |:  1.82 (0.00) ms


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


# Benchmark

In [6]:
def linear_attention(q = None, k = None, v = None, is_causal = False, return_model = False, **kwargs):
    
    # q, k, v: [batch_size*num_heads, seq_len, embed_dim]
    if is_causal: raise NotImplementedError("Causal attention is not implemented for linear attention")
    
    linear_attn = linear_model.layers[0].attention
    if return_model: return linear_attn
    
    return linear_attn(
        q,k,v,
        attn_mask=None,
        query_lengths=None,
        key_lengths=None,
    )

In [7]:
def linear_tf(x, src_mask = None, is_causal = False):
    if is_causal: raise NotImplementedError("Causal attention is not implemented for linear transformer") 
    if src_mask: raise NotImplementedError("Masking is not implemented for linear transformer")
    
    linear_transformer_layer = linear_model.layers[0]

    if src_mask is None:
        return linear_transformer_layer(x)

In [8]:
def get_model(model_name, **kwargs):
    if model_name == 'flash_attention':
        model = fa
    elif model_name == 'flash_linear_attention':
        model = fla
    elif model_name == 'flash_linear_tf':
        model = fla_tf
    elif model_name == 'flash_tf':
        model = fa_tf
    elif model_name == 'linear_attention':
        model = linear_attention
    elif model_name == 'linear_tf':
        model = linear_tf
    elif model_name == 'simplified_linear_attention':
        from linear_attn_forward import linear_attention as simplified_linear_attention
        model = simplified_linear_attention(default_mask = True, event_dispatcher = True, **kwargs).to('cuda')
    else:
        raise ValueError(f"model_name {model_name} not supported")
    return model

In [9]:
overwrite = True

In [10]:
# self-attention
benchmark('linear_attention', get_model, is_causal = False, self_attn = True, overwrite = overwrite)

In [11]:
# self-attention
benchmark('simplified_linear_attention', get_model, is_causal = False, self_attn = True, overwrite = overwrite, max_len_power = 21)

In [12]:
benchmark('linear_tf', get_model, is_causal = False, self_attn = True, overwrite = overwrite)

In [9]:
# cross-attention
benchmark('linear_attention', get_model, is_causal = False, self_attn = False, overwrite = False)
benchmark('linear_tf', get_model, is_causal = False, self_attn = False, overwrite = True)

Skip head dim 32, 2 heads, length-16.
Skip head dim 32, 2 heads, length-32.
Skip head dim 32, 2 heads, length-64.
Skip head dim 32, 2 heads, length-128.
Skip head dim 32, 2 heads, length-256.
Skip head dim 32, 2 heads, length-512.
Skip head dim 32, 2 heads, length-1024.
Skip head dim 32, 2 heads, length-2048.
Skip head dim 32, 2 heads, length-4096.
Skip head dim 32, 2 heads, length-8192.
Skip head dim 32, 2 heads, length-16384.
Skip head dim 32, 2 heads, length-32768.


TypeError: TransformerEncoderLayer.forward() got an unexpected keyword argument 'src_mask'

# Implementation Check

In [28]:
print(inspect.getsource(linear_model.layers[0].attention.event_dispatcher.dispatch))

    def dispatch(self, event):
        """Dispatch an event to the listeners.

        Arguments
        ---------
            event: Event instance
        """
        for event_handler, event_filter in self._listeners.items():
            if event_filter(event):
                event_handler(event)



In [27]:
# get all the function of a class

dir(linear_model.layers[0].attention.event_dispatcher)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_dispatchers',
 '_listeners',
 'clear',
 'dispatch',
 'get',
 'listen',
 'remove']

In [17]:
print(inspect.getsource(linear_model.layers[0].attention.inner_attention.forward))

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally 

In [13]:
print(inspect.getsource(linear_model.layers[0].attention.inner_attention.forward))

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally 

In [9]:
def run_model(model, model_name, q, k, v, is_causal, self_attn = False):
    if 'tf' in model_name:
        if self_attn: 
            return model(q, is_causal = is_causal)
        else:
            x = torch.concatenate([k, q], dim = 1)
            return model(x, src_mask = k.shape[0], is_causal = is_causal)
    else:
        return model(q, k, v, is_causal = is_causal, need_weights = False)

In [10]:
torch.random.manual_seed(0)
q = torch.randn(4, int(1e5), embed_dim, device = "cuda")
k = q
v = q

In [15]:
model_name = 'linear_attention'
model1 = get_model(model_name, d_model = embed_dim, n_heads = num_heads)
memory_before = torch.cuda.memory_allocated(device="cuda")
torch.cuda.reset_peak_memory_stats(device='cuda')
out1 = run_model(model1, model_name, q, k, v, is_causal = False, self_attn = True)
memory_after = torch.cuda.max_memory_allocated(device="cuda")
print((memory_after - memory_before)/1024/1024, 'MB')

zero 948 MB
1 948 MB
2 10485 MB
begin 20022 MB
init 20022 MB
146 MB
98 MB
301 MB
301 MB
19667.7890625 MB


In [35]:
model_name = 'simplified_linear_attention'
model2 = get_model(model_name, d_model = embed_dim, n_heads = num_heads)
memory_before = torch.cuda.memory_allocated(device="cuda")
torch.cuda.reset_peak_memory_stats(device='cuda')
out2 = run_model(model2, model_name, q, k, v, is_causal = False, self_attn = True)
memory_after = torch.cuda.max_memory_allocated(device="cuda")
print((memory_after - memory_before)/1024/1024, 'MB')

init 1943 MB
146 MB
98 MB
301 MB
301 MB
593.95849609375 MB


In [36]:
get_num_params(model1(return_model=True)), get_num_params(model2)

zero 1943 MB


(4224, 4224)

In [16]:
out1

tensor([[[ 0.1336, -0.0488,  0.0153,  ..., -0.0899, -0.0093,  0.0702],
         [ 0.1328, -0.0508,  0.0197,  ..., -0.0829, -0.0143,  0.0658],
         [ 0.1335, -0.0461,  0.0136,  ..., -0.0863, -0.0101,  0.0698],
         ...,
         [ 0.1402, -0.0447,  0.0101,  ..., -0.0888, -0.0085,  0.0757],
         [ 0.1318, -0.0488,  0.0215,  ..., -0.0851, -0.0136,  0.0669],
         [ 0.1386, -0.0476,  0.0129,  ..., -0.0905, -0.0051,  0.0717]],

        [[ 0.1363, -0.0445,  0.0179,  ..., -0.0867, -0.0094,  0.0728],
         [ 0.1417, -0.0490,  0.0140,  ..., -0.0859, -0.0093,  0.0693],
         [ 0.1401, -0.0488,  0.0130,  ..., -0.0826, -0.0119,  0.0694],
         ...,
         [ 0.1387, -0.0467,  0.0126,  ..., -0.0902, -0.0058,  0.0708],
         [ 0.1370, -0.0471,  0.0149,  ..., -0.0863, -0.0128,  0.0706],
         [ 0.1357, -0.0453,  0.0147,  ..., -0.0839, -0.0108,  0.0720]],

        [[ 0.1377, -0.0540,  0.0188,  ..., -0.0813, -0.0166,  0.0633],
         [ 0.1358, -0.0501,  0.0162,  ..., -0

In [46]:
out2

tensor([[[-0.0769, -0.2752,  0.0800,  ...,  0.0190,  0.1057, -0.0327],
         [-0.0724, -0.2785,  0.0758,  ...,  0.0153,  0.1082, -0.0369],
         [-0.0728, -0.2790,  0.0752,  ...,  0.0193,  0.1057, -0.0334],
         ...,
         [-0.0785, -0.2749,  0.0829,  ...,  0.0257,  0.1028, -0.0269],
         [-0.0739, -0.2791,  0.0716,  ...,  0.0202,  0.1076, -0.0367],
         [-0.0730, -0.2789,  0.0733,  ...,  0.0174,  0.1085, -0.0368]],

        [[-0.0721, -0.2755,  0.0802,  ...,  0.0214,  0.1057, -0.0294],
         [-0.0742, -0.2782,  0.0734,  ...,  0.0203,  0.1065, -0.0372],
         [-0.0752, -0.2760,  0.0787,  ...,  0.0213,  0.1037, -0.0304],
         ...,
         [-0.0727, -0.2758,  0.0814,  ...,  0.0247,  0.1030, -0.0257],
         [-0.0761, -0.2781,  0.0779,  ...,  0.0280,  0.1048, -0.0293],
         [-0.0760, -0.2721,  0.0851,  ...,  0.0336,  0.1062, -0.0236]],

        [[-0.0719, -0.2774,  0.0760,  ...,  0.0216,  0.1070, -0.0319],
         [-0.0759, -0.2768,  0.0812,  ...,  0

In [39]:
FullMask(10, device=q.device).float_matrix

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')