You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@zhu-han this repo contains the sampling code I mentioned to you. The "iterative" aspect of it is not needed here, we just treat it as a simple way to sample from a distribution.
Below, is some icefall code called unsupervised.py, that I was going to use to sample CTC transcripts for use in unsupervised training. I believe sampling is more correct than taking the top-one, and will avoid it collapsing to blank.
# Some utilities for unsupervised training
import random
from typing import Optional, Sequence, Tuple, TypeVar, Union, Dict, List
import math
import torch
from k2 import RaggedTensor, Fsa
from torch import nn
from torch import Tensor
import torch_iterative_sampling
Supervisions = Dict[str, torch.Tensor]
def sample_ctc_transcripts_ragged(
ctc_output: Tensor,
paths_per_sequence: int,
modified_topo: bool) -> RaggedTensor:
"""
ctc_output: a Tensor of shape (N, T, C), i.e. (batch, time, num_symbols),
containing normalized log-probs.
paths_per_sequence: The number of separately sampled paths that are requested
per sequence
modified_topo: True if the system is using the modified CTC topology where two
consecutive instances of a nonzero symbol can mean either one or two
copies of the original symbol.
Returns a RaggedTensor, on the same device as `ctc_output`, with shape (N *
paths_per_sequence, None), where 1st index is batch_idx * paths_per_sequence
+ path_idx and 2nd idx is the position in the token sequence. The returned
RaggedTensor will have no 0's... those will have been removed, as blanks.
"""
(N, T, C) = ctc_output.shape
# The 'seq_len' arg below is something specific to the "iterative" part of
# torch_iterative_sampling, which has to do with "sampling without replacement";
# here, we don't really want to do "iterative sampling", we just want to
# sample from the distribution once.
probs = ctc_output.exp()
sampled = torch_iterative_sampling.iterative_sample(probs,
num_seqs=paths_per_sequence,
seq_len=1).to(dtype=torch.int32)
# `sampled` now has shape:
# (N, T, paths_per_sequence, 1)
sampled = sampled.squeeze(3).transpose(1, 2)
# `sampled` now has shape (N, paths_per_sequence, T)
# identical_mask is of shape (N, paths_per_sequence, T-1), and
# contains True at each position if sampled[n,s,t] == sampled[n,s,t+1].
identical_mask = sampled[:,:,1:] == sampled[:,:,:-1]
if modified_topo:
# If we are using the modified/simplified CTC topology, it is possible for
# two consecutive instances of a nonzero symbol to represent either
# one symbol or two. We choose either, with probability 0.5. I think this
# is correct, perhaps should check though.
identical_mask = identical_mask and (torch.randn(*identical_mask.shape,
device=identical_mask.device) > 0.5)
# The following statement replaces repeats of nonzero symbols with 0, so only the
# final symbol in a chain of identical, consecutive symbols will retain its
# nonzero value.
sampled[:,:,:-1].masked_fill_(identical_mask, 0)
sampled = sampled.reshape(N * paths_per_sequence, T)
# The shape of ragged_sampled would be the same as `sampled`.. it's regular.
# if you query it, though, it would come up as (N * paths_per_sequence, None).
ragged_sampled = RaggedTensor(sampled)
# Remove 0's from the ragged tensor, to keep only "real" (non-blank) symbols.
ragged_sampled = ragged_sampled.remove_values_leq(0)
# note: you can create the CTC graphs with k2.ctc_graph(ragged_sampled, modified={True,False})
# You can turn into a List[List[int]] with ragged_sampled.tolist().
return ragged_sampled
def _test_sample_ctc_transcripts_ragged():
for device in ['cpu', 'cuda']:
# simple case.. N = 1, T == 2, C == 3
ctc_output = torch.Tensor( [[[ 0., 1., 0. ], [ 1., 0., 0. ] ],
[[ 1., 0., 0. ], [ 0., 0., 1. ] ]]).to(device=device).log()
r = sample_ctc_transcripts_ragged(ctc_output, paths_per_sequence=1, modified_topo=False)
print("r = ", r)
assert r == RaggedTensor('[[1], [2]]', dtype=torch.int32,
device=device)
for device in ['cpu', 'cuda']:
# simple case.. N = 1, T == 3, C == 3, with repeats.
# We use modified == False, so the repeats should be removed.
ctc_output = torch.Tensor( [[[ 0., 1., 0. ], [0., 1., 0.], [ 1., 0., 0. ] ],
[[ 1., 0., 0. ], [0., 0., 1.], [ 0., 0., 1. ] ]]).to(device=device).log()
r = sample_ctc_transcripts_ragged(ctc_output, paths_per_sequence=1, modified_topo=False)
print("r = ", r)
assert r == RaggedTensor('[[1], [2]]', dtype=torch.int32,
device=device)
if __name__ == "__main__":
_test_sample_ctc_transcripts_ragged()
The text was updated successfully, but these errors were encountered:
Cool! Is there any specific reasons to use this repo for the sampling? I find there is similar function in Pytorch: torch.distributions.categorical.Categorical
@zhu-han this repo contains the sampling code I mentioned to you. The "iterative" aspect of it is not needed here, we just treat it as a simple way to sample from a distribution.
Below, is some icefall code called unsupervised.py, that I was going to use to sample CTC transcripts for use in unsupervised training. I believe sampling is more correct than taking the top-one, and will avoid it collapsing to blank.
The text was updated successfully, but these errors were encountered: