In [211]:
from typing import Sequence

import k2
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Decoding multiple test phrases for a single keyword

In [212]:
batch_size = 6
padded_seq_len = 20
actual_seq_lens = [20, 15, 15, 20, 10, 13]
keyword_len = 5

# probs = torch.randint(1, 20, (batch_size, padded_seq_len, keyword_len))
probs = torch.randn((batch_size, padded_seq_len, keyword_len))
for i, actual_seq_len in enumerate(actual_seq_lens):
    probs[i, actual_seq_len:] = 0
# probs = probs.log()
probs=probs.to(0)
print(probs.shape, probs.device)

torch.Size([6, 20, 5]) cuda:0


In [213]:
def make_dense_fsa(log_scores, seq_lens):
    N = log_scores.shape[0]

    seq_idcs = torch.arange(N, dtype=torch.int32)
    start_frames = torch.zeros(N, dtype=torch.int32)
    durations = torch.tensor(seq_lens, dtype=torch.int32)

    durations_sorted, indices_sorted = torch.sort(durations, descending=True)
    seq_idcs_sorted = seq_idcs[indices_sorted]
    start_frames_sorted = start_frames[indices_sorted]
    log_scores_sorted = log_scores[indices_sorted]

    supervision_segments = torch.stack([seq_idcs_sorted, start_frames_sorted, durations_sorted], dim=1)

    fsa = k2.DenseFsaVec(log_scores_sorted, supervision_segments)
    return fsa

In [214]:
dense_fsa = make_dense_fsa(probs, actual_seq_lens)
dense_fsa.dim0()

6

In [215]:
def get_query_fsa_str(keyword_len):
    if type(keyword_len) is torch.Tensor:
        keyword_len = keyword_len.item()
    query_fsa_str = ""
    for i in range(keyword_len-1):
        curr = str(i)+" "
        nxt = str(i+1)+" "
        #           src_state   dest_state  label   score
        self_arc=   curr +      curr +      curr +  "0.0\n"
        arc=        curr +      nxt +       curr +  "0.0\n"
        query_fsa_str+=self_arc
        query_fsa_str+=arc

    penult = str(keyword_len-1)+" "
    final = str(keyword_len)+" "
    #                   src_state   dest_state  label   score
    final_arc=          penult +    final +     "-1 " + "0.0\n"
    query_fsa_str +=  final_arc
    query_fsa_str += final
    return query_fsa_str
query_fsa_str = get_query_fsa_str(keyword_len)
print(query_fsa_str)

0 0 0 0.0
0 1 0 0.0
1 1 1 0.0
1 2 1 0.0
2 2 2 0.0
2 3 2 0.0
3 3 3 0.0
3 4 3 0.0
4 5 -1 0.0
5 


In [216]:
def make_query_fsa(keyword_len, batch_size=None):
    query_fsa_str = get_query_fsa_str(keyword_len)
    fsa = k2.Fsa.from_str(query_fsa_str).to(device)
    labels = list(range(keyword_len))
    # fsa = k2.linear_fsa(labels, device)
    # fsa = k2.add_epsilon_self_loops(fsa)

    if batch_size is not None:
        fsa = k2.create_fsa_vec([fsa]*batch_size)
    return fsa

query_fsa = make_query_fsa(keyword_len, batch_size)
query_fsa.shape

(6, None, None)

In [217]:
def decode_single_keyword(probs, seq_lens):
    dense_fsa = make_dense_fsa(probs, seq_lens)
    keyword_len = probs.shape[-1]
    query_fsa = make_query_fsa(keyword_len, batch_size)

    lattice = k2.intersect_dense(query_fsa, dense_fsa, output_beam=10.0)
    best_path = k2.shortest_path(lattice, use_double_scores=True)
    score = best_path.get_tot_scores(use_double_scores=True, log_semiring=True)
    labels = best_path.labels
    return score, labels

decode_single_keyword(probs, actual_seq_lens)

