Skip to content

Overview of the fastest CPU RNNs implementation #228

@mratsim

Description

@mratsim

RNNs and particularly LSTM and GRU made a significant contribution to deep learning applications.

They are the default go-to tool for natural language processing, are heavily explored in reinforcement learning, many visual+text combined tasks and time-series prediction (though in competition with WaveNets)

CuDNN implementation is already heavily optimized however CPU implementation should be the fastest possible as well.

General overview

  • GRU Paper
  • CS231n 2017 - lecture 10
  • Colah tutorial
  • Towards Data Science
  • Tensorflow vs PyTorch/CuDNN
    Tensorflow
    r = sigmoid(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) 
    z = sigmoid(W_{iz} x + b_{iz} + W_{hz} h + b_{hz})
    n = tanh(W_{in} x + b_{in} +  W_{hn} (r * h) + b_{hn}))
    h' = (1 - z) * n + z * h
    
    PyTorch equations
    r = sigmoid(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) 
    z = sigmoid(W_{iz} x + b_{iz} + W_{hz} h + b_{hz})
    n = tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn}))
    h' = (1 - z) * n + z * h
    
    Note that in the paper equations are:
    r = sigmoid(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) 
    z = sigmoid(W_{iz} x + b_{iz} + W_{hz} h + b_{hz})
    n = tanh(W_{in} x + b_{in} +  W_{hn} (r * h) + b_{hn}))
    h' = (1 - z) * h + z * n
    
    And CuDNN
    it = σ(Wi * xt + Ri * ht-1 + bWi + bRu)
    rt = σ(Wr * xt + Rr * ht-1 + bWr + bRr)
    h't = tanh(Wh * xt + rt ◦ (Rh * ht-1 + bRh) + bWh)
    ht = (1 - it)◦h't + it◦ht-1
    

Readable implementation

"Unreadable" C++ implementations (static graphs)

Benchmarks

Unfortunately only GPU benchs are available:

Optimized implementations

Note on biases and equations

The various implementations do not agree on biases, and the equations chosen.

  • WildML has 1 bias per equation, Keras and Neon too.
  • Chainer, Torch and CuDNN have 2 biases.

To allow loading weights on both CPU and GPU, it would be best to use the same equations as CuDNN.

List of relevant issues:

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions