New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test and Travis-CI #20
Conversation
@sw005320 I want you to merge this if it is OK. |
just a moment, I found Travis-CI is falling. https://travis-ci.org/ShigekiKarita/espnet |
Now, Travis is OK. Many tests are skipped because torch is not installed but we can wait for him pytorch/pytorch#4178 (comment) |
src/nets/e2e_asr_attctc_th.py
Outdated
acc = 0 | ||
pad_pred = y_all.data.view(pad_target.size(0), pad_target.size(1), y_all.size(1)).max(2)[1] | ||
mask = pad_target.data != ignore_label | ||
return torch.sum(pad_pred.masked_select(mask) == pad_target.data.masked_select(mask)) / torch.sum(mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final calculation torch.sum(pad_pred.masked_select(mask) == pad_target.data.masked_select(mask)) / torch.sum(mask)
is int / int
, therefore, in the case of python2, acc become = 0. (I confirmed)
We should cast to float.
I will show my modified version.
def th_accuracy(y_all, pad_target, ignore_label):
pad_pred = y_all.data.view(pad_target.size(0), pad_target.size(1), y_all.size(1)).max(2)[1]
mask = pad_target.data != ignore_label
numerator = torch.sum(pad_pred.masked_select(mask) == pad_target.data.masked_select(mask))
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
for i in range(len(ys)): | ||
acc += torch.sum(pred_pad[i, :ys[i].size(0)] == ys[i].data) | ||
acc /= sum(map(len, ys)) | ||
acc = th_accuracy(y_all, pad_ys_out, ignore_label=self.ignore_id) | ||
logging.info('att loss:' + str(self.loss.data)) | ||
|
||
# show predicted character sequence for debug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of chainer, ignore label is -1
.
Therefore, we should change as follows.
# now
idx_hat = np.argmax(y_hat_[y_true_ != -1], axis=1)
idx_true = y_true_[y_true_ != -1]
# proposed
idx_hat = np.argmax(y_hat_[y_true_ != self.ignore_id], axis=1)
idx_true = y_true_[y_true_ != self.ignore_id]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I think you are talking about show predicted character sequence for debug
at Decoder.forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I noticed the current version self.ignore_id=0. Is 0 for CTC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about this?
espnet/src/nets/e2e_asr_attctc_th.py
Line 1633 in c044f7a
self.ignore_id = 0 # NOTE: 0 for CTC? |
This may have some problems.
@ShigekiKarita, can you tell me why you set 0 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In current implementation, character index for decoder is 1 to odim - 1.
Therefore, 0 is not used in decoder and it might be no effect even if we use 0 as ignore id, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But to make more clear, I agree to change it to -1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kan-bayashi Many thanks for your reply!
@ShigekiKarita Hi Shigeki. I added some comments. |
src/nets/e2e_asr_attctc_th.py
Outdated
self.h_length = self.enc_h.shape[1] | ||
# utt x frame x att_dim | ||
self.pre_compute_enc_h = F.tanh(linear_tensor(self.mlp_enc, self.enc_h)) | ||
self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it need to perform tanh
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure why I delete it. However I will follow the reference impl of chainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In eq 9 in https://arxiv.org/pdf/1609.06773.pdf, pre_compute_h
corresponds to V h_l
that does not have tanh. Anyway it might not be big deal because we cannot see big difference in #9 (comment)
Thank you for many comments. |
src/nets/e2e_asr_attctc_th.py
Outdated
@@ -469,7 +469,7 @@ def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |||
self.enc_h = enc_hs_pad # utt x frame x hdim | |||
self.h_length = self.enc_h.shape[1] | |||
# utt x frame x att_dim | |||
self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) | |||
self.pre_compute_enc_h = torch.tanh(linear_tensor(self.mlp_enc, self.enc_h)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tanh
op is performed for only AttDot
in Chainer.
I think it is not needed for AttLoc
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that is right
Does everything look good? @kan-bayashi @sw005320 |
can you include pytorch and others in all at Makefile? |
done |
Update of joint model part.
Full-size training and wav2vec2 + mbart training
I'm working on Travis-CI #18