# Some experiments regarding the loss function for the line finder framework

In [1]:
import os, sys
sys.path.append('C:\\Users\\matthias\\Documents\\myProjects\\TU_Bibliothek\\code\\baseline-extract')
import torch
import torch.nn as nn
from scipy.optimize import linear_sum_assignment

In [3]:
pred = torch.tensor([[100, 100,0.1, 1.0], [200,200,0.12, 1.0], [300, 300, 0.05, 1.0], [193,234, -0.2, 0.009], [654,234, 0.25, 0.006], [4,3450, 0.8, 0.001], [345,6, 0.12, 0.006], [23,634, 0.18, 0.003], [345,64,0.1, 0.008]])
label = torch.tensor([[100,100, 0], [200,200, 0.1], [300,300, 0.05]])

In [4]:
pred.shape

torch.Size([9, 4])

In [5]:
label.shape

torch.Size([3, 3])

In [28]:
crit = nn.MSELoss()

In [29]:
pred = pred.float()
label = label.float()

In [33]:
#crit(label[0], pred[0,0:2])

In [31]:
N = pred.shape[0]
M = label.shape[0]

In [32]:
cost = torch.zeros(N, M)

In [11]:
for n in range(N):
    for m in range(M):
        cost[n, m] = crit(pred[n,0:2], label[m])

In [12]:
X = linear_sum_assignment(cost)

In [13]:
cost

tensor([[1.2560e+03, 1.3256e+04, 4.5256e+04],
        [5.8600e+03, 3.0260e+04, 7.4660e+04],
        [1.4500e+01, 1.0314e+04, 4.0614e+04],
        [1.3262e+04, 2.5625e+03, 1.1862e+04],
        [1.6244e+05, 1.0364e+05, 6.4836e+04],
        [5.6159e+06, 5.3005e+06, 5.0051e+06],
        [3.4430e+04, 2.9330e+04, 4.4230e+04],
        [1.4554e+05, 1.0984e+05, 9.4142e+04],
        [3.0660e+04, 1.9760e+04, 2.8860e+04]])

In [14]:
X

(array([0, 2, 3], dtype=int64), array([0, 1, 2], dtype=int64))

In [41]:
class LineFinderLoss(nn.Module):
    """
    Loss = Sum_{n=0}^N Sum_{m=0}^M    X_nm [alpha*MSE(l_n, p_m) - Log(c_m)] - (1- X_nm) Log(1-c_m)
    where:
      N:      prediction dimension
      M:      label dimension
      X_mn:   linear assignement matrix
      l_n:    label coordinates
      p_m:    prediction coordinates
      c_m:    confidence scores
    """
    def __init__(self, alpha=0.1):
        super(LineFinderLoss, self).__init__()
        self.mse = nn.MSELoss()
        self.alpha = alpha

    def forward(self, pred, label, label_len):
        batch_size = pred.shape[0]

        #TODO: find better solution
        if pred.shape[1] < label.shape[1]:
            label = label[:, 0:pred.shape[1], :]

        location_loss = 0
        confidence_loss = 0

        for b in range(batch_size):
            # If the page is empty punish the model if it finds anything at all:
            if label.shape[1] == 0:
                conf_scores = pred[b, :, -1]
                confidence_loss += -torch.log(1 - conf_scores + 0.01).sum()
            else:
                # I get P predictions and T true labels.
                inp = pred[b, :, 0:3]
                targ = label[b, 0:label_len[b], :]

                conf_scores = pred[b, :, -1]

                # Compute the confidence for all P predictions.
                log_c = torch.log(conf_scores + 0.00000001)
                log_c_anti = torch.log(1 - conf_scores + 0.00001)

                # Expand such that for all T true lables I have a row of all predicted confidence logs.
                # The result is a P x T matrix.
                log_c_exp = log_c[:, None].expand(-1, targ.shape[0])
                log_c_anti_exp = log_c_anti[:, None].expand(-1, targ.shape[0])

                # Expand such that I get P x T x 4 matrices.
                inp_exp = inp[:, None, :].expand(-1, targ.shape[0], -1)
                targ_exp = targ[None, :, :].expand(inp.shape[0], -1, -1)

                # Compute the difference between every pair of prediction and true label locations.
                diff = (inp_exp[:, :, 0:3] - targ_exp[:, :, 0:3])
                normed_diff = torch.norm(diff, 2, 2) ** 2

                # Compute the cost matrix. This is a P x T matrix.
                C = self.alpha * normed_diff/2.0 - log_c_exp + log_c_anti_exp
                C = C.cpu().detach().numpy()

                X = torch.zeros(C.shape)
                x_c = torch.ones(C.shape[0])

                # For every row index (true), compute the column index (pred) where the cost is minimal.
                row_idx, col_idx = linear_sum_assignment(C.T)

                X[(col_idx, row_idx)] = 1.0
                x_c[col_idx] = 0.0

                X = X.to(inp.device)
                x_c = x_c.to(inp.device)

                location_loss += (normed_diff * X).sum()/2.0
                confidence_loss += -(log_c_exp * X).sum() - (log_c_anti * x_c).sum()+0.0001

        loss = self.alpha* location_loss + confidence_loss

        return loss, self.alpha * location_loss, -(log_c_exp * X).sum(), -(log_c_anti * x_c).sum()

In [42]:
lfl = LineFinderLoss()

In [82]:
pred = torch.tensor([[100, 100,0.1, 0.9999], 
                     [200, 200,0.12, 0.9999], 
                     [300, 300, 0.05, 0.001], 
                     [193, 234, -0.2, 0.999], 
                     [654, 234, 0.25, 0.0001], 
                     [4, 3450, 0.8, 0.001], 
                     [345, 6, 0.12, 0.006], 
                     [23, 634, 0.18, 0.003],
                     [345, 64,0.1, 0.008]])
label = torch.tensor([[100, 100, 0], 
                      [200, 200, 0.1], 
                      [300, 300, 0.05]])

In [85]:
16*16

256

In [83]:
p = pred.unsqueeze(0)
l = label.unsqueeze(0)

In [84]:
d, loc, c, c_anti = lfl(p, l, [p.shape[1]])
print('Loss:        {:4f}'.format(d.item()))
print('Loc_loss:    {:4f}'.format(loc.item()))
print('conf_loss:   {:4f}'.format(c.item()))
print('conf_a_loss: {:4f}'.format(c_anti.item()))

Loss:        13.824489
Loc_loss:    0.000520
conf_loss:   6.907946
conf_a_loss: 6.915923


In [17]:
lfl(p, l, [p.shape[1]])

(tensor(0.0336), tensor(0.0005), tensor(-3.0041e-05), tensor(0.0331))

In [203]:
d

tensor(2.5633)

In [208]:
X.data

tensor([[1., 0., 0.],
        [0., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [115]:
X_c

tensor([[0., 1., 1.],
        [1., 1., 1.],
        [1., 0., 1.],
        [1., 1., 0.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [210]:
X.device

device(type='cpu')