In [2]:
import sys
sys.path.insert(0,'/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers')
sys.path.insert(0,'/exp/rhuang/meta/k2/k2/python')
sys.path.insert(0,'/exp/rhuang/meta/k2/temp.linux-x86_64-cpython-310/lib')
sys.path.insert(0,'/exp/rhuang/meta/icefall')

In [3]:
import torch
from torch.nn.functional import ctc_loss, log_softmax
from torch.autograd import Variable
import k2

In [None]:
def encode_supervisions(
    targets, target_lengths, input_lengths
):
    """
    Encodes Lhotse's ``batch["supervisions"]`` dict into
    a pair of torch Tensor, and a list of transcription strings or token indexes

    The supervision tensor has shape ``(batch_size, 3)``.
    Its second dimension contains information about sequence index [0],
    start frames [1] and num frames [2].

    The batch items might become re-ordered during this operation -- the
    returned tensor and list of strings are guaranteed to be consistent with
    each other.
    """
    # batch_size = targets.size(0)
    batch_size = len(targets)
    supervision_segments = torch.stack(
        (
            torch.arange(batch_size),
            torch.zeros(batch_size),
            input_lengths.cpu(),
        ),
        1,
    ).to(torch.int32)

    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]
    # import pdb; pdb.set_trace()

    # res = targets[indices].tolist()
    # res_lengths = target_lengths[indices].tolist()
    # res = [[i + 1 for i in l[:l_len]] for l, l_len in zip(res, res_lengths)]  # hard-coded for torchaudio
    res = [targets[i] for i in indices]

    return supervision_segments, res, indices

def k2_ctc_loss(log_prob, targets, input_lengths, target_lengths, reduction='none'):
    # `targets` is a list of lists

    decoding_graph = k2.ctc_graph(targets, modified=False)
    supervision_segments, texts, indices = encode_supervisions(targets, target_lengths, input_lengths)
    # print(supervision_segments, texts, indices)
    dense_fsa_vec = k2.DenseFsaVec(
        log_prob.permute(1,0,2),  # TNC->NTC
        supervision_segments,
    )

    if True:
        loss = k2.ctc_loss(
            decoding_graph=decoding_graph,
            dense_fsa_vec=dense_fsa_vec,
            reduction=reduction,
            target_lengths=target_lengths,
        )

    if False:  # Also checkout: /exp/rhuang/meta/k2/k2/python/k2/ctc_loss.py and https://github.com/k2-fsa/icefall/blob/master/icefall/decode.py
        
        # lattice = k2.intersect_dense(
        #     decoding_graph,
        #     dense_fsa_vec,
        #     10,
        # )

        lattice = k2.intersect_dense_pruned(
            decoding_graph,
            dense_fsa_vec,
            search_beam=20,  # 15
            output_beam=8,  # 6
            min_active_states=30,
            max_active_states=10000,
        )

        best_path = k2.shortest_path(lattice, use_double_scores=False)
        forward_scores = best_path.get_tot_scores(use_double_scores=False, log_semiring=True)

        loss = -forward_scores
        if reduction == "none":
            pass
        elif reduction == "sum":
            loss = loss.sum()
        else:
            assert reduction == "mean"
            loss /= target_lengths
            loss = loss.mean()

    return loss

In [21]:
# Create some example data
output_dim = 5
batch_size = 1
sequence_length = 4
log_probs = torch.randn(batch_size, sequence_length*2, output_dim).log_softmax(2).detach().requires_grad_()
targets = torch.randint(1, output_dim, (batch_size, sequence_length), dtype=torch.long)
input_lengths = torch.full((batch_size,), sequence_length*2, dtype=torch.long)
target_lengths = torch.randint(1, sequence_length-1, (batch_size,), dtype=torch.long)

