/
cnn.py
172 lines (154 loc) · 5.43 KB
/
cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CNN model with Mel spectrum."""
from kws_streaming.layers import modes
from kws_streaming.layers import quantize
from kws_streaming.layers import speech_features
from kws_streaming.layers import stream
from kws_streaming.layers.compat import tf
import kws_streaming.models.model_utils as utils
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import AllValuesQuantizer
def model_parameters(parser_nn):
"""Covolutional Neural Network(CNN) model parameters."""
parser_nn.add_argument(
'--cnn_filters',
type=str,
default='64,64,64,64,128,64,128',
help='Number of output filters in the convolution layers',
)
parser_nn.add_argument(
'--cnn_kernel_size',
type=str,
default='(3,3),(5,3),(5,3),(5,3),(5,2),(5,1),(10,1)',
help='Heights and widths of the 2D convolution window',
)
parser_nn.add_argument(
'--cnn_act',
type=str,
default="'relu','relu','relu','relu','relu','relu','relu'",
help='Activation function in the convolution layers',
)
parser_nn.add_argument(
'--cnn_dilation_rate',
type=str,
default='(1,1),(1,1),(2,1),(1,1),(2,1),(1,1),(2,1)',
help='Dilation rate to use for dilated convolutions',
)
parser_nn.add_argument(
'--cnn_strides',
type=str,
default='(1,1),(1,1),(1,1),(1,1),(1,1),(1,1),(1,1)',
help='Strides of the convolution layers along the height and width',
)
parser_nn.add_argument(
'--dropout1',
type=float,
default=0.5,
help='Percentage of data dropped',
)
parser_nn.add_argument(
'--units2',
type=str,
default='128,256',
help='Number of units in the last set of hidden layers',
)
parser_nn.add_argument(
'--act2',
type=str,
default="'linear','relu'",
help='Activation function of the last set of hidden layers',
)
def model(flags):
"""CNN model.
It is based on paper:
Convolutional Neural Networks for Small-footprint Keyword Spotting
http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
Model topology is similar with "Hello Edge: Keyword Spotting on
Microcontrollers" https://arxiv.org/pdf/1711.07128.pdf
Supports quantization aware training with TF Model Optimization Toolkit
including the experimental n-bit scheme.
Args:
flags: data/model parameters
Returns:
Keras model for training
"""
input_audio = tf.keras.layers.Input(
shape=modes.get_input_data_shape(flags, modes.Modes.TRAINING),
batch_size=flags.batch_size)
net = input_audio
if flags.preprocess == 'raw':
# it is a self contained model, user need to feed raw audio only
net = speech_features.SpeechFeatures(
speech_features.SpeechFeatures.get_params(flags))(
net)
if flags.quantize:
net = quantize_layer.QuantizeLayer(
AllValuesQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False))(
net)
net = tf.keras.backend.expand_dims(net)
for filters, kernel_size, activation, dilation_rate, strides in zip(
utils.parse(flags.cnn_filters), utils.parse(flags.cnn_kernel_size),
utils.parse(flags.cnn_act), utils.parse(flags.cnn_dilation_rate),
utils.parse(flags.cnn_strides)):
net = stream.Stream(
cell=quantize.quantize_layer(
tf.keras.layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
activation='linear',
strides=strides),
flags=flags,
quantize_config=quantize.get_conv_bn_quantize_config(flags=flags)),
pad_time_dim='causal',
use_one_step=False)(
net)
net = quantize.quantize_layer(
tf.keras.layers.BatchNormalization(),
flags=flags,
quantize_config=quantize.get_no_op_quantize_config(flags=flags)
)(net)
net = quantize.quantize_layer(
tf.keras.layers.Activation(activation),
flags=flags,
)(net)
net = stream.Stream(
cell=quantize.quantize_layer(
tf.keras.layers.Flatten(),
flags=flags,
))(
net)
net = tf.keras.layers.Dropout(rate=flags.dropout1)(net)
for units, activation in zip(
utils.parse(flags.units2), utils.parse(flags.act2)):
net = quantize.quantize_layer(
tf.keras.layers.Dense(units=units, activation=activation),
flags=flags,
)(
net)
net = quantize.quantize_layer(
tf.keras.layers.Dense(units=flags.label_count),
flags=flags,
)(
net)
if flags.return_softmax:
net = quantize.quantize_layer(
tf.keras.layers.Activation('softmax'),
flags=flags,
)(
net)
return tf.keras.Model(input_audio, net)