# utils 中的一些方法

* predict_globalpointer_compare
* compare_res
* spare_multilable_categorical_crossentropy

In [1]:
import torch
from torch.functional import F

In [2]:
def predict_globalpointer_compare(y_pre, tags):
    res = [[]] * y_pre.shape[0]
    idx = (y_pre>0).nonzero()
    for b, tag_idx, start, end in idx.tolist():
        res[b].append((tags[tag_idx], (start, end)))
        
    return res


def compare_res(y_pre, y_true, tags, txts, new2ori_idxs):
    y_pre = predict_globalpointer_compare(y_pre, tags)
    y_true = predict_globalpointer_compare(y_true, tags)
    res = []
    for y_pre_cur, y_true_cur, txt, new2ori_idx in zip(y_pre, y_true, txts, new2ori_idxs):
        y_pre_cur, y_true_cur = set(y_pre_cur), set(y_true_cur)
        cur_res = {}
        pre_res, true_res = [], []
        for tag, (start, end) in y_pre_cur - y_true_cur:
            start_txt, end_txt = new2ori_idx[start][0], new2ori_idx[end][1]
            ner_txt = txt[start_txt:end_txt]
            pre_res.append((tag, ner_txt, (start, end), (start_txt, end_txt)))
        for tag, (start, end) in y_true_cur - y_pre_cur:
            start_txt, end_txt = new2ori_idx[start][0], new2ori_idx[end][1]
            ner_txt = txt[start_txt:end_txt]
            true_res.append((tag, ner_txt, (start, end), (start_txt, end_txt)))
        if pre_res:
            cur_res['pre'] = pre_res
        if true_res:
            cur_res['true'] = pre_res
        res.append(cur_res)
    return res

In [3]:
txt = ['非常威锋网服务', '二二多大的大额的我']
new2ori = [[(i, i + 1) for i in range(len(txt[0]))] for _ in range(2)]
tags = {i:l for i,l in enumerate('abcde')}
y_pre = torch.randn(2, len(tags), len(txt[0]), len(txt[0]))
y_true =torch.triu( torch.randint(0, 2, (2, len(tags), len(txt[0]), len(txt[0]))) )

In [4]:
compare_res(y_pre, y_true, tags, txt, new2ori)

[{'pre': [('a', '服务', (5, 6), (5, 7)),
   ('b', '', (5, 0), (5, 1)),
   ('c', '', (3, 0), (3, 1)),
   ('a', '', (5, 4), (5, 5)),
   ('e', '威锋网', (2, 4), (2, 5)),
   ('e', '服务', (5, 6), (5, 7)),
   ('a', '', (5, 2), (5, 3)),
   ('c', '', (3, 2), (3, 3)),
   ('a', '', (5, 0), (5, 1)),
   ('b', '', (5, 4), (5, 5)),
   ('a', '常威锋', (1, 3), (1, 4)),
   ('a', '', (6, 5), (6, 6)),
   ('c', '', (2, 1), (2, 2)),
   ('e', '', (6, 5), (6, 6)),
   ('e', '常', (1, 1), (1, 2)),
   ('b', '', (6, 4), (6, 5)),
   ('a', '', (4, 1), (4, 2)),
   ('b', '', (3, 1), (3, 2)),
   ('a', '', (6, 2), (6, 3)),
   ('a', '', (4, 3), (4, 4)),
   ('a', '非', (0, 0), (0, 1)),
   ('a', '', (3, 1), (3, 2)),
   ('a', '', (3, 0), (3, 1)),
   ('e', '', (6, 0), (6, 1)),
   ('a', '', (3, 2), (3, 3)),
   ('a', '', (5, 1), (5, 2)),
   ('c', '', (2, 0), (2, 1)),
   ('e', '', (3, 1), (3, 2)),
   ('e', '非常威', (0, 2), (0, 3)),
   ('d', '', (3, 2), (3, 3)),
   ('a', '', (6, 1), (6, 2)),
   ('c', '', (6, 5), (6, 6)),
   ('d', '', (3, 0

In [15]:
def sparse_global_loss(y_true, y_pre):
    """
    y_true( batch_size, labes_num, 2)
    y_pre(batch_size, labels_num, seq_len, seq_len)

    """
    batch_size, labels_num, seq_len, _ = y_pre.shape
    y_true = y_true[..., :1] * seq_len + y_true[..., 1:]
    y_pre = y_pre.reshape(batch_size, labels_num, -1)
    return torch.mean(spare_multilable_categorical_crossentropy(y_true, y_pre, True).sum(1))

def spare_multilable_categorical_crossentropy(y_true, y_pre, mask_zero=False, mask_value=-10000):
    """
    y_true(batch_size, labels_num, 1)
    y_pre(batch_size, labels_num, seq_len * seq_len)
    """
    # device = y_pre.device
    zeros = torch.zeros_like(y_true[..., :1])
    y_pre = torch.cat([y_pre, zeros], -1)
    if mask_zero:
        y_pre[..., 0] = - mask_value

    y_pos = y_pre.gather(-1, y_true)
    y_pos_2 = torch.cat([ - y_pos, zeros], dim=-1)
    loss_pos = torch.logsumexp(y_pos_2, -1)

    if mask_zero:
        y_pre[..., 0] = mask_value

    loss_all = torch.logsumexp(y_pre, dim=-1)
    y_pos_2 = y_pre.gather(-1, y_true)
    loss_aux = torch.logsumexp(y_pos_2, dim=-1)
    # 可能需要加一个 clip 操作
    loss_neg = loss_all + torch.log(1 - torch.exp(loss_aux - loss_all))
    return loss_pos + loss_neg



In [3]:
y_pre = torch.randn(2,3,4,4)
y_true = torch.randint(0, 3, (2,3,2))
y_pre, y_true

(tensor([[[[ 0.3848, -0.5780,  1.0659,  0.6921],
           [-1.4639,  0.1571,  0.6209,  0.5790],
           [-0.6177,  0.3236, -1.5864, -1.0411],
           [-1.2318,  0.4377, -2.3634,  1.8104]],
 
          [[ 0.0398,  0.9710, -0.8936,  1.1077],
           [-1.3833, -0.5975,  0.5423,  1.4013],
           [-0.2987,  0.0986,  1.5766, -1.5616],
           [ 1.2109, -1.2952, -1.3710,  0.3016]],
 
          [[-0.6435,  0.7178,  0.8809,  0.5234],
           [-1.6487, -0.4828, -0.8422, -0.3171],
           [ 0.5108,  1.5061,  1.2202, -1.8011],
           [ 1.5057, -0.1612,  0.6365,  0.4732]]],
 
 
         [[[-1.0466, -0.8931, -0.2011,  1.3203],
           [-1.2187, -0.4006,  0.4841,  1.6694],
           [-0.3622, -0.2215, -0.1142, -0.9360],
           [-0.9339, -1.5359,  1.6864, -1.1572]],
 
          [[-1.0286, -0.0028,  0.2846, -1.1157],
           [-1.8599, -1.5566, -0.2134, -0.5406],
           [-0.7019,  1.4854,  1.0335,  0.9046],
           [ 1.1977, -0.9960, -0.5565, -0.4716]],
 
  

In [16]:
sparse_global_loss(y_true, y_pre)

tensor(12.0147)

In [8]:
torch.zeros_like(y_true[..., :1]).shape

torch.Size([2, 3, 1])

In [12]:
(y_true[..., :1] + y_true[..., 1:]).shape

torch.Size([2, 3, 1])