In [None]:
from tf2lwnn import *
import sys
import os
sys.path.append(os.path.abspath('../../../tools'))
import numpy as np
from verifyoutput import *
from lwnn2torch import *
import onnx
import onnxruntime
%matplotlib inline

In [None]:
DEMO_MODEL='Basic_LSTM_S.pb'
DEMO_INPUT='/mnt/d/tmp/speech_dataset/yes/0a9f9af7_nohash_1.wav'
LABELS=['_silence_','_unknown_','yes','no','up','down','left','right','on','off','stop','go']

In [None]:
wav_data = np.fromfile(DEMO_INPUT, np.int8).reshape(1, -1)
converter = TfConverter(DEMO_MODEL, 'KWS', use_tf2onnx=True)
lwnn_model = converter.model

In [None]:
for layer in lwnn_model:
    print(layer)
    if(layer.op == 'LSTM'):
        lstm_layer = layer

In [None]:
itensor = converter.get_tensor('wav_data')
otensor = converter.get_tensor('labels_softmax')
lstm_i = converter.get_tensor('LSTM-Layer/lstm/transpose')
lstm_o = converter.get_tensor('LSTM-Layer/lstm/rnn/while/Exit_3')

In [None]:
result, lstm_i_result, lstm_o_result = converter.sess.run([otensor, lstm_i, lstm_o], {itensor: wav_data.tobytes() })

In [None]:
lstm_i_result.shape, lstm_o_result.shape

In [None]:
for node in converter.onnx_model.graph.node:
    if(node.op_type == 'LSTM'):
        lstm_node = node
    if(node.op_type == 'Squeeze'):
        sq_node = node
lstm_node, sq_node

In [None]:
x = onnx.helper.make_tensor_value_info(lstm_node.input[0], onnx.TensorProto.FLOAT, lstm_i_result.shape)
attrs = {}
for attr in lstm_node.attribute:
    v = onnx.helper.get_attribute_value(attr)
    attrs[attr.name] = v
node = onnx.helper.make_node(
            'LSTM',
            name = lstm_node.name,
            inputs=lstm_node.input[:4],
            outputs=lstm_node.output,
            **attrs)
outputs = [onnx.helper.make_tensor_value_info(o, onnx.TensorProto.FLOAT, None) for o in sq_node.output]
outputs.extend([onnx.helper.make_tensor_value_info(o, onnx.TensorProto.FLOAT, None) for o in node.output])
graph = onnx.helper.make_graph(
            nodes = [node, sq_node],
            name = 'LSTM',
            inputs = [x],
            outputs = outputs,
            value_info = [],
            initializer = converter.onnx_model.graph.initializer)
model = onnx.helper.make_model(graph, producer_name='lwnn-nhwc')
onnx.save(model, '.tmp.onnx')

In [None]:
sess = onnxruntime.InferenceSession('.tmp.onnx')

In [None]:
rs = sess.run(None, {lstm_node.input[0]: lstm_i_result })

In [None]:
rs[0].shape

In [None]:
compare(lstm_o_result, rs[0])

In [None]:
print(lstm_layer)

In [None]:
W,R,B=lstm_layer.W, lstm_layer.R, lstm_layer.B
print(W.shape, R.shape, B.shape)
Wi,Wo,Wf,Wc = W.reshape(4,-1,10)
print(Wi.shape, Wo.shape ,Wf.shape, Wc.shape)
Ri,Ro,Rf,Rc = R.reshape(4,-1,118)
print(Ri.shape, Ro.shape ,Rf.shape, Rc.shape)
Wbi,Wbo,Wbf,Wbc,Rbi,Rbo,Rbf,Rbc = B.reshape(8, -1)
print(Wbi.shape, Wbo.shape ,Wbf.shape, Wbc.shape, Rbi.shape, Rbo.shape ,Rbf.shape, Rbc.shape)

In [None]:
kerenl = converter.get_tensor('lstm/rnn/basic_lstm_cell/kernel/read')
bias = converter.get_tensor('lstm/rnn/basic_lstm_cell/bias/read')
print(kerenl.shape, bias.shape)
w, b = converter.sess.run((kerenl,bias))
w, r = w[:10, :], w[10:, :]
w = w.transpose(1,0)
r = r.transpose(1,0)
print(w.shape, r.shape, b.shape)
wi,wc,wf,wo = w.reshape(4,-1,10)
print(wi.shape, wo.shape ,wf.shape, wc.shape)
ri,rc,rf,ro = r.reshape(4,-1,118)
print(ri.shape, ro.shape ,rf.shape, rc.shape)
wbi,wbc,wbf,wbo = b.reshape(4, -1)
print(wbi.shape, wbo.shape ,wbf.shape, wbc.shape)

In [None]:
compare(Wi, wi, 'wi')
compare(Wf, wf, 'wf')
compare(Wo, wo, 'wo')
compare(Wc, wc, 'wc')

In [None]:
compare(Ri, ri, 'ri')
compare(Rf, rf, 'rf')
compare(Ro, ro, 'ro')
compare(Rc, rc, 'rc')

In [None]:
compare(Wbi, wbi, 'wbi')
compare(Wbf, wbf, 'wbf')
compare(Wbo, wbo, 'wbo')
compare(Wbc, wbc, 'wbc')

https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/rnn/lstm_ops.cc

Equations (Default: f=Sigmoid, g=Tanh, h=Tanh):

```python
 it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
 ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
 ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
 Ct = ft (.) Ct-1 + it (.) ct
 ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
 Ht = ot (.) h(Ct)
```

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
c = np.zeros((1,118))
h = np.zeros((1,118))
for x in np.split(lstm_i_result, lstm_i_result.shape[0], axis=0):
    it = sigmoid( np.dot(x, Wi.transpose()) + np.dot(h, Ri.transpose()) + Wbi + Rbi )
    ft = sigmoid( np.dot(x, Wf.transpose()) + np.dot(h, Rf.transpose()) + Wbf + Rbf )
    ct = np.tanh( np.dot(x, Wc.transpose()) + np.dot(h, Rc.transpose()) + Wbc + Rbc )
    Ct = ft*c + it*ct
    ot = sigmoid( np.dot(x, Wo.transpose()) + np.dot(h, Ro.transpose()) + Wbo + Rbo )
    Ht = ot*np.tanh(Ct)
    c = Ct
    h = Ht

In [None]:
compare(lstm_o_result, Ht)

In [None]:
from onnx.backend.test.case.node.lstm import LSTM_Helper
lstm = LSTM_Helper(X=lstm_i_result, W=W, R=R, B=B)
Y, Y_h = lstm.step()

In [None]:
compare(lstm_o_result, Y_h)

In [None]:
lstm_o_result.tofile('tmp/lstm.raw')
for name, tensor in converter.tensors.items():
    if(name in ['Mfcc']):
        o = converter.sess.run(tensor, {itensor: wav_data.tobytes()})
        o.tofile('tmp/%s.raw'%(name))