In [63]:
# log_probs = [
# [[-1.6678, -1.7028, -0.9742, -1.4975, -3.5766],
# [-0.8538, -1.6253, -1.4779, -2.8868, -2.3700],
# [-2.5472, -1.2712, -1.1831, -2.2954, -1.4518],
# [-4.6883, -3.7046, -0.5502, -0.9947, -3.9358],
# [-1.3619, -1.8677, -0.6211, -3.3785, -4.0219],
# [-1.3728, -0.6234, -1.8358, -3.4717, -3.9147],
# [-1.9166, -2.9598, -0.3028, -2.9376, -4.6766],
# [-1.9206, -2.3894, -3.3004, -0.9705, -1.0613]],
# ]
# log_probs = torch.tensor(log_probs)
# log_probs = log_probs.log_softmax(2).detach().requires_grad_()

targets = torch.tensor([[2,4]], dtype=torch.long)
target_lengths = torch.tensor([2,], dtype=torch.long)

log_probs = [
[[0.9, 0.02, 0.02, 0.02, 0.02],
[0.9, 0.02, 0.02, 0.02, 0.02],
[0.8, 0.02, 0.2, 0.02, 0.02],
[0.1, 0.02, 0.7, 0.02, 0.02],
[0.9, 0.02, 0.5, 0.02, 0.02],
[0.5, 0.02, 0.02, 0.02, 0.3],
[0.9, 0.02, 0.02, 0.02, 0.02],
[0.9, 0.02, 0.02, 0.02, 0.02],
[0.9, 0.02, 0.02, 0.02, 0.02],
[0.9, 0.02, 0.02, 0.02, 0.02],
[0.9, 0.02, 0.02, 0.02, 0.02]],
]
log_probs = torch.tensor(log_probs).log()
log_probs = log_probs - torch.tensor([1.0, 0, 0, 0, 0]).view(1,1,-1)
log_probs = log_probs.log_softmax(2).detach().requires_grad_()
log_probs.exp()

tensor([[[0.8054, 0.0487, 0.0487, 0.0487, 0.0487],
         [0.8054, 0.0487, 0.0487, 0.0487, 0.0487],
         [0.5309, 0.0361, 0.3608, 0.0361, 0.0361],
         [0.0462, 0.0251, 0.8785, 0.0251, 0.0251],
         [0.3716, 0.0224, 0.5611, 0.0224, 0.0224],
         [0.3382, 0.0368, 0.0368, 0.0368, 0.5515],
         [0.8054, 0.0487, 0.0487, 0.0487, 0.0487],
         [0.8054, 0.0487, 0.0487, 0.0487, 0.0487],
         [0.8054, 0.0487, 0.0487, 0.0487, 0.0487],
         [0.8054, 0.0487, 0.0487, 0.0487, 0.0487],
         [0.8054, 0.0487, 0.0487, 0.0487, 0.0487]]], grad_fn=<ExpBackward0>)

In [37]:
print(log_probs.shape)
print(input_lengths)
print(targets)
print(target_lengths)

torch.Size([1, 8, 5])
tensor([8])
tensor([[2, 4]])
tensor([2])


In [25]:
log_probs1 = log_probs.clone()
log_probs1.retain_grad()  # Preserve the gradients for log_probs

# Compute the CTC loss
loss = ctc_loss(log_probs1.permute(1,0,2), targets, input_lengths, target_lengths)

# Backpropagate the gradients
loss.backward()

# Print the gradients with respect to log_probs
print(f"Gradients of log_probs: {log_probs1.grad}")

# # Continue to backpropagate the gradients
# # Check the gradients
# for name, param in model.named_parameters():
#     print(f"Gradients of {name}: {param.grad}")

Gradients of log_probs: tensor([[[-0.1412,  0.0607, -0.0034,  0.0746,  0.0093],
         [-0.0042,  0.0641, -0.1096,  0.0186,  0.0312],
         [-0.0046,  0.0870, -0.1938,  0.0336,  0.0779],
         [-0.0007, -0.0017, -0.1267,  0.1233,  0.0059],
         [-0.0008, -0.0159, -0.0005,  0.0114,  0.0058],
         [ 0.0366, -0.0950,  0.0422,  0.0104,  0.0059],
         [-0.1616, -0.0865,  0.2463,  0.0177, -0.0159],
         [ 0.0418,  0.0306,  0.0123,  0.1263, -0.2109]]])


