Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
mn5k committed Sep 21, 2018
1 parent d46949a commit 658269a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
8 changes: 8 additions & 0 deletions src/asr/asr_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ def train(args):
e2e = E2E(idim, odim, args)
model = Loss(e2e, args.mtlalpha)

if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, rnnlm)

This comment has been minimized.

Copy link
@JaejinCho

JaejinCho Jan 21, 2019

Contributor

I think torch.load should be torch_load.

This comment has been minimized.

Copy link
@sw005320

sw005320 Jan 21, 2019

Contributor

@JaejinCho, thanks!
@mn5k, could you test it and fix it?

This comment has been minimized.

Copy link
@mn5k

mn5k Jan 22, 2019

Author Contributor

Thank you. I'll fix it.

e2e.rnnlm = rnnlm

# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
Expand Down
14 changes: 4 additions & 10 deletions src/nets/e2e_asr_th.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import math
import sys

import editdistance
from argparse import Namespace
import editdistance

import chainer
import numpy as np
Expand All @@ -25,6 +25,7 @@
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

from ctc_prefix_score import CTCPrefixScore
from ctc_prefix_score import CTCPrefixScoreTH
from e2e_asr_common import end_detect
from e2e_asr_common import get_vgg2l_odim
Expand Down Expand Up @@ -269,13 +270,6 @@ def __init__(self, idim, odim, args):
'ctc_weight': args.ctc_weight, 'maxlenratio': args.maxlenratio,
'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight,
'rnnlm': args.rnnlm, 'nbest': args.nbest}
self.rnnlm = None
if self.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
self.rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(self.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, self.rnnlm)
self.recog_args = argparse.Namespace(**recog_args)
self.report_cer = args.report_cer
self.report_wer = args.report_wer
Expand Down Expand Up @@ -435,7 +429,7 @@ def recognize(self, x, recog_args, char_list, rnnlm=None):
if prev:
self.train()
return y

def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
'''E2E beam search
Expand Down Expand Up @@ -2175,7 +2169,7 @@ def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None,
_best_score = local_best_scores.view(-1).cpu().numpy()
local_scores[_best_odims] = _best_score
local_scores = to_cuda(self, torch.from_numpy(local_scores).float()).view(batch, beam, self.odim)

# (or indexing)
# local_scores = to_cuda(self, torch.full((batch, beam, self.odim), self.logzero))
# _best_odims = local_best_odims
Expand Down

0 comments on commit 658269a

Please sign in to comment.