<a href="https://colab.research.google.com/github/davidcpage/mctc/blob/master/notebooks/01_CTC_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# default_exp ctc_simple

# CTC loss simple

> A simplified CTC loss for decoding lattices with only two options stay/move. This can be used for decoding without collapsing of repeats.

In [None]:
#export
import numpy as np
import cupy as cp
import torch
import torch.nn as nn
from collections import namedtuple
from mctc.utils import *

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Preliminaries

Generate a test example:

In [None]:
#export
def generate_sample_inputs(T, N, L_min, L_max, device=device):
    """
    Args:
        T: number of time steps
        N: batch size
        L_min, L_max: bounds on target length
    """
    stay_scores = torch.rand(T, N, L_max, device=device, requires_grad=True)
    move_scores = torch.rand(T, N, L_max-1, device=device, requires_grad=True)
    target_lengths = torch.randint(L_min, L_max+1, (N,), device=device)
    return stay_scores, move_scores, target_lengths

In [None]:
sample_inputs = stay_scores, move_scores, target_lengths = generate_sample_inputs(T=800, N=64, L_min=330, L_max=500)

## Loss implementations

### 1. Basic pytorch

Here's a straightforward implementation in pytorch in logspace.


In [None]:
#export
from torch.nn.functional import pad
from mctc.ctc import Log, semiring

def logZ_fwd(stay_scores, move_scores, target_lengths, S=Log):
    T, N, L = stay_scores.shape
    alpha_0 = stay_scores.new_full((N, L), S.zero); alpha_0[:, 0] = S.one
    beta_T = stay_scores.new_full((N, L), S.zero); beta_T[torch.arange(N), target_lengths - 1] = S.one
    move_scores = pad(move_scores, (1, 0), value=S.zero)
    a = pad(alpha_0, (1, 0), value=S.zero)
    for t in range(0, stay_scores.size(0)):
        a[:, 1:] = S.sum(torch.stack([
            S.mul(stay_scores[t], a[:, 1:]),
            S.mul(move_scores[t], a[:, :-1])
        ]), dim=0)    
    return S.sum(S.mul(a[:, 1:], beta_T), dim=1)

In [None]:
res = logZ_fwd(*sample_inputs)
res

tensor([970.6453, 971.6895, 982.3431, 972.2327, 976.6527, 977.2905, 979.2094,
        979.4549, 984.2569, 977.5510, 969.8797, 981.0777, 982.6025, 981.8260,
        975.4854, 980.4288, 983.0338, 984.4608, 975.9901, 973.8077, 983.5836,
        984.6028, 981.6602, 980.5414, 976.8457, 980.7314, 966.9128, 968.4935,
        982.5735, 983.4719, 982.3672, 981.2139, 968.8883, 981.0426, 976.7378,
        976.1961, 979.2068, 981.1317, 977.9713, 961.0184, 971.9141, 969.5038,
        979.3978, 978.5461, 982.9652, 964.3593, 980.6489, 984.6378, 979.3281,
        984.9939, 978.5788, 961.8087, 972.8093, 980.8213, 970.1132, 981.3785,
        975.8100, 964.0969, 982.7523, 972.8085, 976.3800, 984.9760, 970.7934,
        982.0917], device='cuda:0', grad_fn=<LogsumexpBackward>)

In [None]:
#report(benchmark_fwd_bwd((lambda *x: logZ_fwd(*x).sum()), *sample_inputs))

### 2. Pytorch with grad

In [None]:
#export
def _simple_lattice_fwd_bwd(alpha, beta_T, beta_stay, beta_move, stay_scores, move_scores, S=Log):
    T = alpha.size(0) - 1
    move_scores = pad(move_scores, (1, 1), value=S.zero)
    a = pad(alpha[0], (1, 0), value=S.zero)
    for t in range(0, T):
        a[:, 1:] = S.sum(torch.stack([
            S.mul(stay_scores[t], a[:, 1:]),
            S.mul(move_scores[t, :, :-1], a[:, :-1])
        ]), dim=0)
        alpha[t+1] = a[:, 1:]
    
    b = pad(beta_T, (0, 1), value=S.zero)
    for t in range(T, 0, -1):
        beta_stay[t-1] = S.mul(b[:, :-1], stay_scores[t - 1])
        beta_move[t-1] = S.mul(b[:, 1:], move_scores[t - 1, :, 1:])
        b[:, :-1] = S.sum(torch.stack([beta_stay[t-1], beta_move[t-1]]), dim=0)

def dot(x, y, S=Log, dim=-1):
    return S.sum(S.mul(x, y), dim=dim) 

class LogZ(torch.autograd.Function):
    @staticmethod
    def forward(ctx, stay_scores, move_scores, target_lengths, fwd_bwd_impl):
        S = Log
        T, N, L = stay_scores.shape
        
        alpha = stay_scores.new_full((T + 1, N, L), S.zero) 
        alpha[0, :, 0] = S.one 
    
        beta_stay = stay_scores.new_full((T, N, L), S.zero)
        beta_move = stay_scores.new_full((T, N, L), S.zero)
        beta_T = stay_scores.new_full((N, L), S.zero) 
        beta_T[torch.arange(N), target_lengths - 1] = S.one
        
        fwd_bwd_impl(alpha, beta_T, beta_stay, beta_move, stay_scores, move_scores, S) 
        
        g = torch.softmax(torch.cat([S.mul(alpha[:-1], beta_stay), S.mul(alpha[:-1], beta_move)], dim=2), dim=2) #express softmax in terms of S?
        
        ctx.save_for_backward(g.reshape(T, N, 2, L))
        return dot(alpha[-1], beta_T, S)

    @staticmethod
    def backward(ctx, grad):
        g = ctx.saved_tensors[0] * grad[None, :, None, None]
        return g[:, :, 0], g[:, :, 1, :-1], None, None

