Skip to content

Commit

Permalink
FastCell pytorch working version on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
adityakusupati committed Jun 21, 2019
1 parent 67d2e39 commit e77b36c
Show file tree
Hide file tree
Showing 8 changed files with 2,164 additions and 0 deletions.
81 changes: 81 additions & 0 deletions pytorch/examples/FastCells/README.md
@@ -0,0 +1,81 @@
# EdgeML FastCells on a sample public dataset

This directory includes example notebook and general execution script of
FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified
UGRNN, GRU and LSTM to support the LSQ training routine.
Also, we include a sample cleanup and use-case on the USPS10 public dataset.

`edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with
multiple additional features like Low-Rank parameterisation, custom
non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training
routine for FastRNN and FastGRNN is decoupled from the custom cells to
facilitate a plug and play behaviour of the custom RNN cells in other
architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell` etc.,
`edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../edgeml/graph/rnn.py#L862)),
**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L635)) and **LSTM** ([`LSTMLRCell`](../../edgeml/graph/rnn.py#L376)). These cells also can be substituted for FastCells where ever feasible.

For training FastCells, `edgeml.trainer.fastTrainer` implements the three-phase
FastCell training routine in Tensorflow. A simple example,
`examples/fastcell_example.py` is provided to illustrate its usage.

Note that `fastcell_example.py` assumes that data is in a specific format. It
is assumed that train and test data is contained in two files, `train.npy` and
`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples,
numberOfFeatures]`. numberOfFeatures is `timesteps x inputDims`, flattened
across timestep dimension. So the input of 1st timestep followed by second and
so on. For an N-Class problem, we assume the labels are integers from 0
through N-1. Lastly, the training data, `train.npy`, is assumed to well shuffled
as the training routine doesn't shuffle internally.

**Tested With:** Tensorflow >1.6 with Python 2 and Python 3

## Download and clean up sample dataset

We will be testing out the validation of the code by using the USPS dataset.
The download and cleanup of the dataset to match the above-mentioned format is
done by the script [fetch_usps.py](fetch_usps.py) and
[process_usps.py](process_usps.py)

```
python fetch_usps.py
python process_usps.py
```


## Sample command for FastCells on USPS10
The following sample run on usps10 should validate your library:

Note: Even though usps10 is not a time-series dataset, it can be assumed as, a time-series where each row is coming in at one single time.
So the number of timesteps = 16 and inputDims = 16

```bash
python fastcell_example.py -dir usps10/ -id 16 -hd 32
```
This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches):

```
Maximum Test accuracy at compressed model size(including early stopping): 0.9407075 at Epoch: 262
Final Test Accuracy: 0.93721974
Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False
```
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or `FastGRNNResults.txt` depending on the choice of the RNN cell.
A directory `FastRNNResults` or `FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset.

Note that the scalars like `alpha`, `beta`, `zeta` and `nu` are all before the application of the sigmoid function over them.

## Byte Quantization(Q) for model compression
If you wish to quantize the generated model to use byte quantized integers use `quantizeFastModels.py`. Usage Instructions:

```
python quantizeFastModels.py -h
```

This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedFastModel` inside the model directory.
One can use this model further on edge devices.

Note that the scalars like `qalpha`, `qbeta`, `qzeta` and `qnu` are all after the application of the sigmoid function over them and quantization, they can be directly plugged into the inference pipleines.

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT license.
89 changes: 89 additions & 0 deletions pytorch/examples/FastCells/fastcell_example.py
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import helpermethods
import torch
import numpy as np
import sys
from pytorch_edgeml.graph.rnn import *
from pytorch_edgeml.trainer.fastTrainer import FastTrainer


def main():
# Fixing seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Hyper Param pre-processing
args = helpermethods.getArgs()

dataDir = args.data_dir
cell = args.cell
inputDims = args.input_dim
hiddenDims = args.hidden_dim

totalEpochs = args.epochs
learningRate = args.learning_rate
outFile = args.output_file
batchSize = args.batch_size
decayStep = args.decay_step
decayRate = args.decay_rate

wRank = args.wRank
uRank = args.uRank

sW = args.sW
sU = args.sU

update_non_linearity = args.update_nl
gate_non_linearity = args.gate_nl

(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
mean, std) = helpermethods.preProcessData(dataDir)

assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
"Timesteps have to be integer"

currDir = helpermethods.createTimeStampDir(dataDir, cell)

helpermethods.dumpCommand(sys.argv, currDir)
helpermethods.saveMeanStd(mean, std, currDir)

if cell == "FastGRNN":
FastCell = FastGRNNCell(inputDims, hiddenDims,
gate_non_linearity=gate_non_linearity,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastRNN":
FastCell = FastRNNCell(inputDims, hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "UGRNN":
FastCell = UGRNNLRCell(inputDims, hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "GRU":
FastCell = GRULRCell(inputDims, hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "LSTM":
FastCell = LSTMLRCell(inputDims, hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
else:
sys.exit('Exiting: No Such Cell as ' + cell)

FastCellTrainer = FastTrainer(FastCell, numClasses, sW=sW, sU=sU,
learningRate=learningRate, outFile=outFile)

FastCellTrainer.train(batchSize, totalEpochs,
torch.from_numpy(Xtrain.astype(np.float32)),
torch.from_numpy(Xtest.astype(np.float32)),
torch.from_numpy(Ytrain.astype(np.float32)),
torch.from_numpy(Ytest.astype(np.float32)),
decayStep, decayRate, dataDir, currDir)


if __name__ == '__main__':

main()
66 changes: 66 additions & 0 deletions pytorch/examples/FastCells/fetch_usps.py
@@ -0,0 +1,66 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Setting up the USPS Data.

import bz2
import os
import subprocess
import sys

import requests
import numpy as np
from sklearn.datasets import load_svmlight_file
from helpermethods import download_file, decompress



def downloadData(workingDir, downloadDir, linkTrain, linkTest):
path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
try:
os.makedirs(path, exist_ok=True)
except OSError:
print("Could not create %s. Make sure the path does" % path)
print("not already exist and you have permissions to create it.")
return False

training_data_bz2 = download_file(linkTrain, path)
test_data_bz2 = download_file(linkTest, path)

training_data = decompress(training_data_bz2)
test_data = decompress(test_data_bz2)

train = os.path.join(path, "train.txt")
test = os.path.join(path, "test.txt")
if os.path.isfile(train):
os.remove(train)
if os.path.isfile(test):
os.remove(test)

os.rename(training_data, train)
os.rename(test_data, test)
os.remove(training_data_bz2)
os.remove(test_data_bz2)
return True

if __name__ == '__main__':
workingDir = './'
downloadDir = 'usps10'
linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2'
linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'
failureMsg = '''
Download Failed!
To manually perform the download
\t1. Create a new empty directory named `usps10`.
\t2. Download the data from the following links into the usps10 directory.
\t\tTest: %s
\t\tTrain: %s
\t3. Extract the downloaded files.
\t4. Rename `usps` to `train.txt` and,
\t5. Rename `usps.t` to `test.txt
''' % (linkTrain, linkTest)

if not downloadData(workingDir, downloadDir, linkTrain, linkTest):
exit(failureMsg)
print("Done: see ", downloadDir)

0 comments on commit e77b36c

Please sign in to comment.