# Comparison of RNN-T loss implementations
Note: this will only run inside the awni-speech docker image. 


In [116]:
cd '/home/ubuntu/speech/libs'

/home/ubuntu/speech/libs


In [117]:
import transducer.functions.transducer as awni_transducer

In [118]:
from warprnnt_pytorch import RNNTLoss as WarpRNNTLoss
import torch
import numpy as np

In [381]:
# inputs_awni = [log_probs, y_flat, lengths, label_lengths]
# inputs_warp =  [logits, y_mat, logit_lens, y_lens]
B = 2
T = 4
K = 5
U = 2
logits = torch.randn((B, T, U + 1, K + 1))
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
probs = log_probs.exp()

In [382]:
# print(log_probs)
# print(logits)
# print(probs)

In [383]:
def label_collate(labels):
    # Doesn't matter what we pad the end with
    # since it will be ignored.
    batch_size = len(labels)
    end_tok = labels[0][-1]
    max_len = max(len(l) for l in labels)
    cat_labels = np.full((batch_size, max_len),
                    fill_value=end_tok, dtype=np.int64)
    for e, l in enumerate(labels):
        cat_labels[e, :len(l)] = l
    labels = torch.LongTensor(cat_labels)

    return labels

In [384]:
y = [torch.IntTensor([0, 1]), torch.IntTensor([1, 0])]
y_flat = torch.IntTensor([l for label in y for l in label])
y_flat
y_mat = label_collate(y).int()
x_lens = torch.IntTensor([4, 2]) #T=3
y_lens = torch.IntTensor([len(l) for l in y])

In [385]:
inputs_awni = [log_probs, y_flat, x_lens, y_lens]
inputs_warp =  [logits, y_mat, x_lens, y_lens]

In [386]:
gpu = False
if gpu:
    inputs_awni = [x.cpu() for x in inputs_awni] #this has no gpu implementation
    inputs_warp = [x.cuda() for x in inputs_warp]


In [387]:
blank_label = K
awni_loss = awni_transducer.TransducerLoss(blank_label=blank_label)
warp_loss = WarpRNNTLoss(blank=blank_label)
# awni_loss = awni_transducer.TransducerLoss()
# warp_loss = WarpRNNTLoss()

In [388]:

awni_loss(*inputs_awni)

tensor([7.0178])

In [389]:
for x in inputs_warp:
    print(x.shape)
print(inputs_warp[-2], inputs_warp[-1])

warp_loss(*inputs_warp)

torch.Size([2, 4, 3, 6])
torch.Size([2, 2])
torch.Size([2])
torch.Size([2])
tensor([4, 2], dtype=torch.int32) tensor([2, 2], dtype=torch.int32)


tensor([7.0178])

