<a href="https://colab.research.google.com/github/davidcpage/mctc/blob/master/notebooks/01_CTC_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# default_exp sparse

# Sparse

> Sparse partition function calculations.

In [17]:
#export
from functools import partial
import numpy as np
import cupy as cp
import torch
from mctc.utils import *
from mctc.ctc import semiring, Max, Log, interleave_blanks, generate_sample_inputs, loss_pytorch, benchmark_fwd_bwd, report, compare_fwd_bwd

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### 1. Basic pytorch

In [19]:
#export
def Mv_scan_py(Ms, idx, v0, S:semiring=Log):
    T, N, C, nz = Ms.shape
    alpha = Ms.new_full((T+1, N, C), S.zero)
    alpha[0] = v0 
    for t in range(T):
        alpha[t+1] = S.sum(S.mul(Ms[t], alpha[t, :, idx]), dim=2)
    return alpha

def transpose(Ms, idx):
    T, N, C, nz = Ms.shape
    assert idx.shape == (C, nz) 
    i = idx.flatten().argsort().reshape(C, nz)
    idx_T = i // nz
    Ms_T = Ms.reshape(T, N, -1)[:, :, i]
    return Ms_T, idx_T

class _LogZ(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Ms, idx, v0, vT, S:semiring, scan):
        alpha = scan(Ms, idx, v0, S)
        ctx.save_for_backward(alpha, Ms, idx, vT)
        ctx.semiring, ctx.scan = S, scan
        return S.sum(S.mul(alpha[-1], vT), dim=1)
    
    @staticmethod
    def backward(ctx, grad):
        alpha, Ms, idx, vT = ctx.saved_tensors
        S, scan = ctx.semiring, ctx.scan
        T, N, C, nz = Ms.shape
        Ms_T, idx_T = transpose(Ms, idx)
        beta = scan(Ms_T.flip(0), idx_T, vT, S)
        g = S.mul(S.mul(Ms.reshape(T, N, -1), alpha[:-1, :, idx.flatten()]).reshape(T, N, C, nz), beta[:-1, :, :, None].flip(0))
        g = S.dsum(g.reshape(T, N, -1), dim=2).reshape(T, N, C, nz)
        return grad[None, :, None, None] * g, None, None, None, None, None 

logZ_py = partial(_LogZ.apply, S=Log, scan=Mv_scan_py)

### 2. CTC loss using sparse LogZ

NB: This is only as a test/demo - it is slower than the previous CTC loss implementation and only supports the case where all input_lengths are equal to T (although this could be fixed.)

In [29]:
#export
from torch.nn.functional import pad

def _ctc_loss(logits, targets, input_lengths, target_lengths, scan, S:semiring=Log):
    zero, one = [logits.new_full((1,), x) for x in (S.zero, S.one)]
    scores = logits.log_softmax(2)
    states = interleave_blanks(targets, blank_idx=0)
    state_scores = torch.gather(scores, 2, states.expand(scores.size(0), -1, -1))
    final_states = torch.stack([target_lengths*2-1, target_lengths*2], 1)

    T, N, Lp = state_scores.shape
    assert torch.all(input_lengths == T)

    Ms = torch.stack([
        state_scores, 
        pad(state_scores[:, :, 1:], (1, 0), value=S.zero),
        pad(torch.where(states[:, 2:] == states[:, :-2], zero.expand(T, N, Lp-2), state_scores[:, :, 2:]), (2, 0), value=S.zero)
    ], -1)

    i = torch.arange(Lp, device=device)
    rot = lambda x, n: torch.cat([x[-n:], x[:-n]])
    idx = torch.stack([i, rot(i, 1), rot(i, 2)], dim=1)

    v0 = torch.cat([one.expand(N, 1), zero.expand(N, Lp - 1)], dim=1)
    vT = zero.expand(N, Lp).clone().scatter_(1, final_states, S.one)
    
    logZ = _LogZ.apply(Ms, idx, v0, vT, S, scan)
    return -(logZ / target_lengths).mean()

ctc_loss_py = partial(_ctc_loss, scan=Mv_scan_py)

In [33]:
sample_inputs = logits, targets, input_lengths, target_lengths = generate_sample_inputs(T_min=500, T_max=500, N=128, C=20, L_min=80, L_max=100)
fwd, bwd = compare_fwd_bwd(loss_pytorch, ctc_loss_py, *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 9.20e-08


### 3. Cupy

In [None]:
%%writefile cuda/sparse.cu
__device__ __forceinline__ FLOAT max2(FLOAT a, FLOAT b) {return a > b ? a : b;}
__device__ __forceinline__ FLOAT logsumexp2(FLOAT a, FLOAT b) {return a > b ? log1p(exp(b - a)) + a : log1p(exp(a - b)) + b;}
__device__ __forceinline__ FLOAT add(FLOAT a, FLOAT b) {return a + b;}

extern "C" __global__ void sparse_Mv_scan(
    FLOAT* __restrict__ alpha,
    const FLOAT* __restrict__ Ms,  
    const int* __restrict__ idx,
    int T, int N, int C, int nz
) {
    int bx = blockIdx.x, tx = threadIdx.x;
    if (tx >= C) return;
    extern __shared__ FLOAT smem[];
    
    FLOAT a = alpha[bx * C + tx];
    for (int t = 0; t < T; t++) {
        FLOAT *buf = smem + (t % 2) * blockDim.x;
        buf[tx] = a; __syncthreads();      
        int i = ((t * N + bx) * C) + tx;
        a = MUL(buf[idx[tx * nz]], Ms[i * nz]);
        for (int j = 1; j < nz; j++) {
            a = ADD(a, MUL(buf[idx[tx * nz + j]], Ms[i * nz + j]));
        }
        alpha[i + N * C] = a;
    }
}

In [41]:
#export
cupy_funcs = {
    (torch.float32, Log): load_cupy_func('cuda/sparse.cu', 'sparse_Mv_scan', FLOAT='float',  ADD='logsumexp2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float64, Log): load_cupy_func('cuda/sparse.cu', 'sparse_Mv_scan', FLOAT='double',  ADD='logsumexp2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float32, Max): load_cupy_func('cuda/sparse.cu', 'sparse_Mv_scan', FLOAT='float',  ADD='max2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float64, Max): load_cupy_func('cuda/sparse.cu', 'sparse_Mv_scan', FLOAT='double',  ADD='max2', MUL='add', ZERO='{:E}'.format(Log.zero)),
}

def Mv_scan_cupy(Ms, idx, v0, S:semiring):
    T, N, C, nz = Ms.shape
    assert idx.shape == (C, nz) 
    alpha = Ms.new_full((T+1, N, C), S.zero)
    alpha[0] = v0
    with cp.cuda.Device(Ms.device.index):
        cupy_funcs[(Ms.dtype, S)](grid=(N, 1, 1), block=(C, 1, 1), shared_mem=2*8*C,
               args=(alpha.data_ptr(), Ms.data_ptr(), idx.to(dtype=torch.int, device=Ms.device).data_ptr(), T, N, C, nz))
    return alpha

logZ = partial(_LogZ.apply, S=Log, scan=Mv_scan_cupy)

ctc_loss = partial(_ctc_loss, scan=Mv_scan_cupy)

In [42]:
sample_inputs = logits, targets, input_lengths, target_lengths = generate_sample_inputs(T_min=500, T_max=500, N=128, C=20, L_min=80, L_max=100)
fwd, bwd = compare_fwd_bwd(loss_pytorch, ctc_loss, *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 1.00e-07


In [44]:
report(benchmark_fwd_bwd(loss_pytorch, *sample_inputs))

fwd: 1.71ms (1.64-1.81ms)
bwd: 4.47ms (4.39-4.56ms)
tot: 6.18ms (6.02-6.36ms)


In [43]:
report(benchmark_fwd_bwd(ctc_loss, *sample_inputs))

fwd: 9.26ms (9.05-10.20ms)
bwd: 21.41ms (20.89-24.15ms)
tot: 30.67ms (29.94-34.32ms)
