In [1]:
import torch
from mothernet.models.flash_linear_attention import FlashLinearAttention

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size, num_heads, seq_len, hidden_size,  = 32, 3, 2048, 9
device, dtype = 'cuda:0', torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)

In [3]:
fla = FlashLinearAttention(
    mode = 'chunk',
    hidden_size = hidden_size,
    expand_k = 1,
    expand_v = 1, 
    num_heads = num_heads,
).to(device=device, dtype=dtype)

In [5]:
# compute number of parameters
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

In [6]:
get_num_params(fla)

687

In [11]:
for name, param in fla.named_parameters():
    print(f"| {name} | shape {param.shape} | number of params {param.numel()} |")

| feature_map_q.layer1.weight | shape torch.Size([9, 9]) | number of params 81 |
| feature_map_q.layer1.bias | shape torch.Size([9]) | number of params 9 |
| feature_map_q.layer2.weight | shape torch.Size([9, 9]) | number of params 81 |
| feature_map_q.layer2.bias | shape torch.Size([9]) | number of params 9 |
| feature_map_k.layer1.weight | shape torch.Size([9, 9]) | number of params 81 |
| feature_map_k.layer1.bias | shape torch.Size([9]) | number of params 9 |
| feature_map_k.layer2.weight | shape torch.Size([9, 9]) | number of params 81 |
| feature_map_k.layer2.bias | shape torch.Size([9]) | number of params 9 |
| q_proj.weight | shape torch.Size([9, 9]) | number of params 81 |
| k_proj.weight | shape torch.Size([9, 9]) | number of params 81 |
| v_proj.weight | shape torch.Size([9, 9]) | number of params 81 |
| norm.weight | shape torch.Size([3]) | number of params 3 |
| o_proj.weight | shape torch.Size([9, 9]) | number of params 81 |


In [5]:
fla(x[:,:(seq_len//2),:], x, x)

(tensor([[[ 1.8030e-06, -1.9073e-06, -2.8685e-07,  ..., -1.8552e-06,
            4.3213e-07, -9.9838e-07],
          [ 1.3597e-07, -4.0419e-07,  9.0804e-08,  ..., -6.4448e-07,
           -1.5348e-06, -7.2643e-08],
          [-1.5140e-05,  1.1362e-07,  5.3644e-07,  ...,  3.0249e-06,
           -8.0466e-07, -4.5598e-06],
          ...,
          [ 3.6812e-04,  8.7357e-04,  5.9891e-04,  ...,  4.5013e-04,
           -1.4973e-04,  5.1117e-04],
          [-3.7193e-05,  1.9646e-04, -2.6131e-04,  ..., -4.0245e-04,
           -5.1737e-05, -1.7881e-05],
          [-1.5163e-04,  9.5367e-05,  1.7643e-04,  ...,  8.4400e-05,
            2.8014e-05,  5.6839e-04]],
 
         [[ 4.1389e-04,  2.6703e-04, -7.0315e-08,  ...,  2.4033e-04,
            4.1485e-05,  2.5392e-05],
          [-2.5749e-04, -9.1171e-04,  3.5667e-04,  ...,  2.5177e-04,
           -7.2956e-05, -4.1771e-04],
          [ 8.8215e-05, -5.4932e-04, -5.4550e-04,  ..., -8.5831e-05,
           -2.1172e-04, -3.2234e-04],
          ...,
    