Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The mechanism of alignment between text encoder output and audio_seq2seq output #37

Open
inconnu11 opened this issue May 11, 2020 · 1 comment

Comments

@inconnu11
Copy link

Hi, Zhang
Could you please explain how the text encoder output and recognition encoder output align? it is stated in your paper as "The recognition encoder Er is a seq2seq neural network which aligns the acoustic and phoneme sequences automatically." I couldn't figure out how the code work.
Thank you advance!

@jxzhanggg
Copy link
Owner

jxzhanggg commented May 11, 2020

Hi, by saying that, I mean the recognition encoder is a seq2seq with attention module, and its definition is here

class AudioSeq2seq(nn.Module):
'''
- Simple 2 layer bidirectional LSTM
'''
def __init__(self, hparams):
super(AudioSeq2seq, self).__init__()
self.encoder = AudioEncoder(hparams)
self.decoder_rnn_dim = hparams.audio_encoder_hidden_dim
self.attention_layer = Attention(self.decoder_rnn_dim, hparams.audio_encoder_hidden_dim,
hparams.AE_attention_dim, hparams.AE_attention_location_n_filters,
hparams.AE_attention_location_kernel_size)
self.decoder_rnn = nn.LSTMCell(hparams.symbols_embedding_dim + hparams.audio_encoder_hidden_dim,
self.decoder_rnn_dim)
def _proj(activation):
if activation is not None:
return nn.Sequential(LinearNorm(self.decoder_rnn_dim+hparams.audio_encoder_hidden_dim,
hparams.encoder_embedding_dim,
w_init_gain=hparams.hidden_activation),
activation)
else:
return LinearNorm(self.decoder_rnn_dim+hparams.audio_encoder_hidden_dim,
hparams.encoder_embedding_dim,
w_init_gain=hparams.hidden_activation)
if hparams.hidden_activation == 'relu':
self.project_to_hidden = _proj(nn.ReLU())
elif hparams.hidden_activation == 'tanh':
self.project_to_hidden = _proj(nn.Tanh())
elif hparams.hidden_activation == 'linear':
self.project_to_hidden = _proj(None)
else:
print('Must be relu, tanh or linear.')
assert False
self.project_to_n_symbols= LinearNorm(hparams.encoder_embedding_dim,
hparams.n_symbols + 1) # plus the <eos>
self.eos = hparams.n_symbols
self.activation = hparams.hidden_activation
self.max_len = 100
def initialize_decoder_states(self, memory, mask):
B = memory.size(0)
MAX_TIME = memory.size(1)
self.decoder_hidden = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.attention_weigths = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weigths_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
def map_states(self, fn):
'''
mapping the decoder states using fn
'''
self.decoder_hidden = fn(self.decoder_hidden, 0)
self.decoder_cell = fn(self.decoder_cell, 0)
self.attention_weigths = fn(self.attention_weigths, 0)
self.attention_weigths_cum = fn(self.attention_weigths_cum, 0)
self.attention_context = fn(self.attention_context, 0)
def parse_decoder_outputs(self, hidden, logit, alignments):
# -> [B, T_out + 1, max_time]
alignments = torch.stack(alignments).transpose(0,1)
# [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
logit = torch.stack(logit).transpose(0, 1).contiguous()
hidden = torch.stack(hidden).transpose(0, 1).contiguous()
return hidden, logit, alignments
def decode(self, decoder_input):
cell_input = torch.cat((decoder_input, self.attention_context),-1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
cell_input,
(self.decoder_hidden,
self.decoder_cell))
attention_weigths_cat = torch.cat(
(self.attention_weigths.unsqueeze(1),
self.attention_weigths_cum.unsqueeze(1)),dim=1)
self.attention_context, self.attention_weigths = self.attention_layer(
self.decoder_hidden,
self.memory,
self.processed_memory,
attention_weigths_cat,
self.mask)
self.attention_weigths_cum += self.attention_weigths
hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
hidden = self.project_to_hidden(hidden_and_context)
# dropout to increasing g
logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
return hidden, logit, self.attention_weigths
def forward(self, mel, mel_lengths, decoder_inputs, start_embedding):
'''
decoder_inputs: [B, channel, T]
start_embedding [B, channel]
return
hidden_outputs [B, T+1, channel]
logits_outputs [B, T+1, n_symbols]
alignments [B, T+1, max_time]
'''
memory, memory_lengths = self.encoder(mel, mel_lengths)
decoder_inputs = decoder_inputs.permute(2, 0, 1) # -> [T, B, channel]
decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
self.initialize_decoder_states(memory,
mask=~get_mask_from_lengths(memory_lengths))
hidden_outputs, logit_outputs, alignments = [], [], []
while len(hidden_outputs) < decoder_inputs.size(0):
decoder_input = decoder_inputs[len(hidden_outputs)]
hidden, logit, attention_weights = self.decode(decoder_input)
hidden_outputs += [hidden]
logit_outputs += [logit]
alignments += [attention_weights]
hidden_outputs, logit_outputs, alignments = \
self.parse_decoder_outputs(
hidden_outputs, logit_outputs, alignments)
return hidden_outputs, logit_outputs, alignments
'''
use beam search ?
'''
def inference_greed(self, x, start_embedding, embedding_table):
'''
decoding the phone sequence using greed algorithm
x [1, mel_bins, T]
start_embedding [1,embedding_dim]
embedding_table nn.Embedding class
return
hidden_outputs [1, ]
'''
MAX_LEN = 100
decoder_input = start_embedding
memory = self.encoder.inference(x)
self.initialize_decoder_states(memory, mask=None)
hidden_outputs, alignments, phone_ids = [], [], []
while True:
hidden, logit, attention_weights = self.decode(decoder_input)
hidden_outputs += [hidden]
alignments += [attention_weights]
phone_id = torch.argmax(logit,dim=1)
phone_ids += [phone_id]
# if reaches the <eos>
if phone_id.squeeze().item() == self.eos:
break
if len(hidden_outputs) == self.max_len:
break
print('Warning! The decoded text reaches the maximum lengths.')
# embedding the phone_id
decoder_input = embedding_table(phone_id) # -> [1, embedding_dim]
hidden_outputs, phone_ids, alignments = \
self.parse_decoder_outputs(hidden_outputs, phone_ids, alignments)
return hidden_outputs, phone_ids, alignments
def inference_beam(self, x, start_embedding, embedding_table,
beam_width=20,):
memory = self.encoder.inference(x).expand(beam_width, -1,-1)
MAX_LEN = 100
n_best = 5
self.initialize_decoder_states(memory, mask=None)
decoder_input = tile(start_embedding, beam_width)
beam = Beam(beam_width, 0, self.eos, self.eos,
n_best=n_best, cuda=True, global_scorer=GNMTGlobalScorer())
hidden_outputs, alignments, phone_ids = [], [], []
for step in range(MAX_LEN):
if beam.done():
break
hidden, logit, attention_weights = self.decode(decoder_input)
logit = F.log_softmax(logit, dim=1)
beam.advance(logit, attention_weights, hidden)
select_indices = beam.get_current_origin()
self.map_states(lambda state, dim: state.index_select(dim, select_indices))
decoder_input = embedding_table(beam.get_current_state())
scores, ks = beam.sort_finished(minimum=n_best)
hyps, attn, hiddens = [], [], []
for i, (times, k) in enumerate(ks[:n_best]):
hyp, att, hid = beam.get_hyp(times, k)
hyps.append(hyp)
attn.append(att)
hiddens.append(hid)
return hiddens[0].unsqueeze(0), hyps[0].unsqueeze(0), attn[0].unsqueeze(0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants