Skip to content

Commit

Permalink
Merge pull request #218 from ShikharJ/Test
Browse files Browse the repository at this point in the history
* Refactor FastGRNN CUDA Setup

* Update c_reference README.md

* Update pytorch/requirements

* Add FastGRNNCUDA to FastCells and Fix torch.randn() Argument Errors

* Update README
  • Loading branch information
harsha-simhadri committed Dec 20, 2020
2 parents cbba9f8 + f10b009 commit 5f0b6e8
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 40 deletions.
2 changes: 1 addition & 1 deletion c_reference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and is to be adapted as needed for other embedded platforms.
The `EdgeML/c_reference/` directory is broadly structured into the following sub-directories:

- **include/**: Contains the header files for various lower level operators and layers.
- **models/**: Contains the optimized source code and header files for various models built by stiching together different layers and operators. Also contains the layer weights and hyper-parameters for the corresponding models as well (stored using `Git LFS`).
- **models/**: Contains the optimized source code and header files for various models built by stiching together different layers and operators. Also contains the layer weights and hyper-parameters for the corresponding models as well (stored using `Git LFS`). (**Note:** Cloning the repo without installing `Git LFS` would fail to clone the actual headers. It's recommended to follow instructions on setting up `LFS` from [here](https://git-lfs.github.com/) before cloning.)
- **src/**: Contains the optimized source code files for various lower level operators and layers.
- **tests/**: Contains extensive test cases for individual operators and layers, as well as the implemented models. The executables are generated in the main directory itself, while the test scripts and their configurations can be accessed in the appropriate sub-directories.

Expand Down
10 changes: 5 additions & 5 deletions examples/pytorch/FastCells/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ features like low-rank parameterisation and custom non-linearities. Akin to
Bonsai and ProtoNN, the three-phase training routine for FastRNN and FastGRNN
is decoupled from the custom cells to facilitate a plug and play behaviour of
the custom RNN cells in other architectures (NMT, Encoder-Decoder etc.).
Additionally, numerically equivalent CUDA-based implementations FastRNNCuda
and FastGRNNCuda are provided for faster training.
Additionally, numerically equivalent CUDA-based implementations **FastRNNCUDA**
and **FastGRNNCUDA** are provided for faster training.
`edgeml_pytorch.graph.rnn` also contains modified RNN cells of **UGRNNCell**,
**GRUCell**, and **LSTMCell**, which can be substituted for Fast(G)RNN,
as well as untrolled RNNs which are equivalent to `nn.LSTM` and `nn.GRU`.
Expand Down Expand Up @@ -67,9 +67,9 @@ Final Test Accuracy: 0.9347
Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False
```
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or
`FastGRNNResults.txt` depending on the choice of the RNN cell. A directory `FastRNNResults` or
`FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset.
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt`,
`FastGRNNResults.txt` or `FastGRNNCUDAResults.txt` depending on the choice of the RNN cell. A directory `FastRNNResults`,
`FastGRNNResults` or `FastGRNNCUDAResults` with the corresponding models with each run of the code on the `usps10` dataset.

Note that the scalars like `alpha`, `beta`, `zeta` and `nu` correspond to the values before
the application of the sigmoid function.
Expand Down
5 changes: 5 additions & 0 deletions examples/pytorch/FastCells/fastcell_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def main():
gate_nonlinearity=gate_non_linearity,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastGRNNCUDA":
FastCell = FastGRNNCUDACell(inputDims, hiddenDims,
gate_nonlinearity=gate_non_linearity,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastRNN":
FastCell = FastRNNCell(inputDims, hiddenDims,
update_nonlinearity=update_non_linearity,
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/FastCells/helpermethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def getArgs():
'train.npy and test.npy')

parser.add_argument('-c', '--cell', type=str, default="FastGRNN",
help='Choose between [FastGRNN, FastRNN, UGRNN' +
', GRU, LSTM], default: FastGRNN')
help='Choose between [FastGRNN, FastGRNNCUDA, FastRNN,' +
' UGRNN, GRU, LSTM], default: FastGRNN')

parser.add_argument('-id', '--input-dim', type=checkIntNneg, required=True,
help='Input Dimension of RNN, each timestep will ' +
Expand Down
1 change: 1 addition & 0 deletions pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Install appropriate CUDA and cuDNN [Tested with >= CUDA 8.1 and cuDNN >= 6.1]
```
pip install -r requirements-gpu.txt
pip install -e .
pip install -e edgeml_pytorch/cuda/
```

**Note**: For using the optimized FastGRNNCUDA implementation, it is recommended to use CUDA v10.1, gcc 7.5 and cuDNN v7.6 and torch==1.4.0. Also, there are some known issues when compiling custom CUDA kernels on Windows [pytorch/#11004](https://github.com/pytorch/pytorch/issues/11004).
Expand Down
18 changes: 18 additions & 0 deletions pytorch/edgeml_pytorch/cuda/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import setuptools #enables develop
import os
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from edgeml_pytorch.utils import findCUDA

if findCUDA() is not None:
setuptools.setup(
name='fastgrnn_cuda',
ext_modules=[
CUDAExtension('fastgrnn_cuda', [
'fastgrnn_cuda.cpp',
'fastgrnn_cuda_kernel.cu',
]),
],
cmdclass={
'build_ext': BuildExtension
}
)
28 changes: 16 additions & 12 deletions pytorch/edgeml_pytorch/graph/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@

import edgeml_pytorch.utils as utils

if utils.findCUDA() is not None:
import fastgrnn_cuda
try:
if utils.findCUDA() is not None:
import fastgrnn_cuda
except:
print("Running without FastGRNN CUDA")
pass


# All the matrix vector computations of the form Wx are done
Expand Down Expand Up @@ -351,29 +355,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
self._name = name

if wRank is None:
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device))
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device))
self.W1 = torch.empty(0)
self.W2 = torch.empty(0)
else:
self.W = torch.empty(0)
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device))
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device))
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device))
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device))

if uRank is None:
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device))
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device))
self.U1 = torch.empty(0)
self.U2 = torch.empty(0)
else:
self.U = torch.empty(0)
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device))
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device))
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device))
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device))

self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]

self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device))
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device))
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device))
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device))

@property
def name(self):
Expand Down
4 changes: 2 additions & 2 deletions pytorch/requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ numpy==1.16.4
pandas==0.23.4
scikit-learn==0.21.2
scipy==1.3.0
torch
torchvision
torch==1.4.0
torchvision==0.5.0
requests
4 changes: 2 additions & 2 deletions pytorch/requirements-gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ numpy==1.16.4
pandas==0.23.4
scikit-learn==0.21.2
scipy==1.3.0
torch
torchvision
torch==1.4.0
torchvision==0.5.0
requests
16 changes: 0 additions & 16 deletions pytorch/setup.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,5 @@
import setuptools #enables develop
import os
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from edgeml_pytorch.utils import findCUDA

if findCUDA() is not None:
setuptools.setup(
name='fastgrnn_cuda',
ext_modules=[
CUDAExtension('fastgrnn_cuda', [
'edgeml_pytorch/cuda/fastgrnn_cuda.cpp',
'edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu',
]),
],
cmdclass={
'build_ext': BuildExtension
}
)

setuptools.setup(
name='edgeml',
Expand Down

0 comments on commit 5f0b6e8

Please sign in to comment.