Skip to content

jenkspt/online-softmax-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Jax Online Softmax

This is a jax implementation of Online normalizer calculation for softmax. The idea here was to experiment with ways of making attention slightly faster. online_softmax calculates the max value c = max(x) and the normalizer n = exp(x - c) in a single parallel reduction.

The online_softmax_dot combines a softmax and dot product (i.e. sum(softmax(x) * v) into a single parallel reduction.

As it turns out, neither of these are reliably faster than the naive jitted jax implementations, however online_softmax_dot should be more memory efficient. Unfortunately for reasons somewhat unclear to me, it also has numerical stability issues for large input sizes.

online_softmax_dot could potentially have a faster backward pass as well -- but since the forward pass results aren't entirely promising, I'm not going to implement the backward pass.

Online Softmax benchmark plot

Online Softmax dot product benchmark plot

These benchmarks have similiar results on TPUv2 and T4 GPU

Citations

@misc{https://doi.org/10.48550/arxiv.1805.02867,
  doi = {10.48550/ARXIV.1805.02867},
  
  url = {https://arxiv.org/abs/1805.02867},
  
  author = {Milakov, Maxim and Gimelshein, Natalia},
  
  keywords = {Performance (cs.PF), Artificial Intelligence (cs.AI), Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {Online normalizer calculation for softmax},
  
  publisher = {arXiv},
  
  year = {2018},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}

About

jax implementation of online softmax algorithm

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages