Why implement an LSTM in 2025? Mostly as an exercise to improve my gpu programming and algorithmic reasoning skills given the non-trivial dependency structure within an LSTM. There also wasn't, to the best of my knowledge, a fast open source LSTM implementation.
This repo attempts to provide an reasonably performant LSTM implementation using pytorch and triton. The next section contains a speed comparison with nn.LSTM. Next, we go into the implementation details beginning with the relevant equations, followed by comments about the high-level design choices.
The tables below compare median runtimes, computed with triton.testing.do_bench on an RTX 2000 Ada gpu, between FastRNN and nn.LSTM in single precision.
More figures and the raw data can be found in ./figures.
The numbers here are not perfectly correlated with the iterations/second produced in scripts/overfit.py.
When re-running the benchmark without re-tuning the kernels and the graph/persistent kernel boundary I find:
- In the forward pass, there are up to 30% gains when using bf16 or fp16 in FastLSTM
- Combining forward and backward pass, the gains are order 40%
nn.LSTMin full bfloat16 seems broken: it runs in most configurations a few times slower than the single precision version. fp16 version runs as expected. Usingtorch.autocastalleviates the issuenn.LSTMgains slightly more from fp16 thanFastLSTMbut that might be down to the lack of re-tuning
I have not investigated lower precision such as fp8 or even lower dtypes.
An LSTM is a recurrent neural net that can't be (fully) parallelised. The forward pass is given by:
$i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \equiv \sigma(pi_t)$ $f_n = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \equiv \sigma(pf_t)$ $g_n = \text{tanh}(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \equiv \sigma(pg_t)$ $o_n = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \equiv \sigma(po_t)$ $c_n = f_t \cdot c_{t-1} + i_t \cdot g_t$ $h_n = o_t \cdot \text{tanh}(c_t)$
Where:
To simplify things, we'll define
For the backward pass, the following derivaties are given:
-
$d h_t = d h_t + d h_{t+1}$ - what this means is the hidden state at a point in time gets gradient contribution from the next layers as well as the next time step $d c_t = d c_t + d h_t \sigma(po_t) (1- \text{tanh}^2(c_t))$ $d po_t = d h_t \text{tanh}(c_t) \sigma(po) * (1-\sigma(po))$ $d c_{t-1} = d c_t \sigma(pf_t)$ $d pi_{t} = d c_t \text{tanh}(pg_t) ( 1- \text{tanh}^2(pg_t))$ $d pg_{t} = d c_t \sigma(pi_t) (1 - \text{tanh}^2(pg_t))$ $d \text{ifgo} = [d pi, d pf, d pf, d po]$ $d h_{t-1} = d \text{ifgo}_t W_h$ $d x_t = d \text{ifgo}_t W_x$ $d Wx += d \text{ifgo}_t^T x$ - $d Wh += d \text{ifgo}t^T h{t-1}$
$d b_{x|h} += d \text{ifgo}_t.sum(\text{batch})$
We begin by computing torch.backends.cudnn.rnn.fp32_precision = "tf32"), you need to turn this explicitly on for matmuls in order to be competitive.
Next, we need to work on the second, non-parallelisable components of
- using a one time step kernel that gets called n-times, synchronization occurs naturally between kernel launches
- using a persistent kernel approach where the seq_len loop is within the kernel and the channel sync is implemented by atomics
Trade-offs:
- with the one step kernel, we need to pay n launch overheads. This can be mitigated by using cuda graphs (but you still pay the capture once). The advantage is that each such kernel is close to a standard matmul and thus efficient
- using persistent kernels, we only pay the launch overhead once and don't need to capture the kernel either. The dowside is that you might not use all resouces efficiently (eg. how to tile your (
batch-size,hidden-size) problem in a way to fully use 22 SMs on a RTX 2000 Ada?). In addition, for largehidden-sizes you might need to work with suboptimal tiles (long hidden and short batch chunks) in order to prevent deadlocks. - In practice, the persistent kernel does well for smaller problem sizes and the graph for larger. So you kind of want to combine both.
To finish the forward pass, we need to fuse all non-linearities and pointwise operations into this aforementioned kernels.
To avoid recomputations, we first need to decide what parts of the forward pass we want to reuse. Looking at the derivatives points to
Let's start with
To compute batch-size, 4 hidden-size) to (batch-size, hidden-size), i.e. there are fewer parallelisation options across SMs. Regarding the one-step kernel there is a small annoyance: in the forward pass we had channel-mixing followed by non-linearities and point-wise operations. In this order it's easy to fuse the two steps. In the backward pass it is naturally thye other way around which requires either an extra sync done by breaking the one-step kernel into two pieces: h-grad and then ifgo-grad or you 'shift the kernel by half a time step' and have some extra code for the boundary condition at the end and beginning of the sequence. The latter options runs slightly faster than the two kernel approach.
- to use
autotune, your kernels much be idempotent! - use atomics to implement cross-program synchronization (but be careful with deadlocks)
- only tuned for RTX 2000 Ada
nn.LSTMhas extra optionsbias,batch_first,dropout,bidirectional,proj_sizethatFastLSTMdoesn't support yet- can't operate on
PackedSequenceinput - accuracy of half-precision might be suboptimal: I did not think carefully about whether some components need to remain in single precision
- doesn't exploit
num_layerparallelisation dimension (whichnn.LSTMdoes!) - not battletested
FlashRNN attempts to implement, among other things, a fast LSTM implementation in their repo. However, as their figure 2 demonstrates: the implemenation is much slower than nn.LSTM and only when introducing lots of sparsity can they match the runtime. Eg. For a batch-size 16, seq-len 1024, hidden-size 768 and running on an H100, the parameter count needs to be reduced by a factor 24 to match nn.LSTM's runtime.

