Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling code #1

Open
danpovey opened this issue Nov 5, 2021 · 3 comments
Open

Sampling code #1

danpovey opened this issue Nov 5, 2021 · 3 comments

Comments

@danpovey
Copy link
Owner

danpovey commented Nov 5, 2021

@zhu-han this repo contains the sampling code I mentioned to you. The "iterative" aspect of it is not needed here, we just treat it as a simple way to sample from a distribution.

Below, is some icefall code called unsupervised.py, that I was going to use to sample CTC transcripts for use in unsupervised training. I believe sampling is more correct than taking the top-one, and will avoid it collapsing to blank.

# Some utilities for unsupervised training                                                                                                                                                                            

import random
from typing import Optional, Sequence, Tuple, TypeVar, Union, Dict, List

import math
import torch
from k2 import RaggedTensor, Fsa
from torch import nn
from torch import Tensor
import torch_iterative_sampling

Supervisions = Dict[str, torch.Tensor]


def sample_ctc_transcripts_ragged(
        ctc_output: Tensor,
        paths_per_sequence: int,
        modified_topo: bool) -> RaggedTensor:
    """                                                                                                                                                                                                               
      ctc_output: a Tensor of shape (N, T, C), i.e. (batch, time, num_symbols),                                                                                                                                       
                 containing normalized log-probs.                                                                                                                                                                     
      paths_per_sequence: The number of separately sampled paths that are requested                                                                                                                                   
                 per sequence                                                                                                                                                                                         
      modified_topo:  True if the system is using the modified CTC topology where two                                                                                                                                 
                consecutive instances of a nonzero symbol can mean either one or two                                                                                                                                  
                copies of the original symbol.                                                                                                                                                                        
                                                                                                                                                                                                                      
    Returns a RaggedTensor, on the same device as `ctc_output`, with shape (N *                                                                                                                                       
    paths_per_sequence, None), where 1st index is batch_idx * paths_per_sequence                                                                                                                                      
    + path_idx and 2nd idx is the position in the token sequence.  The returned                                                                                                                                       
    RaggedTensor will have no 0's... those will have been removed, as blanks.                                                                                                                                         
    """
    (N, T, C) = ctc_output.shape

    # The 'seq_len' arg below is something specific to the "iterative" part of                                                                                                                                        
    # torch_iterative_sampling, which has to do with "sampling without replacement";                                                                                                                                  
    # here, we don't really want to do "iterative sampling", we just want to                                                                                                                                          
    # sample from the distribution once.                                                                                                                                                                              

    probs = ctc_output.exp()
    sampled = torch_iterative_sampling.iterative_sample(probs,
                                                        num_seqs=paths_per_sequence,
                                                        seq_len=1).to(dtype=torch.int32)
    # `sampled` now has shape:                                                                                                                                                                                        
    # (N, T, paths_per_sequence, 1)                                                                                                                                                                                   
    sampled = sampled.squeeze(3).transpose(1, 2)
    # `sampled` now has shape (N, paths_per_sequence, T)                                                                                                                                                              

    # identical_mask is of shape (N, paths_per_sequence, T-1), and                                                                                                                                                    
    # contains True at each position if sampled[n,s,t] == sampled[n,s,t+1].                                                                                                                                           
    identical_mask = sampled[:,:,1:] == sampled[:,:,:-1]

    if modified_topo:
        # If we are using the modified/simplified CTC topology, it is possible for                                                                                                                                    
        # two consecutive instances of a nonzero symbol to represent either                                                                                                                                           
        # one symbol or two.  We choose either, with probability 0.5.  I think this                                                                                                                                   
        # is correct, perhaps should check though.                                                                                                                                                                    
        identical_mask = identical_mask and (torch.randn(*identical_mask.shape,
                                                         device=identical_mask.device) > 0.5)
    # The following statement replaces repeats of nonzero symbols with 0, so only the                                                                                                                                 
    # final symbol in a chain of identical, consecutive symbols will retain its                                                                                                                                       
    # nonzero value.                                                                                                                                                                                                  
    sampled[:,:,:-1].masked_fill_(identical_mask, 0)

    sampled = sampled.reshape(N * paths_per_sequence, T)

    # The shape of ragged_sampled would be the same as `sampled`.. it's regular.                                                                                                                                      
    # if you query it, though, it would come up as (N * paths_per_sequence, None).                                                                                                                                    
    ragged_sampled = RaggedTensor(sampled)

    # Remove 0's from the ragged tensor, to keep only "real" (non-blank) symbols.                                                                                                                                     
    ragged_sampled = ragged_sampled.remove_values_leq(0)

    # note: you can create the CTC graphs with k2.ctc_graph(ragged_sampled, modified={True,False})                                                                                                                    
    # You can turn into a List[List[int]] with ragged_sampled.tolist().                                                                                                                                               
    return ragged_sampled



def _test_sample_ctc_transcripts_ragged():
    for device in ['cpu', 'cuda']:
        # simple case.. N = 1, T == 2, C == 3                                                                                                                                                                         
        ctc_output = torch.Tensor( [[[ 0., 1., 0. ], [ 1., 0., 0. ] ],
                                    [[ 1., 0., 0. ], [ 0., 0., 1. ] ]]).to(device=device).log()
        r = sample_ctc_transcripts_ragged(ctc_output, paths_per_sequence=1, modified_topo=False)
        print("r = ", r)
        assert r == RaggedTensor('[[1], [2]]', dtype=torch.int32,
                                 device=device)



    for device in ['cpu', 'cuda']:
        # simple case.. N = 1, T == 3, C == 3, with repeats.                                                                                                                                                          
        # We use modified == False, so the repeats should be removed.                                                                                                                                                 
        ctc_output = torch.Tensor( [[[ 0., 1., 0. ], [0., 1., 0.], [ 1., 0., 0. ] ],
                                    [[ 1., 0., 0. ], [0., 0., 1.], [ 0., 0., 1. ] ]]).to(device=device).log()
        r = sample_ctc_transcripts_ragged(ctc_output, paths_per_sequence=1, modified_topo=False)
        print("r = ", r)
        assert r == RaggedTensor('[[1], [2]]', dtype=torch.int32,
                                 device=device)



if __name__ == "__main__":
    _test_sample_ctc_transcripts_ragged()

@zhu-han
Copy link

zhu-han commented Nov 5, 2021

Cool! Is there any specific reasons to use this repo for the sampling? I find there is similar function in Pytorch: torch.distributions.categorical.Categorical

@danpovey
Copy link
Owner Author

danpovey commented Nov 5, 2021

Oh, I didn't know about that. Then that should be fine.

@zhu-han
Copy link

zhu-han commented Nov 5, 2021

OK, thanks! I’ll try to use this sampling idea in the unsupervised training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants