In [3]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt


- $D$: `dim`
- $D_{kv}$: `kv_dim`
- $D_h$: `hidden_dim`
- $B$: `batch_size`
- $V$: `vocab_size`
- $N_{h}$: `num_heads`
- $L_{max}$ : `seq_len` (max)
- $L$ : `seq_len` (current)

## Pre-attention

Operation | Parameters | Inputs | Outputs | Complexity (FLOPs) | Complexity (Memory)
--- | --- | --- | --- | --- | ---
`rmsnorm` | `rms_att_weights` ($D \times 1$) | `state->x` ($B \times D$) | `state->x` ($B \times D$) | $\mathcal{O}(BD)$ | $\mathcal{O}(BD)$
`matmul` | `Wq` ($D \times D$) | `state->x` ($B \times D$) | `state->q` ($B \times D$) | $\mathcal{O}(BD^2)$ | $\mathcal{O}(D^2)$
`matmul` | `Wkv` ($D \times 2D_{kv}$) | `state->x` ($B \times D$) | `state->kv` ($B \times 2D_{kv}$) | $\mathcal{O}(2BDD_{kv})$ | $\mathcal{O}(2DD_{kv} + BD + 2BD_{kv})$
Total | - | - | - | $\mathcal{O}(BD^2)$ | $\mathcal{O}(D^2)$

## Attention

Operation | Parameters | Inputs | Outputs | Complexity (FLOPs) | Complexity (Memory)
--- | --- | --- | --- | --- | ---
`rope` | `freq_cis_real` ($L_{max} \times \frac{D}{2}$)<br>`freq_cis_imag` ($S \times \frac{D}{2}$) | ?? | ??
`matmul` | - | `state->q` ($B \times D$)<br>`context` ($B \times L \times 2D_{kv}$) | `state->att` ($B \times N_{h} \times L$) | $\mathcal{O}(BDL)$ | $\mathcal{O}(BDL)$
`softmax` | - | `state->att` ($B \times N_{h} \times L$) | `state->att` ($B \times N_{h} \times L$) | - | -
`matmul` | - | `state->att` ($B \times N_{h} \times L$)<br> `context` ($B \times L \times 2D_{kv}$) | `state->x` ($B \times D$) | $\mathcal{O}(BDL)$ | $\mathcal{O}(BDL)$
Total | - | - | - | $\mathcal{O}(BDL)$ | $\mathcal{O}(BDL)$

## Post-attention

Operation | Parameters | Inputs | Outputs | Complexity (FLOPs) | Complexity (Memory)
--- | --- | --- | --- | --- | ---
`matmul` | `Wo` ($D \times D$) | `state->x` ($B \times D$) | `state->x` ($B \times D$) | $\mathcal{O}(BD^2)$ | $\mathcal{O}(D^2)$
`accum` | - | `state->x` ($B \times D$) | `state->x` ($B \times D$) | $\mathcal{O}(BD)$ | $\mathcal{O}(BD)$
`rmsnorm` | `rms_ffn_weight` ($D \times 1$) | `state->x` ($B \times D$) | `state->x` ($B \times D$) | $\mathcal{O}(BD)$ | $\mathcal{O}(BD)$
`matmul` | `W1` ($D \times D_h$) | `state->x` ($B \times D$) | `state->h` ($B \times D_h$) | $\mathcal{O}(BD_hD)$ | $\mathcal{O}(DD_h)$
`matmul` | `W3` ($D \times D_h$) | `state->x` ($B \times D$) | `state->h` ($B \times D_h$) | $\mathcal{O}(BD_hD)$ | $\mathcal{O}(DD_h)$
`silu` | - | `state->h` ($B \times D_h$) | `state->h` ($B \times D_h$) | $\mathcal{O}(BD_h)$ | $\mathcal{O}(BD_h)$
`matmul`| `W2` ($D_h \times D$) | `state->h` ($B \times D_h$) | `state->x` ($B \times D$) | $\mathcal{O}(BD_hD)$ | $\mathcal{O}(DD_h)$
`accum` | - | `state->x` ($B \times D$) | `state->x` ($B \times D$) | $\mathcal{O}(BD)$ | $\mathcal{O}(BD)$
Total | - | - | - | $\mathcal{O}(BD^2)$ | $\mathcal{O}(D^2)$

## Classify

Operation | Parameters | Inputs | Outputs | Complexity (FLOPs) | Complexity (Memory)
--- | --- | --- | --- | --- | ---
`rmsnorm` | `rms_final_weight` ($D \times 1$) | `state->x` ($B \times D$) | `state->x` ($B \times D$) | $\mathcal{O}(BD)$ | $\mathcal{O}(BD)$
`matmul` | `wcls` ($D \times V$) | `state->x` ($B \times D$) | `state->logits` ($B \times V$) | $\mathcal{O}(BDV)$ | $\mathcal{O}(DV)$
Total | - | - | - | $\mathcal{O}(BDV)$ | $\mathcal{O}(DV)$
