-
Notifications
You must be signed in to change notification settings - Fork 58
/
models.py
391 lines (318 loc) · 12.6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
"""This is a pytorch implementation mirroring:
https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
"""
from dataclasses import dataclass
import os
from typing import Optional, Tuple
import torch
from torch import nn
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \
preprocessor
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
SpecAug
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
@dataclass
class DeepspeechConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
vocab_size: int = 1024
encoder_dim: int = 512
num_lstm_layers: int = 6
num_ffn_layers: int = 3
conv_subsampling_factor: int = 2
conv_subsampling_layers: int = 2
use_specaug: bool = True
freq_mask_count: int = 2
freq_mask_max_bins: int = 27
time_mask_count: int = 10
time_mask_max_frames: int = 40
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
# If None, defaults to 0.1.
feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
use_tanh: bool = False
layernorm_everywhere: bool = False
class LayerNorm(nn.Module):
def __init__(self, dim, epsilon=1e-6):
super().__init__()
self.dim = dim
self.scale = nn.Parameter(torch.zeros(self.dim))
self.bias = nn.Parameter(torch.zeros(self.dim))
self.epsilon = epsilon
def forward(self, x):
mean = x.mean(dim=-1, keepdims=True)
var = x.var(dim=-1, unbiased=False, keepdims=True)
normed_x = (x - mean) * torch.rsqrt(var + self.epsilon)
normed_x *= (1 + self.scale)
normed_x += self.bias
return normed_x
class Subsample(nn.Module):
def __init__(self, config: DeepspeechConfig):
super().__init__()
encoder_dim = config.encoder_dim
self.encoder_dim = encoder_dim
self.conv1 = Conv2dSubsampling(
input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
self.conv2 = Conv2dSubsampling(
input_channels=encoder_dim,
output_channels=encoder_dim,
use_tanh=config.use_tanh)
self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
if config.input_dropout_rate is None:
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
self.dropout = nn.Dropout(p=input_dropout_rate)
def forward(self, inputs, input_paddings):
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
outputs, output_paddings = self.conv1(outputs, output_paddings)
outputs, output_paddings = self.conv2(outputs, output_paddings)
batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
subsampled_lengths,
subsampled_dims * channels)
outputs = self.lin(outputs)
outputs = self.dropout(outputs)
return outputs, output_paddings
class Conv2dSubsampling(nn.Module):
def __init__(self,
input_channels: int,
output_channels: int,
filter_stride: Tuple[int] = (2, 2),
padding: str = 'SAME',
batch_norm_momentum: float = 0.999,
batch_norm_epsilon: float = 0.001,
use_tanh: bool = False):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.filter_stride = filter_stride
self.padding = padding
self.filter_shape = (output_channels, input_channels, 3, 3)
self.kernel = nn.Parameter(
nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
self.bias = nn.Parameter(torch.zeros(output_channels))
self.use_tanh = use_tanh
def get_same_padding(self, input_shape):
in_height, in_width = input_shape[2:]
stride_height, stride_width = self.filter_stride
filter_height, filter_width = 3, 3
if in_height % stride_height == 0:
pad_along_height = max(filter_height - stride_height, 0)
else:
pad_along_height = max(filter_height - (in_height % stride_height), 0)
if in_width % stride_width == 0:
pad_along_width = max(filter_width - stride_width, 0)
else:
pad_along_width = max(filter_width - (in_width % stride_width), 0)
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
return (pad_left, pad_right, pad_top, pad_bottom)
def forward(self, inputs, paddings):
groups = inputs.shape[1] // self.input_channels
if self.padding == 'SAME':
in_ = F.pad(inputs, self.get_same_padding(inputs.shape))
else:
in_ = inputs
outputs = F.conv2d(
input=in_,
weight=self.kernel,
bias=self.bias,
stride=self.filter_stride,
dilation=(1, 1),
groups=groups)
if self.use_tanh:
outputs = F.tanh(outputs)
else:
outputs = F.relu(outputs)
input_length = paddings.shape[1]
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = F.conv1d(
input=torch.cat([
paddings[:, None, :],
torch.zeros(
size=(paddings.shape[0], 1, pad_len), device=paddings.device)
],
dim=2),
weight=torch.ones([1, 1, 1], device=paddings.device),
stride=self.filter_stride[:1])
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
return outputs, out_padding
class FeedForwardModule(nn.Module):
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.bn_normalization_layer = BatchNorm(
dim=config.encoder_dim,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon)
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
if config.feed_forward_dropout_rate is None:
feed_forward_dropout_rate = 0.1
else:
feed_forward_dropout_rate = config.feed_forward_dropout_rate
self.dropout = nn.Dropout(p=feed_forward_dropout_rate)
def forward(self, inputs, input_paddings):
padding_mask = (1 - input_paddings)[:, :, None]
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else: # batchnorm
inputs = self.bn_normalization_layer(inputs, input_paddings)
inputs = self.lin(inputs)
if self.config.use_tanh:
inputs = F.tanh(inputs)
else:
inputs = F.relu(inputs)
inputs = inputs * padding_mask
inputs = self.dropout(inputs)
return inputs
class BatchNorm(nn.Module):
def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
super().__init__()
running_mean = torch.zeros(dim)
running_var = torch.ones(dim)
self.register_buffer('running_mean', running_mean)
self.register_buffer('running_var', running_var)
self.weight = nn.Parameter(torch.zeros(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.momentum = batch_norm_momentum
self.epsilon = batch_norm_epsilon
self.dim = dim
def forward(self, inputs, input_paddings):
#inputs: NHD
#padding: NH
mask = 1 - input_paddings[:, :, None]
if self.training:
count = mask.sum()
masked_inp = inputs.masked_fill(mask == 0, 0)
sum_ = (masked_inp).sum(dim=(0, 1))
if USE_PYTORCH_DDP:
sum_ = dist_nn.all_reduce(sum_)
count = dist_nn.all_reduce(count)
mean = sum_ / count
sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1))
if USE_PYTORCH_DDP:
sum_ = dist_nn.all_reduce(sum_)
var = sum_ / count
self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
else:
mean = self.running_mean
var = self.running_var
v = (1 + self.weight) * torch.rsqrt(var + self.epsilon)
bn = (inputs - mean) * v + self.bias
output = bn.masked_fill(mask == 0, 0)
return output
class BatchRNN(nn.Module):
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
hidden_size = config.encoder_dim
input_size = config.encoder_dim
bidirectional = config.bidirectional
self.bidirectional = bidirectional
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.bn_normalization_layer = BatchNorm(config.encoder_dim,
config.batch_norm_momentum,
config.batch_norm_epsilon)
if bidirectional:
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size // 2,
bidirectional=True,
batch_first=True)
else:
self.lstm = nn.LSTM(
input_size=input_size, hidden_size=hidden_size, batch_first=True)
def forward(self, inputs, input_paddings):
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else:
inputs = self.bn_normalization_layer(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, lengths, batch_first=True, enforce_sorted=False)
packed_outputs, _ = self.lstm(packed_inputs)
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
packed_outputs, batch_first=True)
if outputs.shape[1] < inputs.shape[1]:
outputs = torch.cat([
outputs,
torch.zeros(
size=(outputs.shape[0],
inputs.shape[1] - outputs.shape[1],
outputs.shape[2]),
device=outputs.device)
],
dim=1)
return outputs
class DeepspeechEncoderDecoder(nn.Module):
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
self.specaug = SpecAug(
freq_mask_count=config.freq_mask_count,
freq_mask_max_bins=config.freq_mask_max_bins,
time_mask_count=config.time_mask_count,
time_mask_max_frames=config.time_mask_max_frames,
time_mask_max_ratio=config.time_mask_max_ratio,
time_masks_per_frame=config.time_masks_per_frame,
use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
)
preprocessing_config = preprocessor.PreprocessorConfig()
self.preprocessor = preprocessor.MelFilterbankFrontend(
preprocessing_config,
per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
self.subsample = Subsample(config=config)
self.lstms = nn.ModuleList(
[BatchRNN(config) for _ in range(config.num_lstm_layers)])
self.ffns = nn.ModuleList(
[FeedForwardModule(config) for _ in range(config.num_ffn_layers)])
if config.enable_decoder_layer_norm:
self.ln = LayerNorm(config.encoder_dim)
else:
self.ln = nn.Identity()
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
def forward(self, inputs, input_paddings):
outputs = inputs
output_paddings = input_paddings
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
outputs, output_paddings = self.subsample(outputs, output_paddings)
for idx in range(self.config.num_lstm_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.lstms[idx](outputs, output_paddings)
else:
outputs = self.lstms[idx](outputs, output_paddings)
for idx in range(self.config.num_ffn_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.ffns[idx](outputs, output_paddings)
else:
outputs = self.ffns[idx](outputs, output_paddings)
if self.config.enable_decoder_layer_norm:
outputs = self.ln(outputs)
outputs = self.lin(outputs)
return outputs, output_paddings