(tensor([0.1785, 6.9284, 5.5624, 7.6326, 3.0374, 5.2211], device='cuda:0',
        dtype=torch.float64),
 tensor([ 0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
          3,  3, -1,  0,  1,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
          3,  3,  3,  3,  3, -1,  0,  0,  0,  0,  0,  1,  2,  2,  2,  2,  2,  2,
          2,  2,  3, -1,  0,  1,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,
          3, -1,  0,  0,  0,  0,  0,  1,  1,  1,  2,  3,  3,  3,  3, -1,  0,  0,
          0,  1,  1,  1,  1,  2,  2,  3, -1], device='cuda:0',
        dtype=torch.int32))

(tensor([ 7.8563, 20.5448,  4.5942,  9.5775,  3.4156,  1.4157], device='cuda:0',
        dtype=torch.float64),
 tensor([ 0,  0,  0,  0,  0,  1,  2,  3,  0,  0,  0,  4,  0,  0,  0,  0,  0,  0,
          0,  0, -1,  0,  0,  1,  0,  0,  0,  0,  0,  2,  3,  4,  0,  0,  0,  0,
          0,  0,  0,  0,  0, -1,  0,  1,  2,  3,  4,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0, -1,  0,  0,  0,  0,  0,  0,  1,  0,  0,  2,  0,  3,  0,  4,
          0, -1,  0,  1,  0,  0,  2,  0,  3,  0,  0,  4,  0,  0,  0, -1,  0,  0,
          1,  0,  2,  3,  0,  4,  0,  0, -1], device='cuda:0',
        dtype=torch.int32))

## Batched keyword decoding

In [228]:
def prepare_query_graph(keyword_lens, batch_size):
    """
    Creates an FSA Vector with `batch_size` repetitions
    of each query graph.

    Arguments:
        keyword_lens: list of lengths of each query graph
        batch_size: number of test phrases in batch
    Returns:
        FSA vector of queries
    """
    expanded_lens = torch.tensor(keyword_lens)\
        .unsqueeze(1)\
        .repeat(1, batch_size)\
        .view(-1)\
        .to(torch.int32)
    fsa_list = []
    for keyword_len in expanded_lens:
        fsa_list.append(make_query_fsa(keyword_len))
    query_graph = k2.create_fsa_vec(fsa_list)
    return query_graph

keyword_lens = [10, 5, 9]
padded_keyword_len = max(keyword_lens)
num_keywords = len(keyword_lens)

query_graph = prepare_query_graph(keyword_lens, batch_size)
query_graph.shape

(18, None, None)

In [231]:
def prepare_dense_fsa_batch(prob_matrices: torch.Tensor, seq_lens: Sequence[int]):
    """
    Arguments:
        prob_matrices: tensor of shape K*T*W_k*W_t, where K is the number of queries,
            T the number of test phrases, W_k the number of padded windows in each
            query and W_t the number of padded windows in each test phrase
    Returns:
        Dense FSA of query probabilities
    """
    # prob_matrices needs to be reshaped from K*T*W_k*W_t
    # to (K*T)*W_t*W_k
    probs_flattened = prob_matrices.flatten(0,1)
    probs_transposed = probs_flattened.transpose(1,2)
    probs_transposed = probs_transposed.to(device)
    if type(seq_lens) is torch.Tensor:
        seq_lens = seq_lens.tolist()
    new_seq_lens=seq_lens*num_keywords
    return make_dense_fsa(probs_transposed, new_seq_lens)


prob_matrices = torch.randn((num_keywords, batch_size, padded_keyword_len, padded_seq_len))
dense_fsa = prepare_dense_fsa_batch(prob_matrices, actual_seq_lens)
dense_fsa.dim0(), dense_fsa.device

(18, device(type='cuda', index=0))

In [232]:
def decode_keyword_batch(prob_matrices, keyword_lens, seq_lens):
    batch_size = prob_matrices.shape[1]
    query_fsa = prepare_query_graph(keyword_lens, batch_size)
    dense_fsa = prepare_dense_fsa_batch(prob_matrices, seq_lens)

    lattice = k2.intersect_dense(query_fsa, dense_fsa, output_beam=10.0)
    best_path = k2.shortest_path(lattice, use_double_scores=True)
    score = best_path.get_tot_scores(use_double_scores=True, log_semiring=True)
    labels = best_path.labels
    return score, labels

In [234]:
scores, labels = decode_keyword_batch(prob_matrices, keyword_lens, actual_seq_lens)
scores.shape, labels.shape

(torch.Size([18]), torch.Size([297]))