In [3]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [42]:
class CTCLoss(torch.nn.Module):
    def __init__(self):
        super(CTCLoss, self).__init__()
        
    def show_alignment(self, log_alpha):
        plt.imshow(log_alpha.cpu().data.numpy()); plt.show()
    
    def compute_log_alpha(self, net_out, y, blank):
        net_out = torch.nn.functional.log_softmax(net_out, dim=1) # not needed if net_out already logsoftmax, but warp-ctc does this
        T = len(net_out)
        U = len(y)
        S = 2*U + 1
        y_prime = [] # [_, y1, _, y2, _, y3, _]
        for i in range(S):
            label = blank if (i+1) % 2 else y[int(i/2)].item()
            y_prime.append(label)
            
        log_alphas = []
        for t in range(T):
            eps = 1e-30
            log_alpha_t = torch.log(torch.zeros(S) + eps) # w/o eps, gradients will be nan
            
            if t == 0:
                log_alpha_t[0] = net_out[0,blank]
                log_alpha_t[1] = net_out[0,y_prime[1]]
            else:
                log_alpha_t_1 = log_alphas[-1]
                for s in range(S):
                    if s == 0:
                        log_alpha_t[s] = log_alpha_t_1[s] + net_out[t,y_prime[s]]
                        
                    if s == 1:
                        log_alpha_t[s] = torch.logsumexp(log_alpha_t_1[s-1:s+1], dim=0) + net_out[t,y_prime[s]]
                        
                    if s > 1:                    
                        if y_prime[s] == blank or y_prime[s-2] == y_prime[s]:
                            log_alpha_t[s] = torch.logsumexp(log_alpha_t_1[s-1:s+1], dim=0) + net_out[t,y_prime[s]]
                        else:
                            log_alpha_t[s] = torch.logsumexp(log_alpha_t_1[s-2:s+1], dim=0) + net_out[t,y_prime[s]]

            log_alphas.append(log_alpha_t)
            
        log_alpha = torch.stack(log_alphas)
        return log_alpha
    
    def forward(self,log_probs,targets,input_lengths,target_lengths,reduction="none",blank=0):
        """
        log_probs: FloatTensor (max(input_lengths), N, #labels)
        targets: LongTensor (N, max(target_lengths))
        input_lengths: LongTensor (N)
        target_lengths: LongTensor (N)
        reduction: "none", "avg"
        blank: int
        """
        batch_size = len(input_lengths)
        losses = []
        for i in range(0, batch_size):
            net_out = log_probs[:input_lengths[i], i, :]
            y = targets[i, :target_lengths[i]]
            log_alpha = self.compute_log_alpha(net_out, y, blank) # shape (T, 2U + 1)
            loss = -torch.logsumexp(log_alpha[-1, -2:], dim=0) 
            losses.append(loss)
        losses = torch.stack(losses)
        if reduction=="none": return losses
        if reduction=="avg": return losses.mean()
        

In [43]:
num_labels = 5
blank_index = num_labels-1 # last output = blank
batch_size = 1
pad = -1
T = torch.LongTensor([20])
U = torch.LongTensor([8])
y = torch.randint(low=0,high=num_labels-1,size=(U[0],)).unsqueeze(0).long()
print(y)

net_out = torch.randn(max(T), batch_size, num_labels).log_softmax(2).detach().requires_grad_()


tensor([[0, 0, 2, 2, 1, 3, 1, 1]])


In [44]:
ctc_loss = torch.nn.functional.ctc_loss
loss = ctc_loss(log_probs=net_out,targets=y,input_lengths=T,target_lengths=U,reduction="none",blank=blank_index)

ctc_loss = CTCLoss()
loss_ = ctc_loss(log_probs=net_out,targets=y,input_lengths=T,target_lengths=U,reduction="none",blank=blank_index)

print("my implementation:", loss_)
print("warp-ctc:", loss)

loss_.mean().backward()
print("my grad:")
# print(net_out.grad[:,1,:])
# print(net_out.grad[:,1,:].sum(1))
print(net_out.grad)
del net_out.grad

loss.mean().backward()
print("warp-ctc grad:")
print(net_out.grad)
# print(net_out.grad[:,1,:])
# print(net_out.grad[:,1,:].sum(1))
del net_out.grad

my implementation: tensor([18.9827], grad_fn=<StackBackward>)
warp-ctc: tensor([18.9827], grad_fn=<CtcLossBackward>)
my grad:
tensor([[[-0.5733,  0.0413,  0.3169,  0.4162, -0.2011]],

        [[-0.0628,  0.0674,  0.0968,  0.5738, -0.6752]],

        [[-0.3125,  0.1726,  0.1424,  0.0644, -0.0669]],

        [[-0.1248,  0.0231,  0.0550,  0.1372, -0.0904]],

        [[-0.1175,  0.2925,  0.0111,  0.0928, -0.2790]],

        [[-0.2531,  0.4388, -0.3150,  0.3021, -0.1728]],

        [[ 0.1731,  0.0590, -0.1972,  0.0651, -0.1001]],

        [[ 0.2604,  0.1783, -0.1815,  0.0610, -0.3181]],

        [[ 0.2900,  0.0684, -0.3567,  0.0903, -0.0921]],

        [[ 0.4014,  0.0370, -0.4168,  0.1980, -0.2196]],

        [[ 0.0517,  0.0293, -0.0518,  0.0275, -0.0567]],

        [[ 0.0373, -0.0764, -0.1070,  0.1778, -0.0317]],

        [[ 0.0502, -0.1844,  0.0083,  0.1162,  0.0098]],

        [[ 0.2829, -0.1472,  0.1663, -0.2711, -0.0308]],

        [[ 0.0479, -0.0038,  0.0854, -0.1542,  0.0247]],

    