In [64]:
log_probs2 = log_probs.clone()
log_probs2.retain_grad()  # Preserve the gradients for log_probs

loss_k2 = k2_ctc_loss(
    log_probs2.permute(1,0,2), 
    [t[:lt].tolist() for t, lt in zip(targets, target_lengths)],  
    input_lengths, 
    target_lengths,
    reduction="mean",
)
print(loss_k2, loss)

loss_k2.backward()

# Print the gradients with respect to log_probs
print(f"Gradients of log_probs: {log_probs2.grad}")


tensor(1.4006, grad_fn=<MeanBackward0>) tensor(2.0889, grad_fn=<MeanBackward0>)
Gradients of log_probs: tensor([[[-0.5000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.5000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.5000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, -0.5000],
         [-0.5000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.5000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]])


In [58]:
print(f"Gradients of log_probs: {log_probs2.grad}")

Gradients of log_probs: tensor([[[-0.4654,  0.0000, -0.0346,  0.0000,  0.0000],
         [-0.4441,  0.0000, -0.0550,  0.0000, -0.0009],
         [-0.1224,  0.0000, -0.3754,  0.0000, -0.0022],
         [-0.4340,  0.0000, -0.0519,  0.0000, -0.0141],
         [-0.4321,  0.0000, -0.0292,  0.0000, -0.0387],
         [-0.0595,  0.0000, -0.0035,  0.0000, -0.4370],
         [-0.4558,  0.0000, -0.0010,  0.0000, -0.0433],
         [-0.4806,  0.0000,  0.0000,  0.0000, -0.0194]]])


In [45]:
log_probs.exp(), targets

(tensor([[[0.9184, 0.0204, 0.0204, 0.0204, 0.0204],
          [0.9184, 0.0204, 0.0204, 0.0204, 0.0204],
          [0.0943, 0.0189, 0.8491, 0.0189, 0.0189],
          [0.9184, 0.0204, 0.0204, 0.0204, 0.0204],
          [0.9184, 0.0204, 0.0204, 0.0204, 0.0204],
          [0.5814, 0.0233, 0.0233, 0.0233, 0.3488],
          [0.9184, 0.0204, 0.0204, 0.0204, 0.0204],
          [0.9184, 0.0204, 0.0204, 0.0204, 0.0204]]], grad_fn=<ExpBackward0>),
 tensor([[2, 4]]))

In [54]:
loss_k2, loss

(tensor([6.9858, 6.8263], grad_fn=<ToCopyBackward0>),
 tensor(2.2232, grad_fn=<MeanBackward0>))

In [35]:
log_probs.exp().sum(dim=-1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)

In [18]:
log_probs

tensor([[[-1.6678, -1.7028, -0.9742, -1.4975, -3.5766],
         [-0.8538, -1.6253, -1.4779, -2.8868, -2.3700],
         [-2.5472, -1.2712, -1.1831, -2.2954, -1.4518],
         [-4.6883, -3.7046, -0.5502, -0.9947, -3.9358],
         [-1.3619, -1.8677, -0.6211, -3.3785, -4.0219],
         [-1.3728, -0.6234, -1.8358, -3.4717, -3.9147],
         [-1.9166, -2.9598, -0.3028, -2.9376, -4.6766],
         [-1.9206, -2.3894, -3.3004, -0.9705, -1.0613]],

        [[-1.3247, -1.6347, -1.5798, -2.3376, -1.4416],
         [-3.1774, -3.2598, -0.4258, -1.4641, -3.3417],
         [-0.8104, -2.6120, -2.0748, -2.4026, -1.3248],
         [-1.5996, -1.7569, -2.8319, -0.9441, -1.7286],
         [-1.3534, -2.1563, -2.6799, -1.1094, -1.4804],
         [-2.0603, -0.8716, -1.8151, -1.3472, -3.4582],
         [-1.5186, -1.5564, -1.3164, -2.1602, -1.6783],
         [-2.1493, -2.1961, -1.6528, -1.5585, -0.9936]]], requires_grad=True)