Skip to content

huntzhan/pytorch-stateful-lstm

Repository files navigation

pytorch-stateful-lstm

  • Free software: MIT license

Features

Pytorch LSTM implementation powered by Libtorch, and with the support of:

  • Hidden/Cell Clip.
  • Skip Connections.
  • Variational Dropout & DropConnect.
  • Managed Initial State.
  • Built-in TBPTT.

Benchmark: https://github.com/cnt-dev/pytorch-stateful-lstm/tree/master/benchmark

Install

Prerequisite: torch>=1.0.0, supported C++11 compiler (see here). To install through pip:

pip install pytorch-stateful-lstm

Usage

Example:

import torch
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence
from pytorch_stateful_lstm import StatefulUnidirectionalLstm

lstm = StatefulUnidirectionalLstm(
        num_layers=2,
        input_size=3,
        hidden_size=5,
        cell_size=7,
)

inputs = pack_padded_sequence(torch.rand(4, 5, 3), [5, 4, 2, 1], batch_first=True)
raw_packed_outputs, lstm_state = lstm(
        inputs.data,
        inputs.batch_sizes
)
outputs = PackedSequence(raw_packed_outputs, inputs.batch_sizes)

For the definition of parameters, see https://github.com/cnt-dev/pytorch-stateful-lstm/tree/master/extension.

Credits

This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.

About

Pytorch LSTM implementation powered by Libtorch

Resources

License

Stars

Watchers

Forks

Packages

No packages published