## Conclusions
* warp_loss and awni_loss give the same value. It is obviously possible they are both wrong (I think warp-rnnt was based on awni's implementation).
* warp_loss appears to give the same values reagrdless of whether logits or log_probs are passed. This is quite weird - there must be a check in the implementation that determines whether log_probs or logits have been passed.  

In [435]:
def assert_same(inp1, inp2):
    epsilon = 1e-5
    val1 = warp_loss(*inp1)
    val2 = warp_loss(*inp2)
    assert (val1.item() - val2.item()) < epsilon, f"log probs and logits do not give the same values. i.e {val1.item()} != {val2.item()}"
    print(f"Losses are the same since {val1.item()} == {val2.item()}")

In [438]:
logits = torch.randn((B, T, U + 1, K + 1))
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
inputs_warp1 =  [logits, y_mat, x_lens, y_lens]
inputs_warp2 = [log_probs, y_mat, x_lens, y_lens]

assert_same(inputs_warp1, inputs_warp2)
#perhaps the check is seeing if all the values are negative?

logits2 = logits.clone()
logits2 -= 1000 #make all negative
log_probs2 = torch.nn.functional.log_softmax(logits2, dim=-1)

inputs_warp3 = [log_probs2, y_mat, x_lens, y_lens]
inputs_warp4 =  [logits2, y_mat, x_lens, y_lens]

assert_same(inputs_warp3, inputs_warp4)
logits

Losses are the same since 6.881084442138672 == 6.881083965301514
Losses are the same since 6.881107330322266 == 6.881138801574707


tensor([[[[-0.4129, -1.6022, -0.9453,  0.5857,  0.6219, -0.7563],
          [ 2.1744,  2.4167, -0.1291, -0.3900, -0.4496, -0.7070],
          [ 0.3734,  0.0791, -0.8818, -0.7426, -0.8053, -1.5643]],

         [[ 0.0635, -0.8184,  1.1282, -1.3565,  0.3035,  0.3024],
          [-1.2073,  0.9886, -0.6912, -1.5064, -0.2042,  0.5708],
          [-0.2789, -0.7848, -1.1220,  0.6504,  0.0577,  0.4951]],

         [[ 0.4040,  0.3438,  0.8030,  0.4080, -1.1770,  1.1413],
          [ 1.5867,  1.8462, -0.8938, -0.6556,  1.0368,  0.6939],
          [ 0.6407,  0.4395,  0.0393,  1.6467,  0.4994,  0.6421]],

         [[ 0.9332,  0.5397, -0.9760, -1.7922, -1.2370, -1.3334],
          [-1.6242, -1.4349,  0.0672, -0.2375,  0.7944,  1.1286],
          [ 0.0827, -0.3559, -1.1331, -0.0338, -0.9112,  1.2635]]],


        [[[ 0.1498, -0.3466, -2.0149, -0.2680, -0.4186, -0.9724],
          [ 0.9700, -0.7038,  0.3109,  0.5435,  1.3241,  0.1239],
          [ 1.0267,  2.0639, -0.3451,  0.8338, -0.1505,  0.3615]],

### Logits vs log_probs conclusion

Ok so as expected, the implementation treats the input as log_probs if it is all negative (or all positive by the looks of it) but if the values are mixed signs it assumes they are logits

# Worked example 1
Consider three timesteps with perfectly trained model and exact alignments such that correct output should be:  
* a @ T=1
* b @ T=2
* blank @ T=3


In [286]:
probs_t1 = torch.tensor([[1., 0, 0], [0, 0, 1], [0, 0, 1]])
probs_t2 = torch.tensor([[0., 0, 1], [0, 1, 0], [0, 0, 1]])
probs_t3 = torch.tensor([[0, 0, 1.], [0, 0, 1], [0, 0, 1]])

probs = [probs_t1, probs_t2, probs_t3]
probs = [x.unsqueeze(0) for x in probs]
probs = torch.cat(probs, dim=0)
probs = probs.unsqueeze(0).float()

EPS = 0.00001 #Non-zero value of other 
r = 2 * (1 - 3 * EPS)
probs *= r
probs += 2 * EPS
probs /= 2.0
probs

tensor([[[[9.9998e-01, 1.0000e-05, 1.0000e-05],
          [1.0000e-05, 1.0000e-05, 9.9998e-01],
          [1.0000e-05, 1.0000e-05, 9.9998e-01]],

         [[1.0000e-05, 1.0000e-05, 9.9998e-01],
          [1.0000e-05, 9.9998e-01, 1.0000e-05],
          [1.0000e-05, 1.0000e-05, 9.9998e-01]],

         [[1.0000e-05, 1.0000e-05, 9.9998e-01],
          [1.0000e-05, 1.0000e-05, 9.9998e-01],
          [1.0000e-05, 1.0000e-05, 9.9998e-01]]]])

In [287]:
log_probs = probs.log()
log_probs

tensor([[[[ -0.0000, -11.5129, -11.5129],
          [-11.5129, -11.5129,  -0.0000],
          [-11.5129, -11.5129,  -0.0000]],

         [[-11.5129, -11.5129,  -0.0000],
          [-11.5129,  -0.0000, -11.5129],
          [-11.5129, -11.5129,  -0.0000]],

         [[-11.5129, -11.5129,  -0.0000],
          [-11.5129, -11.5129,  -0.0000],
          [-11.5129, -11.5129,  -0.0000]]]])

In [288]:
y = [torch.IntTensor([0, 1])]
y_flat = torch.IntTensor([l for label in y for l in label])
y_mat = label_collate(y).int()
x_lens = torch.IntTensor([3]) #T=3
y_lens = torch.IntTensor([len(l) for l in y])

In [289]:
blank_label = 2
inputs_awni = [log_probs, y_flat, x_lens, y_lens]
awni_loss = awni_transducer.TransducerLoss(blank_label=blank_label)
awni_loss(*inputs_awni)

tensor([0.0001])

In [290]:
warp_loss = WarpRNNTLoss(blank=blank_label)
inputs_warp =  [log_probs, y_mat, x_lens, y_lens]
warp_loss(*inputs_warp)

tensor([0.0001])

In [291]:
for x in inputs_warp:
    print(x.shape)

torch.Size([1, 3, 3, 3])
torch.Size([1, 2])
torch.Size([1])
torch.Size([1])


This is behaving as expected - i.e. monotonically increasing from 0 as EPS increases.  

# Worked example 2
Consider untrained model (all probabilities equally likely). Work out expected loss. 
* Alphabet = {a, blank}
* i.e. K = 1 (hence all probabilities are 0.5)
* T = 2
* There are two paths through the graph: (blank, a) or (a, blank)
* Each of these requires x3 transitions (i.e. must append blank at end)

* expected total: 0.5 ** 3 x 2 = 0.25

* LOSS = - ln(0.25)


In [456]:
EXPECTED_LOSS = - torch.log(torch.tensor([1 / 4.]))
probs = torch.ones((1, 2, 2, 2)) * 0.5
log_probs = probs.log()
y = [torch.IntTensor([0])]
y_flat = torch.IntTensor([l for label in y for l in label])
y_mat = label_collate(y).int()
x_lens = torch.IntTensor([2]) #T=3
y_lens = torch.IntTensor([len(l) for l in y])

In [457]:
blank_label = 1
inputs_awni = [log_probs, y_flat, x_lens, y_lens]
awni_loss = awni_transducer.TransducerLoss(blank_label=blank_label)
loss1 = awni_loss(*inputs_awni)

In [458]:
warp_loss = WarpRNNTLoss(blank=blank_label)
inputs_warp =  [log_probs, y_mat, x_lens, y_lens]
loss2 = warp_loss(*inputs_warp)

In [459]:
print(loss1, loss2)
assert torch.allclose(loss1, EXPECTED_LOSS), f"loss={loss1} != {EXPECTED_LOSS}"
assert torch.allclose(loss2, EXPECTED_LOSS), f"loss={loss2} != {EXPECTED_LOSS}"

tensor([1.3863]) tensor([1.3863])