def logZ_py(stay_scores, move_scores, target_lengths):
    return LogZ.apply(stay_scores, move_scores, target_lengths, _simple_lattice_fwd_bwd)

In [None]:
#export
mean = lambda f: (lambda *xs: f(*xs).mean())

In [None]:
fwds, bwds = compare_fwd_bwd(float64(mean(logZ_fwd)), float64(mean(logZ_py)), *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 5.82e-11


### 3. Cupy

In [None]:
%%writefile cuda/ctc_simple.cu
__device__ __forceinline__ FLOAT max2(FLOAT a, FLOAT a1) {
    return a > a1 ? a : a1; 
}

__device__ __forceinline__ FLOAT logsumexp2(FLOAT a, FLOAT a1) {
    FLOAT maxa = max2(a, a1); 
    return maxa + log(exp(a-maxa) + exp(a1-maxa));
}

__device__ __forceinline__ FLOAT add(FLOAT a, FLOAT b) {return a + b;}
__device__ __forceinline__ FLOAT mul(FLOAT a, FLOAT b) {return a * b;}

extern "C" __global__ void fwd_bwd_logspace(
    FLOAT* __restrict__ alpha, FLOAT* __restrict__ beta_T,
    FLOAT* __restrict__ beta_stay, FLOAT* __restrict__ beta_move, 
    const FLOAT* __restrict__ stay_scores, const FLOAT* __restrict__ move_scores,
    int T, int N, int L
) {
    int bx = blockIdx.x, tx = threadIdx.x;
    if (tx >= L) return;
    extern __shared__ FLOAT smem[];
    if (blockIdx.y == 0) {
        FLOAT a = ZERO, a1 = ZERO;
        a = alpha[bx * L + tx];
        for (int t = 0; t < T; t++) {
            FLOAT *buf = smem + (t % 2) * blockDim.x;
            buf[tx] = a; __syncthreads(); 
            if (tx > 0) {a1 = MUL(move_scores[(t * N + bx) * (L - 1) + tx - 1], buf[tx - 1]);}
            a = SUM(MUL(stay_scores[(t * N + bx) * L + tx], a), a1);
            alpha[((t + 1) * N + bx) * L + tx] = a;
        }
    }
    else {
        FLOAT b = ZERO, b1 = ZERO;
        b = beta_T[bx * L + tx];
        for (int t = T; t > 0; t--) {
            FLOAT *buf = smem + (t % 2) * blockDim.x;
            buf[tx] = b; __syncthreads();
            if (tx < L - 1) {
                b1 = MUL(buf[tx+1], move_scores[(((t - 1) * N + bx) * (L - 1)) + tx]);
                beta_move[((t - 1) * N + bx) * (L - 1) + tx] = b1;
            }
            b = beta_stay[((t - 1) * N + bx) * L + tx] = MUL(b, stay_scores[(((t - 1) * N + bx) * L) + tx]);
            b = SUM(b, b1);
        }
    }
  }

In [None]:
#export
from mctc.utils import *
import cupy as cp

cupy_funcs = {
    (torch.float32, Log): load_cupy_func('cuda/ctc_simple.cu', 'fwd_bwd_logspace', FLOAT='float',  SUM='logsumexp2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float64, Log): load_cupy_func('cuda/ctc_simple.cu', 'fwd_bwd_logspace', FLOAT='double', SUM='logsumexp2', MUL='add', ZERO='{:E}'.format(Log.zero)),
}

def _simple_lattice_fwd_bwd_cupy(alpha, beta_T, beta_stay, beta_move, stay_scores, move_scores, S:semiring):
    T, N, L = stay_scores.shape
    with cp.cuda.Device(stay_scores.device.index):
        cupy_funcs[(stay_scores.dtype, S)](grid=(N, 2, 1), block=(L, 1, 1), shared_mem=2*8*L,
               args=(alpha.data_ptr(), beta_T.data_ptr(), beta_stay.data_ptr(), beta_move.data_ptr(), 
                     stay_scores.data_ptr(), move_scores.data_ptr(), T, N, L))

def logZ_cupy(stay_scores, move_scores, target_lengths):
    return LogZ.apply(stay_scores, move_scores, target_lengths, _simple_lattice_fwd_bwd_cupy)

In [None]:
fwds, bwds = compare_fwd_bwd(float64(mean(logZ_py)), float64(mean(logZ_cupy)), *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 1.27e-02


In [None]:
report(benchmark_fwd_bwd(mean(logZ_cupy), *sample_inputs, nloops=100))

fwd: 19.44ms (19.22-19.71ms)
bwd: 7.28ms (7.22-7.39ms)
tot: 26.72ms (26.45-27.00ms)
