In [330]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [383]:
class CTCLoss(torch.nn.Module):
    def __init__(self):
        super(CTCLoss, self).__init__()
        
    def show_alignment(self, log_alpha):
        plt.imshow(log_alpha.cpu().data.numpy()); plt.show()
    
    def compute_log_alpha(self, net_out, y, blank):
        net_out = torch.nn.functional.log_softmax(net_out, dim=1) # not needed if net_out already logsoftmax, but warp-ctc does this
        T = len(net_out)
        U = len(y)
        S = 2*U + 1
        y_prime = []
        for i in range(S):
            label = blank if (i+1) % 2 else y[int(i/2)].item()
            y_prime.append(label)
            
        log_alphas = []
        for t in range(T):
            log_alpha_t = torch.log(torch.zeros(S))
            
            if t == 0:
                log_alpha_t[0] = net_out[0,blank]
                log_alpha_t[1] = net_out[0,y_prime[1]]
            else:
                log_alpha_t_1 = log_alphas[-1]
                for s in range(S):
                    if s == 0:
                        log_alpha_t[s] = log_alpha_t_1[s] + net_out[t,y_prime[s]]
                        
                    if s == 1:
                        log_alpha_t[s] = torch.logsumexp(log_alpha_t_1[s-1:s+1], dim=0) + net_out[t,y_prime[s]]
                        
                    if s > 1:                    
                        if y_prime[s] == blank or y_prime[s-2] == y_prime[s]:
                            log_alpha_t[s] = torch.logsumexp(log_alpha_t_1[s-1:s+1], dim=0) + net_out[t,y_prime[s]]
                        else:
                            log_alpha_t[s] = torch.logsumexp(log_alpha_t_1[s-2:s+1], dim=0) + net_out[t,y_prime[s]]

            log_alphas.append(log_alpha_t)
            
        log_alpha = torch.stack(log_alphas)
        return log_alpha
    
    def forward(self,log_probs,targets,input_lengths,target_lengths,reduction="none",blank=0):
        """
        log_probs: FloatTensor (max(input_lengths), N, #labels)
        targets: LongTensor (N, max(target_lengths))
        input_lengths: LongTensor (N)
        target_lengths: LongTensor (N)
        reduction: "none", "avg"
        blank: int
        """
        batch_size = len(input_lengths)
        losses = []
        for i in range(0, batch_size):
            net_out = log_probs[:input_lengths[i], i, :]
            y = targets[i, :target_lengths[i]]
            log_alpha = self.compute_log_alpha(net_out, y, blank) # shape (T, 2U + 1)
            loss = -torch.logsumexp(log_alpha[-1, -2:], dim=0) 
            losses.append(loss)
        losses = torch.stack(losses)
        if reduction=="none": return losses
        if reduction=="avg": return losses.mean()
        

In [386]:
num_labels = 5
blank_index = num_labels-1 # last output = blank
batch_size = 1
pad = -1
T = torch.LongTensor([20])
U = torch.LongTensor([8])
y = torch.randint(low=0,high=num_labels-1,size=(U[0],)).unsqueeze(0).long()
print(y)

net_out = torch.randn(max(T), batch_size, num_labels).log_softmax(2).detach().requires_grad_()

tensor([[1, 0, 1, 3, 1, 1, 0, 3]])


In [387]:
ctc_loss = torch.nn.functional.ctc_loss
loss = ctc_loss(log_probs=net_out,targets=y,input_lengths=T,target_lengths=U,reduction="none",blank=blank_index)

ctc_loss = CTCLoss()
loss_ = ctc_loss(log_probs=net_out,targets=y,input_lengths=T,target_lengths=U,reduction="none",blank=blank_index)

print("my implementation:", loss_)
print("warp-ctc:", loss)

loss_.mean().backward()
print("my grad:")
# print(net_out.grad[:,1,:])
# print(net_out.grad[:,1,:].sum(1))
print(net_out.grad)
del net_out.grad

loss.mean().backward()
print("warp-ctc grad:")
print(net_out.grad)
# print(net_out.grad[:,1,:])
# print(net_out.grad[:,1,:].sum(1))
del net_out.grad

my implementation: tensor([14.5260], grad_fn=<StackBackward>)
warp-ctc: tensor([14.5260], grad_fn=<CtcLossBackward>)
my grad:
tensor([[[ 0.5483, -0.1575,  0.0575,  0.2400, -0.6883]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,     nan,     nan]],

        [[ 0.1933, -0.1128,  0.0705, -0.1231, -0.0279]],

        [[ 0.0573, -0.0851,  0.1696, -0.0283, -0.1135]],

        [[ 0.1816, -0.0052,  0.0126, -0.0033, -0.1858]],

        [[ 0.0741, -0.0181,  0.0939,  0.0145, -0.1644]],

        [[ 0.0278, -0.4771,  0.4687,  0.1199, -0.1394]],

        [[ 0.0080, -0.0842,  0.0300,  0.1084, -0.0622]],

        [[ 0.0185, -0.2975,  0.1823,  0.4725, -0.3758]],

    