Skip to content

Commit

Permalink
Resolved #38
Browse files Browse the repository at this point in the history
  • Loading branch information
upskyy committed Jul 21, 2021
1 parent 848a4f9 commit 3cdd9cb
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 150 deletions.
5 changes: 0 additions & 5 deletions openspeech/decoders/rnn_transducer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ def forward(
* hidden_states (torch.FloatTensor): A hidden state of decoders. `FloatTensor` of size
``(batch, seq_length, dimension)``
"""
batch_size, input_lengths = inputs.size(0), inputs.size(1)

if input_lengths != 1:
inputs = inputs[inputs != self.eos_id].view(batch_size, -1)

embedded = self.embedding(inputs)

if hidden_states is not None:
Expand Down
8 changes: 4 additions & 4 deletions openspeech/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attention_prenorm = nn.LayerNorm(d_model)
self.encoder_attention_prenorm = nn.LayerNorm(d_model)
self.decoder_attention_prenorm = nn.LayerNorm(d_model)
self.feed_forward_prenorm = nn.LayerNorm(d_model)
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.encoder_attention = MultiHeadAttention(d_model, num_heads)
self.decoder_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout_p)

def forward(
Expand Down Expand Up @@ -108,8 +108,8 @@ def forward(
outputs += residual

residual = outputs
outputs = self.encoder_attention_prenorm(outputs)
outputs, encoder_attn = self.encoder_attention(outputs, encoder_outputs, encoder_outputs, encoder_attn_mask)
outputs = self.decoder_attention_prenorm(outputs)
outputs, encoder_attn = self.decoder_attention(outputs, encoder_outputs, encoder_outputs, encoder_attn_mask)
outputs += residual

residual = outputs
Expand Down
3 changes: 1 addition & 2 deletions openspeech/decoders/transformer_transducer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.pad_id = pad_id
self.sos_id = sos_id
self.eos_id = eos_id
self.encoder_layers = nn.ModuleList([
self.decoder_layers = nn.ModuleList([
TransformerTransducerEncoderLayer(
model_dim,
d_ff,
Expand Down Expand Up @@ -124,7 +124,6 @@ def forward(
)

else: # train
inputs = inputs[inputs != self.eos_id].view(batch, -1)
target_lengths = inputs.size(1)

outputs = self.forward_step(
Expand Down
78 changes: 25 additions & 53 deletions openspeech/models/openspeech_transducer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,53 +88,27 @@ def collect_outputs(
input_lengths: torch.IntTensor,
targets: torch.IntTensor,
target_lengths: torch.IntTensor,
predictions: torch.Tensor = None,
) -> OrderedDict:
if predictions is None:
predictions = logits.max(-1)[1]
loss = self.criterion(
logits=logits,
targets=targets[:, 1:].contiguous().int(),
input_lengths=input_lengths.int(),
target_lengths=target_lengths.int(),
)

wer = self.wer_metric(targets[:, 1:], predictions)
cer = self.cer_metric(targets[:, 1:], predictions)

self.info({
f"{stage}_loss": loss,
f"{stage}_wer": wer,
f"{stage}_cer": cer,
"learning_rate": self.get_lr(),
})

return OrderedDict({
"loss": loss,
"wer": wer,
"cer": cer,
"predictions": predictions,
"targets": targets,
"logits": logits,
})
predictions = logits.max(-1)[1]

else:
wer = self.wer_metric(targets[:, 1:], predictions)
cer = self.cer_metric(targets[:, 1:], predictions)

self.info({
f"{stage}_wer": wer,
f"{stage}_cer": cer,
})

return OrderedDict({
"loss": None,
"wer": wer,
"cer": cer,
"predictions": predictions,
"targets": targets,
"logits": logits,
})
loss = self.criterion(
logits=logits,
targets=targets[:, 1:].contiguous().int(),
input_lengths=input_lengths.int(),
target_lengths=target_lengths.int(),
)

self.info({
f"{stage}_loss": loss,
"learning_rate": self.get_lr(),
})

return OrderedDict({
"loss": loss,
"predictions": predictions,
"targets": targets,
"logits": logits,
})

def _expand_for_joint(self, encoder_outputs: Tensor, decoder_outputs: Tensor) -> Tuple[Tensor, Tensor]:
input_length = encoder_outputs.size(1)
Expand Down Expand Up @@ -278,16 +252,15 @@ def validation_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
else:
encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)

max_length = encoder_outputs.size(1)
decoder_outputs, _ = self.decoder(targets, target_lengths)
logits = self.joint(encoder_outputs, decoder_outputs)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'val',
logits=None,
logits=logits,
input_lengths=output_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)

def test_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
Expand All @@ -308,14 +281,13 @@ def test_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
else:
encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)

max_length = encoder_outputs.size(1)
decoder_outputs, _ = self.decoder(targets, target_lengths)
logits = self.joint(encoder_outputs, decoder_outputs)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'test',
logits=None,
logits=logits,
input_lengths=output_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)
87 changes: 1 addition & 86 deletions openspeech/models/transformer_transducer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,89 +117,4 @@ def greedy_decode(self, encoder_outputs: Tensor, max_length: int) -> Tensor:

pred_tokens = torch.stack(pred_tokens, dim=1)

return torch.LongTensor(pred_tokens)

def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]:
r"""
Decode `encoder_outputs`.
Args:
inputs (torch.FloatTensor): A input sequence passed to encoders. Typically for inputs this will be a padded `FloatTensor` of size ``(batch, seq_length, dimension)``.
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
Returns:
outputs (dict): Result of model predictions.
"""
encoder_outputs, _ = self.encoder(inputs, input_lengths)
max_length = encoder_outputs.size(1)

predictions = self.decode(encoder_outputs, max_length)
return {
"predictions": predictions,
"encoder_outputs": encoder_outputs,
}

def training_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
r"""
Forward propagate a `inputs` and `targets` pair for training.
Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch
Returns:
loss (torch.Tensor): loss for training
"""
return super(TransformerTransducerModel, self).training_step(batch, batch_idx)

def validation_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
r"""
Forward propagate a `inputs` and `targets` pair for validation.
Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch
Returns:
loss (torch.Tensor): loss for training
"""
inputs, targets, input_lengths, target_lengths = batch

encoder_outputs, _ = self.encoder(inputs, input_lengths)
max_length = encoder_outputs.size(1)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'valid',
logits=None,
input_lengths=input_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)

def test_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
r"""
Forward propagate a `inputs` and `targets` pair for test.
Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch
Returns:
loss (torch.Tensor): loss for training
"""
inputs, targets, input_lengths, target_lengths = batch

encoder_outputs, _ = self.encoder(inputs, input_lengths)
max_length = encoder_outputs.size(1)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'valid',
logits=None,
input_lengths=input_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)
return torch.LongTensor(pred_tokens)

0 comments on commit 3cdd9cb

Please sign in to comment.