Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

CTC training speed question #220

Open
yuekaizhang opened this issue Jun 28, 2021 · 10 comments
Open

CTC training speed question #220

yuekaizhang opened this issue Jun 28, 2021 · 10 comments

Comments

@yuekaizhang
Copy link

Hi, for my experiment, built-in (cudnnctc) is about 2.5 times fast than k2-ctc. I was wondering if this is normal and would like to make sure my program is correct.

I found that decoding_graph = k2.compose(self.ctc_topo, label_graph, treat_epsilons_specially=False) is the reason even with build_ctc_topo2 #209. How about considering to construct ctc loss directly from text rather than compsing the topo FST and label FST. In my experiment, it would give a very similar speed with cudnnctc.

Like the below pic:
Screenshot2021_06_28_100754

@danpovey
Copy link
Contributor

Thanks for doing the comparison, and sure, that's a good idea. Yes, we should introduce a special-purpose function that constructs a batch of CTC graphs from a ragged tensor consisting of the linear symbol sequences for each one. Perhaps @pkufool could work on that?

@csukuangfj
Copy link
Collaborator

Shall we also consider the transition probability contained in the bigram P while constructing the graph for LF-MMI training?
(It's not an issue for CTC training.)

@danpovey
Copy link
Contributor

danpovey commented Jun 28, 2021 via email

@yuekaizhang
Copy link
Author

yuekaizhang commented Jun 28, 2021

Hi, Dan, my comparison is based on the code:

def compile(self, texts: Iterable[str]) -> k2.Fsa:
decoding_graphs = k2.create_fsa_vec(
[self.compile_one_and_cache(text) for text in texts])
# make sure the gradient is not accumulated
decoding_graphs.requires_grad_(False)
return decoding_graphs
@lru_cache(maxsize=100000)
def compile_one_and_cache(self, text: str) -> k2.Fsa:
tokens = (token if token in self.words else self.oov
for token in text.split(' '))
word_ids = [self.words[token] for token in tokens]
label_graph = k2.linear_fsa(word_ids)
decoding_graph = k2.connect(k2.intersect(label_graph,
self.L_inv)).invert_()
decoding_graph = k2.arc_sort(decoding_graph)
decoding_graph = k2.compose(self.ctc_topo, decoding_graph)
decoding_graph = k2.connect(decoding_graph)
return decoding_graph
. I used the below code to avoid composing operation. I was wondering any reference code to compile them as a batch. Thanks.

def _compile_one_and_cache_v2(self, text: torch.Tensor) -> k2.Fsa:
    text = text.tolist()
    blank_idx = 0
    num_tokens = len(text)
    S = 2 * num_tokens + 1
    final = S + 1
    arcs = []
    arcs.append([final])
    for s in range(S):
        idx = (s-1) // 2
        word_id = text[idx] if s % 2 else blank_idx
        arcs.append([s,s,word_id,word_id,0]) 
        if s > 0:
            arcs.append([s-1,s,word_id,word_id,0]) 
        if s % 2 and s > 1 and word_id != text[idx - 1]:
            arcs.append([s-2,s,word_id,word_id,0]) 
    arcs.append([S-2,final,-1,-1,0]) 
    arcs.append([S-1,final,-1,-1,0]) 
    arcs = sorted(arcs, key=lambda arc: arc[0])
    arcs = [[str(i) for i in arc] for arc in arcs]
    arcs = [' '.join(arc) for arc in arcs]
    arcs = '\n'.join(arcs)
    ctc_topo = k2.Fsa.from_str(arcs, False)
    return k2.arc_sort(ctc_topo).to(self.device)

@csukuangfj
Copy link
Collaborator

That code is doing composition on CPU.

Could you try

def build_num_graphs(self, texts: List[str]) -> k2.Fsa:
'''Convert transcript to an Fsa with the help of lexicon
and word symbol table.
Args:
texts:
Each element is a transcript containing words separated by spaces.
For instance, it may be 'HELLO SNOWFALL', which contains
two words.
Returns:
Return an FST (FsaVec) corresponding to the transcript. Its `labels` are
phone IDs and `aux_labels` are word IDs.
'''
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split(' '):
if word in self.lexicon.words:
word_ids.append(self.lexicon.words[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
fsa = k2.linear_fsa(word_ids_list, self.device)
fsa = k2.add_epsilon_self_loops(fsa)
assert fsa.device == self.device
num_graphs = k2.intersect(self.L_inv,
fsa,
treat_epsilons_specially=False).invert_()
num_graphs = k2.arc_sort(num_graphs)
return num_graphs

which is run on GPU.

@yuekaizhang
Copy link
Author

yuekaizhang commented Jun 28, 2021

Thanks. I could try this way to compose. Actually, for my code, I follow this requirement When treat_epsilons_specially is True, this function works only on CPU. When treat_epsilons_specially is False and both a_fsa and b_fsa are on GPU, then this function works on GPU
So, I do k2.compose(self.ctc_topo.todevice("cuda"), decoding_graph.todevice("cuda"), treat_epsilons=False). I think its on GPU according to the doc?

@pkufool
Copy link
Contributor

pkufool commented Jun 28, 2021

Thanks for doing the comparison, and sure, that's a good idea. Yes, we should introduce a special-purpose function that constructs a batch of CTC graphs from a ragged tensor consisting of the linear symbol sequences for each one. Perhaps @pkufool could work on that?

Sure, I will.

@csukuangfj
Copy link
Collaborator

I was wondering any reference code to compile them as a batch.

I am afraid that has to be done in C++.

So, I do k2.compose(self.ctc_topo.todevice("cuda"), decoding_graph.todevice("cuda"), treat_epsilons=False). I think its on GPU according to the doc?

Yes, it is run on GPU. It would be more efficient if you

  • (1) Move ctc_topo to GPU inside the constructor, e.g., in __init__.
  • (2) Construct the decoding graph on GPU directly, not to move it after construction.

@pkufool
Copy link
Contributor

pkufool commented Jun 29, 2021

Thanks for doing the comparison, and sure, that's a good idea. Yes, we should introduce a special-purpose function that constructs a batch of CTC graphs from a ragged tensor consisting of the linear symbol sequences for each one. Perhaps @pkufool could work on that?

@danpovey Do you mean constructing the decoding_graphs for texts in a batch rather than call compile_one_and_cache() for several times.

@danpovey
Copy link
Contributor

Yes, I'm talking about constructing it for a batch at a time; in general all our FSA functions work for a batch (of course people can use a batch of one if needed). This function will be very fast so there is no problem re-doing the work on each minibatch.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants