<a href="https://colab.research.google.com/github/davidcpage/seqdist/blob/master/notebooks/01_CTC_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# default_exp sparse

# Sparse

> Sparse partition function calculations.

In [None]:
#export
from functools import partial, lru_cache as cache
import numpy as np
import cupy as cp
import torch

from seqdist.core import semiring, Max, Log 
from seqdist.utils import *
from seqdist.ctc import interleave_blanks, generate_sample_inputs, loss_pytorch, benchmark_fwd_bwd, report, compare_fwd_bwd

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### 1. Basic pytorch

In [None]:
#export
def Mv_scan_py(Ms, idx, v0, S:semiring=Log):
    T, N, C, nz = Ms.shape
    alpha = Ms.new_full((T+1, N, C), S.zero)
    alpha[0] = v0 
    for t in range(T):
        alpha[t+1] = S.sum(S.mul(Ms[t], alpha[t, :, idx]), dim=2)
    return alpha

class _LogZ_scan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Ms, idx, v0, vT, S:semiring, scan):
        alpha = scan(Ms, idx, v0, S)
        ctx.save_for_backward(alpha, Ms, idx, vT)
        ctx.semiring, ctx.scan = S, scan
        return S.sum(S.mul(alpha[-1], vT), dim=1)
    
    @staticmethod
    def backward(ctx, grad):
        alpha, Ms, idx, vT = ctx.saved_tensors
        S, scan = ctx.semiring, ctx.scan
        T, N, C, nz = Ms.shape
        idx_T = idx.flatten().argsort().reshape(*idx.shape) #transpose
        Ms_T = Ms.reshape(T, N, -1)[:, :, idx_T]
        beta = scan(Ms_T.flip(0), idx_T // nz, vT, S)
        g = S.mul(S.mul(Ms.reshape(T, N, -1), alpha[:-1, :, idx.flatten()]).reshape(T, N, C, nz), beta[:-1, :, :, None].flip(0))
        g = S.dsum(g.reshape(T, N, -1), dim=2).reshape(T, N, C, nz)
        return grad[None, :, None, None] * g, None, None, None, None, None 

def logZ_scan_py(Ms, idx, v0, vT, S:semiring):
    return _LogZ_scan.apply(Ms, idx, v0, vT, S, Mv_scan_py)

### 2. CTC loss using sparse LogZ

NB: This is only as a test/demo - it is slower than the previous CTC loss implementation and only supports the case where all input_lengths are equal to T (although this could be fixed.)

In [None]:
#export
from torch.nn.functional import pad

def _ctc_loss(logits, targets, input_lengths, target_lengths, logZ_impl, S:semiring=Log):
    zero, one = [logits.new_full((1,), x) for x in (S.zero, S.one)]
    scores = logits.log_softmax(2)
    states = interleave_blanks(targets, blank_idx=0)
    state_scores = torch.gather(scores, 2, states.expand(scores.size(0), -1, -1))
    final_states = torch.stack([target_lengths*2-1, target_lengths*2], 1)

    T, N, Lp = state_scores.shape
    assert torch.all(input_lengths == T)

    Ms = torch.stack([
        state_scores, 
        pad(state_scores[:, :, 1:], (1, 0), value=S.zero),
        pad(torch.where(states[:, 2:] == states[:, :-2], zero.expand(T, N, Lp-2), state_scores[:, :, 2:]), (2, 0), value=S.zero)
    ], -1)

    i = torch.arange(Lp, device=device)
    rot = lambda x, n: torch.cat([x[-n:], x[:-n]])
    idx = torch.stack([i, rot(i, 1), rot(i, 2)], dim=1)

    v0 = torch.cat([one.expand(N, 1), zero.expand(N, Lp - 1)], dim=1)
    vT = zero.expand(N, Lp).clone().scatter_(1, final_states, S.one)
    
    logZ = logZ_impl(Ms, idx, v0, vT, S)
    return -(logZ / target_lengths).mean()

ctc_loss_scan_py = partial(_ctc_loss, logZ_impl=logZ_scan_py)

In [None]:
sample_inputs = logits, targets, input_lengths, target_lengths = generate_sample_inputs(T_min=500, T_max=500, N=128, C=20, L_min=80, L_max=100)
fwd, bwd = compare_fwd_bwd(loss_pytorch, ctc_loss_scan_py, *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 8.51e-08


### 3. Cupy

In [None]:
%%writefile cuda/sparse_scan.cu
__device__ __forceinline__ FLOAT max2(FLOAT a, FLOAT b) {return a > b ? a : b;}
__device__ __forceinline__ FLOAT logsumexp2(FLOAT a, FLOAT b) {return a > b ? log1p(exp(b - a)) + a : log1p(exp(a - b)) + b;}
__device__ __forceinline__ FLOAT add(FLOAT a, FLOAT b) {return a + b;}

extern "C" __global__ void sparse_Mv_scan(
    FLOAT* __restrict__ alpha,
    const FLOAT* __restrict__ Ms,  
    const int* __restrict__ idx,
    int T, int N, int C, int nz
) {
    int bx = blockIdx.x, tx = threadIdx.x;
    if (tx >= C) return;
    extern __shared__ FLOAT smem[];
    
    FLOAT a = alpha[bx * C + tx];
    for (int t = 0; t < T; t++) {
        FLOAT *buf = smem + (t % 2) * blockDim.x;
        buf[tx] = a; __syncthreads();      
        int i = ((t * N + bx) * C) + tx;
        a = MUL(buf[idx[tx * nz]], Ms[i * nz]);
        for (int j = 1; j < nz; j++) {
            a = ADD(a, MUL(buf[idx[tx * nz + j]], Ms[i * nz + j]));
        }
        alpha[i + N * C] = a;
    }
}

Overwriting cuda/sparse_scan.cu


In [None]:
#export
cupy_funcs = {
    (torch.float32, Log): load_cupy_func('cuda/sparse_scan.cu', 'sparse_Mv_scan', FLOAT='float',  ADD='logsumexp2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float64, Log): load_cupy_func('cuda/sparse_scan.cu', 'sparse_Mv_scan', FLOAT='double',  ADD='logsumexp2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float32, Max): load_cupy_func('cuda/sparse_scan.cu', 'sparse_Mv_scan', FLOAT='float',  ADD='max2', MUL='add', ZERO='{:E}'.format(Log.zero)),
    (torch.float64, Max): load_cupy_func('cuda/sparse_scan.cu', 'sparse_Mv_scan', FLOAT='double',  ADD='max2', MUL='add', ZERO='{:E}'.format(Log.zero)),
}

def Mv_scan_cupy(Ms, idx, v0, S:semiring):
    T, N, C, nz = Ms.shape
    assert idx.shape == (C, nz) 
    alpha = Ms.new_full((T+1, N, C), S.zero)
    alpha[0] = v0
    with cp.cuda.Device(Ms.device.index):
        cupy_funcs[(Ms.dtype, S)](grid=(N, 1, 1), block=(C, 1, 1), shared_mem=2*8*C,
               args=(alpha.data_ptr(), Ms.data_ptr(), idx.to(dtype=torch.int, device=Ms.device).data_ptr(), T, N, C, nz))
    return alpha

def logZ_scan(Ms, idx, v0, vT, S:semiring):
    return _LogZ_scan.apply(Ms, idx, v0, vT, S, Mv_scan_cupy)

ctc_loss_scan = partial(_ctc_loss, logZ_impl=logZ_scan)

In [None]:
sample_inputs = logits, targets, input_lengths, target_lengths = generate_sample_inputs(T_min=500, T_max=500, N=128, C=20, L_min=80, L_max=100)
fwd, bwd = compare_fwd_bwd(loss_pytorch, ctc_loss_scan, *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 9.95e-08


In [None]:
report(benchmark_fwd_bwd(loss_pytorch, *sample_inputs))

bwd: 2.14ms (2.00-2.23ms)
fwd: 1.26ms (0.94-6.39ms)
tot: 3.40ms (3.00-8.54ms)


In [None]:
report(benchmark_fwd_bwd(ctc_loss_scan, *sample_inputs))

bwd: 7.23ms (7.15-7.32ms)
fwd: 3.46ms (3.30-4.24ms)
tot: 10.69ms (10.50-11.55ms)


## 4. Faster grads in Cupy

In [None]:
%%writefile cuda/sparse_logZ.cu
__device__ __forceinline__ FLOAT add(FLOAT a, FLOAT b) {return a + b;}
__device__ __forceinline__ FLOAT max_(FLOAT *s) {
    FLOAT mx = s[0];
    for (int j = 1; j < NZ; j++) {
        mx = mx > s[j] ? mx : s[j];
    }
    return mx;
}
__device__ __forceinline__ FLOAT logsumexp(FLOAT *s) {
    FLOAT mx = max_(s);
    FLOAT res = exp(s[0] - mx);
    for (int j = 1; j < NZ; j++) {
        res += exp(s[j] - mx);
    }
    return log(res) + mx;
}
    
extern "C" __global__ void logZ_fwd_bwd(
    FLOAT* __restrict__ logZ,
    FLOAT* __restrict__ Ms_grad,
    const FLOAT* __restrict__ Ms,
    const FLOAT* __restrict__ v0,
    const FLOAT* __restrict__ vT,
    const int* __restrict__ idx,
    const int* __restrict__ idx_T,
    int T, int N, int C
) {
    int bx = blockIdx.x;
    int tx = threadIdx.x * K;
    if (tx >= C) return;
    extern __shared__ FLOAT smem[];
    
    FLOAT a[K];
    for (int k = 0; k < K; k++) {
        a[k] = v0[bx * C + tx + k]; 
    }
    __syncthreads();
    
    FLOAT s[NZ];
    for (int t = 0; t < T; t++) {
        FLOAT *buf = smem + (t % 2) * blockDim.x * K;
        for (int k = 0; k < K; k++) {
            buf[tx+k] = a[k];
        }
        __syncthreads();
        int i = (t * N + bx) * C * NZ;
        for (int k = 0; k < K; k++) {
            for (int j = 0; j < NZ; j++) {
                s[j] = MUL(buf[idx[(tx + k) * NZ + j]], Ms[i + (tx + k) * NZ + j]);
                Ms_grad[i + (tx + k) * NZ + j] = s[j];
            }
            a[k] = SUM(s);        
        }
    }

    for (int k = 0; k < K; k++) {
        logZ[bx * C + tx + k] = MUL(a[k], vT[bx * C + tx + k]);
        a[k] = vT[bx * C + tx + k];
    }
    __syncthreads();

    for (int t = T - 1; t >= 0; t--) {
        FLOAT *buf = smem + (t % 2) * blockDim.x * K;
        for (int k = 0; k < K; k++) {
            buf[tx+k] = a[k];
        }
        __syncthreads(); 
        int i = (t * N + bx) * C * NZ;
        for (int k = 0; k < K; k++) {
            for (int j = 0; j < NZ; j++) {
                int ix = idx_T[(tx + k) * NZ + j];
                Ms_grad[i + (tx + k) * NZ + j] = MUL(Ms_grad[i + (tx + k) * NZ + j], a[k]);
                s[j] = MUL(buf[ix / NZ], Ms[i + ix]);
            }            
            a[k] = SUM(s);
        }        
    }
}

In [None]:
#export
@cache(None)
def cupy_func(dtype, S, NZ, K):
    float_types = {torch.float32: 'float', torch.float64: 'double'}
    ops = {
        Log: {'sum': 'logsumexp', 'mul': 'add'},
        Max: {'sum': 'max_', 'mul': 'add'},
    }
    fname = 'cuda/sparse_logZ.cu'
    return load_cupy_func(fname, 'logZ_fwd_bwd', FLOAT=float_types[dtype],  MUL=ops[S]['mul'], ZERO='{:E}'.format(S.zero), SUM=ops[S]['sum'], NZ=NZ, K=K)

def _logZ_fwd_bwd_cupy(Ms, idx, v0, vT, S:semiring=Log, K=4):
    assert Ms.device.index is not None
    T, N, C, NZ = Ms.shape
    assert idx.shape == (C, NZ)
    idx = idx.to(dtype=torch.int, device=Ms.device)
    Ms_grad = Ms.new_full((T, N, C, NZ), S.zero)
    logZ = Ms.new_full((N, C), S.zero)
    idx_T = idx.flatten().argsort().to(torch.int) #transpose
    _bytes = 8 if (Ms.dtype == torch.float64) else 4
    with cp.cuda.Device(Ms.device.index):
        cupy_func(Ms.dtype, S, NZ, K)(grid=(N, 1, 1), block=(C//K, 1, 1), shared_mem=2*_bytes*C,
               args=(logZ.data_ptr(), Ms_grad.data_ptr(), Ms.data_ptr(), v0.data_ptr(), vT.data_ptr(), idx.data_ptr(), idx_T.data_ptr(), T, N, C))
    return S.sum(logZ, dim=1), Ms_grad

class _LogZ(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Ms, idx, v0, vT, S:semiring, K):
        logZ, Ms_grad = _logZ_fwd_bwd_cupy(Ms, idx, v0, vT, S, K)
        ctx.save_for_backward(Ms_grad)
        ctx.semiring = S
        return logZ

    @staticmethod
    def backward(ctx, grad):
        Ms_grad, = ctx.saved_tensors
        T, N, C, nz = Ms_grad.shape
        Ms_grad = ctx.semiring.dsum(Ms_grad.reshape(T, N, -1), dim=2).reshape(T, N, C, nz)
        return grad[None, :, None, None] * Ms_grad, None, None, None, None, None

def logZ(Ms, idx, v0, vT, S:semiring=Log, K=1):
    return _LogZ.apply(Ms, idx, v0, vT, S, K)

ctc_loss = partial(_ctc_loss, logZ_impl=logZ)

In [None]:
sample_inputs = logits, targets, input_lengths, target_lengths = generate_sample_inputs(T_min=500, T_max=500, N=128, C=20, L_min=80, L_max=100)
fwd, bwd = compare_fwd_bwd(loss_pytorch, ctc_loss, *sample_inputs)

fwd diff: 0.00e+00
bwd diff: 9.08e-08


In [None]:
report(benchmark_fwd_bwd(loss_pytorch, *sample_inputs))

bwd: 2.17ms (2.13-2.23ms)
fwd: 1.07ms (1.03-1.10ms)
tot: 3.24ms (3.17-3.32ms)


In [None]:
report(benchmark_fwd_bwd(ctc_loss, *sample_inputs))

bwd: 4.03ms (3.99-4.07ms)
fwd: 4.29ms (4.23-4.34ms)
tot: 8.32ms (8.26-8.37ms)
