# utils 中的一些方法

* predict_globalpointer_compare
* compare_res
* spare_multilable_categorical_crossentropy
* f1_globalpointer
* f1_globalpointer_sparse
* em_globalpointer
* em_globalpointer_sparse
* multilable_categorical_crossentropy
* globalpointer_loss

In [167]:
import torch
from torch.functional import F
# from my_py_toolkit.torch.tensor_toolkit import mask

In [2]:
def predict_globalpointer_compare(y_pre, tags):
    res = [[]] * y_pre.shape[0]
    idx = (y_pre>0).nonzero()
    for b, tag_idx, start, end in idx.tolist():
        res[b].append((tags[tag_idx], (start, end)))
        
    return res


def compare_res(y_pre, y_true, tags, txts, new2ori_idxs):
    y_pre = predict_globalpointer_compare(y_pre, tags)
    y_true = predict_globalpointer_compare(y_true, tags)
    res = []
    for y_pre_cur, y_true_cur, txt, new2ori_idx in zip(y_pre, y_true, txts, new2ori_idxs):
        y_pre_cur, y_true_cur = set(y_pre_cur), set(y_true_cur)
        cur_res = {}
        pre_res, true_res = [], []
        for tag, (start, end) in y_pre_cur - y_true_cur:
            start_txt, end_txt = new2ori_idx[start][0], new2ori_idx[end][1]
            ner_txt = txt[start_txt:end_txt]
            pre_res.append((tag, ner_txt, (start, end), (start_txt, end_txt)))
        for tag, (start, end) in y_true_cur - y_pre_cur:
            start_txt, end_txt = new2ori_idx[start][0], new2ori_idx[end][1]
            ner_txt = txt[start_txt:end_txt]
            true_res.append((tag, ner_txt, (start, end), (start_txt, end_txt)))
        if pre_res:
            cur_res['pre'] = pre_res
        if true_res:
            cur_res['true'] = pre_res
        res.append(cur_res)
    return res

In [3]:
txt = ['非常威锋网服务', '二二多大的大额的我']
new2ori = [[(i, i + 1) for i in range(len(txt[0]))] for _ in range(2)]
tags = {i:l for i,l in enumerate('abcde')}
y_pre = torch.randn(2, len(tags), len(txt[0]), len(txt[0]))
y_true =torch.triu( torch.randint(0, 2, (2, len(tags), len(txt[0]), len(txt[0]))) )

In [4]:
compare_res(y_pre, y_true, tags, txt, new2ori)

[{'pre': [('a', '服务', (5, 6), (5, 7)),
   ('b', '', (5, 0), (5, 1)),
   ('c', '', (3, 0), (3, 1)),
   ('a', '', (5, 4), (5, 5)),
   ('e', '威锋网', (2, 4), (2, 5)),
   ('e', '服务', (5, 6), (5, 7)),
   ('a', '', (5, 2), (5, 3)),
   ('c', '', (3, 2), (3, 3)),
   ('a', '', (5, 0), (5, 1)),
   ('b', '', (5, 4), (5, 5)),
   ('a', '常威锋', (1, 3), (1, 4)),
   ('a', '', (6, 5), (6, 6)),
   ('c', '', (2, 1), (2, 2)),
   ('e', '', (6, 5), (6, 6)),
   ('e', '常', (1, 1), (1, 2)),
   ('b', '', (6, 4), (6, 5)),
   ('a', '', (4, 1), (4, 2)),
   ('b', '', (3, 1), (3, 2)),
   ('a', '', (6, 2), (6, 3)),
   ('a', '', (4, 3), (4, 4)),
   ('a', '非', (0, 0), (0, 1)),
   ('a', '', (3, 1), (3, 2)),
   ('a', '', (3, 0), (3, 1)),
   ('e', '', (6, 0), (6, 1)),
   ('a', '', (3, 2), (3, 3)),
   ('a', '', (5, 1), (5, 2)),
   ('c', '', (2, 0), (2, 1)),
   ('e', '', (3, 1), (3, 2)),
   ('e', '非常威', (0, 2), (0, 3)),
   ('d', '', (3, 2), (3, 3)),
   ('a', '', (6, 1), (6, 2)),
   ('c', '', (6, 5), (6, 6)),
   ('d', '', (3, 0

In [316]:
def sparse_global_loss(y_true, y_pre):
    """
    y_true( batch_size, labes_num, 2)
    y_pre(batch_size, labels_num, seq_len, seq_len)

    """
    batch_size, labels_num, seq_len, _ = y_pre.shape
    y_true = y_true[..., 0] * seq_len + y_true[..., 1]
    y_pre = y_pre.reshape(batch_size, labels_num, -1)
    return torch.mean(spare_multilable_categorical_crossentropy(y_true, y_pre, True).sum(1))

def spare_multilable_categorical_crossentropy(y_true, y_pre, mask_zero=False, mask_value=-10000):
    """
    y_true(batch_size, labels_num, 1)
    y_pre(batch_size, labels_num, seq_len * seq_len)
    """
    # device = y_pre.device
    zeros = torch.zeros_like(y_true[..., :1])
    y_pre = torch.cat([y_pre, zeros], -1)
    if mask_zero:
        y_pre[..., 0] = - mask_value

    y_pos = y_pre.gather(-1, y_true)
    y_pos_2 = torch.cat([ - y_pos, zeros], dim=-1)
    loss_pos = torch.logsumexp(y_pos_2, -1)

    if mask_zero:
        y_pre[..., 0] = mask_value

    loss_all = torch.logsumexp(y_pre, dim=-1)
    y_pos_2 = y_pre.gather(-1, y_true)
    loss_aux = torch.logsumexp(y_pos_2, dim=-1)
    # 可能需要加一个 clip 操作
    loss_neg = loss_all + torch.log(1 - torch.exp(loss_aux - loss_all))
    return loss_pos + loss_neg



In [317]:
y_pre = torch.randn(2,3,4,4)
y_true = torch.randint(0, 3, (2,3, 5,2))
y_pre, y_true

(tensor([[[[ 0.5784,  0.7584,  0.8789,  0.8394],
           [-0.7072,  0.7751,  0.1053,  0.5991],
           [-0.0086, -0.1709, -0.7987, -0.7811],
           [-0.0680, -0.8091, -0.8300,  0.8150]],
 
          [[ 0.9281, -0.1977,  0.5939, -1.7586],
           [ 0.7674, -0.5807, -1.6056, -0.6681],
           [ 0.4122,  0.5615, -0.5178, -0.6037],
           [ 0.4126,  1.0954,  0.5267,  0.1089]],
 
          [[-0.1875, -0.7717,  1.7173, -0.2593],
           [ 0.8962, -0.9508, -0.7852,  0.1863],
           [ 0.7808, -0.1469,  0.0374,  0.4198],
           [-2.0249,  0.5224, -0.7146,  3.4525]]],
 
 
         [[[ 0.1420,  0.0084, -0.9874,  0.3218],
           [ 0.1670,  1.8775,  1.7553, -0.6546],
           [-1.3691, -0.8345,  1.0055,  0.6572],
           [ 0.6825, -0.2435,  0.7463,  0.2962]],
 
          [[ 0.1991,  0.1816,  1.7086, -0.5378],
           [-0.5559, -1.0820, -0.8937,  0.1389],
           [-0.1933,  0.5735, -1.7005, -0.2799],
           [ 0.2437, -1.0049,  1.4619,  0.4143]],
 
  

In [318]:
sparse_global_loss(y_true, y_pre)

tensor(14.7517)

In [8]:
torch.zeros_like(y_true[..., :1]).shape

torch.Size([2, 3, 1])

In [12]:
(y_true[..., :1] + y_true[..., 1:]).shape

torch.Size([2, 3, 1])

In [290]:
# f1_globalpointer
def f1_globalpointer(y_true, y_pre, tp_pre=None, tpfp_pre=None, tpfn_pre=None):
    y_pre = y_pre.greater(0)
    if tp_pre is None:
        return 2 * (y_true * y_pre).sum() / (y_true.sum() + y_pre.sum())
    else:
        tp = (y_true * y_pre).sum() + tp_pre
        tpfp = y_pre.sum() + tpfp_pre
        tpfn = y_true.sum() + tpfn_pre
        return 2 * tp / (tpfp + tpfn), tp, tpfp, tpfn

In [291]:
y_pre = torch.randn(2,3,4)
y_true = torch.randint(0, 2, (2, 3, 4))
f1_globalpointer(y_true, y_pre), f1_globalpointer(y_true, y_pre, 2, 3, 4)

(tensor(0.5455), (tensor(0.5517), tensor(8), tensor(14), tensor(15)))

In [293]:
16/(14 + 15), 12/(7 + 15)

(0.5517241379310345, 0.5454545454545454)

In [296]:
def em_globalpoiter(y_true, y_pre, tp_pre=None, tpfp_pre=None):
    y_pre = y_pre.greater(0)
    if tp_pre is None:
        return (y_true * y_pre).sum()/ y_pre.sum() 
    else:
        tp = (y_true * y_pre).sum() + tp_pre
        tpfp = y_pre.sum() + tpfp_pre
        return tp/tpfp, tp, tpfp

In [297]:
em_globalpoiter(y_true, y_pre), em_globalpoiter(y_true, y_pre, 2, 5)

(tensor(0.5455), (tensor(0.5000), tensor(8), tensor(16)))

In [27]:
# f1_globalpointer_sparse
# def f1_globalpointer_sparse(y_true, y_pre):
#     batch_size, labels_num, seq_len, _ = y_pre.shape
#     y_true = y_true[..., :1] * seq_len + y_true[..., 1:]
#     y_pre = y_pre.reshape(batch_size, labels_num, -1)
#     y_pre = (y_pre)>0

#     tp = y_pre.gather(-1, y_true).sum()
#     tpfp = y_pre.sum()
#     tpfn = (y_true > 0).sum()
#     return 2*tp / (tpfp + tpfn)


In [312]:
# f1_globalpointer_sparse
def f1_globalpointer_sparse(y_true, y_pre, tp_pre=None, tpfp_pre=None, tpfn_pre=None):
    batch_size, labels_num, seq_len, _ = y_pre.shape
    y_true = y_true[..., 0] * seq_len + y_true[..., 1]
    y_pre = y_pre.reshape(batch_size, labels_num, -1)
    y_pre = (y_pre)>0

    tp = y_pre.gather(-1, y_true).sum()
    tpfp = y_pre.sum()
    tpfn = (y_true > 0).sum()
    
    if tp_pre is None:
        return 2*tp / (tpfp + tpfn)
    else:
        tp += tp_pre
        tpfp += tpfp_pre
        tpfn += tpfn_pre
        return 2*tp / (tpfp + tpfn), tp, tpfp, tpfn


In [310]:
a = torch.randn((1,2,3,3))
b = torch.randint(0, 3, (1,2, 5,2))
a, b 



(tensor([[[[ 0.5509, -0.3341,  0.1795],
           [-0.4284, -1.9252,  0.3193],
           [ 0.8755,  1.9773, -0.0037]],
 
          [[-0.7280,  0.0338, -1.0060],
           [-1.1954, -1.1370, -0.4052],
           [ 0.9667,  0.0255,  1.2255]]]]),
 tensor([[[[1, 0],
           [2, 0],
           [2, 2],
           [0, 1],
           [1, 1]],
 
          [[2, 2],
           [1, 2],
           [2, 0],
           [2, 0],
           [2, 1]]]]))

In [313]:
f1_globalpointer_sparse(b, a), f1_globalpointer_sparse(b, a, 2, 4, 3)

(tensor(0.5263), (tensor(0.5385), tensor(7), tensor(13), tensor(13)))

In [315]:
(b[..., :1] *3 + b[..., 1:]).squeeze(-1).shape

torch.Size([1, 2, 5])

In [35]:
a_1 = a.reshape(1,2, -1) > 0
a_1

tensor([[[False,  True, False, False, False, False,  True, False,  True],
         [False,  True,  True,  True,  True, False, False,  True,  True]]])

In [36]:
b_1 = b[..., :1] * 3 +b [..., 1::]
b_1

tensor([[[8],
         [1]]])

In [47]:
tp = a_1.gather(-1, b_1).sum()
tp

tensor(2)

In [40]:
tpfp = a_1.sum()
tpfp

tensor(9)

In [42]:
tpfn = (b_1 > 0).sum()
tpfn

tensor(2)

In [48]:
2 * tp/ (tpfp + tpfn)

tensor(0.3636)

In [54]:
tp/tpfp

tensor(0.2222)

In [307]:
# em_globalpointer_sparse
def em_globalpointer_sparse(y_true, y_pre, tp_pre=None, tpfp_pre=None):
    batch_size, labels_num, seq_len, _ = y_pre.shape
    y_true = y_true[..., 0] * seq_len + y_true[..., 1]
    y_pre = y_pre.reshape(batch_size, labels_num, -1) > 0
    if tp_pre is None:
        return y_pre.gather(-1, y_true).sum() / y_pre.sum()
    else:
        tp = y_pre.gather(-1, y_true).sum() + tp_pre
        tpfp = y_pre.sum() + tpfp_pre
        return tp / tpfp, tp, tpfp



In [301]:
a = torch.randn(2, 3, 4, 4)
b = torch.randint(0, 4, (2, 3, 5, 2))
a, b

(tensor([[[[ 0.4278, -0.1485,  0.8134,  0.4763],
           [ 0.4506, -0.9256,  0.6311, -0.9998],
           [-0.5012,  1.4041, -0.4142, -0.6175],
           [-0.4538,  0.2776, -2.1733,  0.4106]],
 
          [[ 0.1179,  0.5701, -0.0847,  0.7752],
           [ 1.3515,  0.7486, -1.9082, -1.9032],
           [ 0.1734, -0.8240,  0.4001, -0.5463],
           [ 0.2290,  1.3607, -0.3709,  0.0617]],
 
          [[-1.7603, -0.2351,  0.0529, -1.5433],
           [-0.6146, -1.8675, -1.6590, -0.4500],
           [ 0.3181,  0.4151,  0.5632, -1.5208],
           [ 1.5337, -1.5245, -1.4476, -0.6147]]],
 
 
         [[[ 1.7203,  1.7013, -0.3912,  0.9440],
           [-1.7835,  0.4110,  1.0061, -1.3277],
           [ 1.3365,  1.7706,  1.6693, -0.7272],
           [ 2.0811, -0.6681, -0.1015, -0.6598]],
 
          [[ 0.3039,  0.6207, -1.8062, -0.2783],
           [-0.0533, -0.7125, -0.4873, -0.5468],
           [ 0.0462,  1.2818, -1.8496,  0.3464],
           [ 1.1191,  0.6175,  1.5951, -1.4643]],
 
  

In [308]:
em_globalpointer_sparse(b, a), em_globalpointer_sparse(b, a, 2, 3)

(tensor(0.2400), (tensor(0.2642), tensor(14), tensor(53)))

In [55]:
2/9

0.2222222222222222

In [177]:
def mask(tensor, tensor_mask, mask_dim, mask_value=0):
  """
  Mask a tensor.
  Args:
    tensor(torch.Tensor): 输入
    tensor_mask(torch.Tensor): mask 位置信息. 注：数据类型必须是 Int, 否则不能正确 mask。
    mask_dim(int): 负数，指定需要 mask 的维度，example：mask_dim = -1, 表示在最后一维上做 mask 操作.
      example: tensor is shape(3,3,3), -1 表示对最后一维的 len=3 的数组做 mask.
  Returns:
  """
  if not mask_dim < 0:
    raise Exception(f"Mask dim only supports negative numbers! Mask dim: {mask_dim} ")
  tensor_mask = 1 - tensor_mask
  if mask_dim < 0:
    mask_dim = tensor.dim() + mask_dim
  for _ in range(mask_dim - tensor_mask.dim() + 1):
    tensor_mask = tensor_mask.unsqueeze(1)
  for _ in range(tensor.dim() - tensor_mask.dim()):
    tensor_mask = tensor_mask.unsqueeze(-1)
  return tensor.masked_fill(tensor_mask, mask_value)

In [179]:
a = torch.randn(2, 2)
b = torch.randint(0, 2, (2, 2))
a, b, mask(a, b, -1, -9)

(tensor([[ 0.0955, -0.4033],
         [ 2.5606,  0.3287]]),
 tensor([[0, 0],
         [0, 1]]),
 tensor([[-9.0000, -9.0000],
         [-9.0000,  0.3287]]))

In [180]:
# multilabel_categorical_crossentropy
# global_pointer_loss
def multilabel_categorical_crossentropy(y_true, y_pre):
    y_pre = (1 - 2* y_true) * y_pre
    y_neg = mask(y_pre.clone(), y_true, -1, -10000)
    y_pos = mask(y_pre.clone(), 1 - y_true, -1, -10000)
    zero = torch.zeros_like(y_true[..., :1])
    y_neg = torch.cat([y_neg, zero], -1)
    y_pos = torch.cat([y_pos, zero], -1)
    return torch.logsumexp(y_neg, -1) + torch.logsumexp(y_pos, -1)


def global_pointer_loss(y_true, y_pre):
    b, h = y_pre.shape[:2]
    y_pre = y_pre.reshape(b*h, -1)
    y_true = y_true.reshape(b*h, -1)
    return torch.mean(multilabel_categorical_crossentropy(y_true, y_pre))

In [278]:
from my_py_toolkit.torch.transformer_utils import gen_pos_emb
def rope(inputs):
    """rotary position embedding"""
    device = inputs.device
    length, dim = inputs.shape[-2:]
    pos = gen_pos_emb(length, dim)
    inputs_2 = torch.zeros_like(inputs)
    inputs_2[..., ::2] = -inputs[..., 1::2]
    inputs_2[..., 1::2] = inputs[..., ::2]
    sin, cos = torch.zeros_like(pos), torch.zeros_like(pos)
    sin[:, ::2], sin[:, 1::2] = pos[:, ::2], pos[:, ::2]
    cos[:, ::2], cos[:, 1::2] = pos[:, 1::2], pos[:, 1::2]
    return inputs * cos.to(device) + inputs_2 * sin.to(device)
# efficient globalpointer
class EfficientGlobalPointer(torch.nn.Module):
    def __init__(self, num_head, head_size, dim):
        super().__init__()
        self.linear_1 = torch.nn.Linear(dim, head_size * 2)
        self.linear_2 = torch.nn.Linear(head_size * 2, num_head * 2)

    def forward(self, inputs, mask_t):
        inputs = self.linear_1(inputs)
        q, k = inputs[..., ::2], inputs[..., 1::2]
        q, k = rope(q), rope(k)
        att = q@k.transpose(-1, -2)
        bias = self.linear_2(inputs).transpose(-1, -2) / 2
        logits = att.unsqueeze(1) + bias[:, ::2, :].unsqueeze(-1) + bias[:, 1::2, :].unsqueeze(-2)
        logits = mask(logits, mask_t, -1, -10000)
        logits = mask(logits, mask_t, -2, -10000)
        logits[:, :, 0, :] = -10000
        logits[:, :, :, 0] = -10000
        logits = logits.masked_fill(torch.tril(torch.ones_like(logits).to(int), -1), -10000)
        return logits

In [279]:
import random

In [280]:
batch_size = 70
sqe_len = 64
labels = 11
dim = 768
head_size = 64
input_idx = torch.randn(batch_size, sqe_len, dim)
mask_tensor = []
for _ in range(batch_size):
    tmp = random.randint(0, 63)
    mask_tensor.append([0] + [1] * tmp + [0] * (sqe_len - tmp - 1))
mask_tensor = torch.tensor(mask_tensor)
y_true = torch.triu(torch.randint(0, 2, (batch_size, labels, sqe_len, sqe_len)))
y_true = mask(y_true, mask_tensor, -1, 0)
y_true = mask(y_true, mask_tensor, -2, 0)
y_true = torch.triu(y_true)
fg = EfficientGlobalPointer(labels, head_size, dim)

In [281]:
y_pre = fg(input_idx, mask_tensor)

In [282]:
global_pointer_loss(y_true, y_pre)

tensor(16.9475, grad_fn=<MeanBackward0>)

In [247]:
y_true, y_pre

(tensor([[[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
 
          [[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
 
          [[0, 0, 0,  ..., 0, 0, 0],
           [0, 1, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
 
          ...,
 
          [[0, 0, 0,  ..., 0, 0, 0],
           [0, 1, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
 
          [[0, 0, 0

In [207]:
mask_tensor

tensor([[0, 1, 1,  ..., 0, 0, 0],
        [0, 1, 1,  ..., 0, 0, 0],
        [0, 1, 1,  ..., 0, 0, 0],
        ...,
        [0, 1, 1,  ..., 0, 0, 0],
        [0, 1, 1,  ..., 0, 0, 0],
        [0, 1, 1,  ..., 0, 0, 0]])

In [255]:
a = torch.randn(batch_size, labels, sqe_len, sqe_len)
a

tensor([[[[ 9.9535e-01,  1.3340e+00, -5.3121e-01,  ..., -2.9744e-01,
            8.1377e-01,  3.3625e+00],
          [-1.0398e-01, -2.8573e-01,  8.6282e-01,  ..., -1.2219e-01,
            4.8551e-01, -8.5907e-01],
          [-5.5943e-01,  1.8919e+00,  3.7684e-02,  ..., -1.4333e+00,
           -1.3917e+00, -1.7400e+00],
          ...,
          [-5.2352e-01, -4.5030e-01, -8.4085e-01,  ..., -1.1597e+00,
           -1.2906e+00, -4.1241e-01],
          [ 2.1684e-01,  8.0790e-01, -7.9786e-01,  ...,  5.0231e-01,
           -1.3848e+00, -3.6179e-01],
          [-2.8068e-02, -5.4057e-01, -1.3864e+00,  ..., -5.2880e-01,
           -1.2116e+00,  1.0247e-01]],

         [[-8.4520e-01,  1.4655e+00, -6.0748e-01,  ..., -1.2112e-01,
           -3.2309e-01,  1.0780e+00],
          [-8.4809e-01,  4.1728e-01, -6.1097e-01,  ..., -1.5629e+00,
            1.2860e+00, -1.8204e+00],
          [-3.7374e-02,  1.2143e-01,  6.6585e-01,  ..., -8.1875e-01,
            1.3172e+00, -9.7002e-01],
          ...,
     

In [256]:
b = mask(a, mask_tensor, -1, -10000)
b = mask(b, mask_tensor, -2, -10000)
f = mask(y_true, mask_tensor, -1, 0)
f = mask(f, mask_tensor, -2, 0)
b

tensor([[[[-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -2.8573e-01, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          ...,
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04]],

         [[-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04,  4.1728e-01, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          ...,
     

In [258]:
c = b.reshape(batch_size * labels, -1)
d = f.reshape(batch_size * labels, -1)
e = (1 - 2 * d) * c
# c = y_true_t * c
torch.logsumexp(mask(e,  1 - d, -1, -10000), -1).max()

tensor(8.5494)

In [259]:
torch.logsumexp(c * d, -1).max()

tensor(8.5069)

In [224]:
-d, d

(tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]))

In [206]:
y_pre

tensor([[[[-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04,  4.2126e+00,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          ...,
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04]],

         [[-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04,  4.2107e+00,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          [-1.0000e+04, -1.0000e+04, -1.0000e+04,  ..., -1.0000e+04,
           -1.0000e+04, -1.0000e+04],
          ...,
     

In [173]:
a = torch.randn(2,3)
b = torch.randint(0,2, (2,3))
a, b

(tensor([[-0.0987, -1.0744,  0.1928],
         [ 0.0962,  1.3312, -1.2819]]),
 tensor([[1, 1, 0],
         [0, 0, 1]]))

In [178]:
mask(a, b, -1)

tensor([[-0.0987, -1.0744,  0.0000],
        [ 0.0000,  0.0000, -1.2819]])

In [185]:
(y_pre == -10000).sum(), (y_pre == 10000 ).sum()

(tensor(2605339), tensor(0))

In [253]:
y_true_t = y_true.reshape(batch_size * labels, -1)
y_pre_t = y_pre.reshape(batch_size * labels, -1)
y_pre_t = (1 - 2 * y_true_t) * y_pre_t
y_true_t, y_pre_t

(tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([[-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
         [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
         [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
         ...,
         [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
         [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
         [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.]],
        grad_fn=<MulBackward0>))

In [254]:
(y_pre_t==10000).sum()

tensor(11977)

In [252]:
(y_pre_t * y_true_t).sum()

tensor(1.1976e+08, grad_fn=<SumBackward0>)

In [188]:
y_pre.reshape(batch_size * labels, -1)

tensor([[-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
        [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
        [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
        ...,
        [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
        [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.],
        [-10000., -10000., -10000.,  ..., -10000., -10000., -10000.]],
       grad_fn=<ViewBackward>)

In [190]:
mask(y_pre_t, y_true_t, -1, 0)

tensor([[    0., 10000., 10000.,  ...,     0.,     0.,     0.],
        [    0., 10000.,     0.,  ...,     0.,     0.,     0.],
        [10000.,     0.,     0.,  ...,     0.,     0.,     0.],
        ...,
        [10000.,     0.,     0.,  ...,     0.,     0.,     0.],
        [    0.,     0., 10000.,  ...,     0.,     0., 10000.],
        [10000.,     0., 10000.,  ...,     0.,     0.,     0.]],
       grad_fn=<MaskedFillBackward0>)

In [194]:
y_true_t2 = mask(y_true, mask_tensor, -1, 0)
y_true_t2= mask(y_true_t2, mask_tensor, -2, 0)
y_true_t

tensor([[0, 1, 1,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 0, 0,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 0, 1],
        [1, 0, 1,  ..., 0, 0, 0]])

In [196]:
mask_tensor

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])

In [193]:
y_true_t.shape, mask_tensor.shape

(torch.Size([770, 4096]), torch.Size([70, 64]))

In [187]:
mask(y_pre_t, y_true_t, -1, 0)

tensor([[    0., 10000., 10000.,  ...,     0.,     0.,     0.],
        [    0., 10000.,     0.,  ...,     0.,     0.,     0.],
        [10000.,     0.,     0.,  ...,     0.,     0.,     0.],
        ...,
        [10000.,     0.,     0.,  ...,     0.,     0.,     0.],
        [    0.,     0., 10000.,  ...,     0.,     0., 10000.],
        [10000.,     0., 10000.,  ...,     0.,     0.,     0.]],
       grad_fn=<MaskedFillBackward0>)

In [157]:
(y_pre_t == 10000).sum()

tensor(520808)

In [160]:
(mask(y_pre_t, 1 - y_true_t, -1, 0)==10000).sum()

tensor(0)

In [164]:
y_pre_t.shape, y_true_t.shape

(torch.Size([770, 4096]), torch.Size([770, 4096]))

In [161]:
a = torch.arange(8).reshape(2,2,2)
b = torch.randint(0, 2, (2,2,2))
a, b

(tensor([[[0, 1],
          [2, 3]],
 
         [[4, 5],
          [6, 7]]]),
 tensor([[[0, 1],
          [0, 0]],
 
         [[0, 0],
          [1, 0]]]))

In [163]:
(1 - 2 * b) * a

tensor([[[ 0, -1],
         [ 2,  3]],

        [[ 4,  5],
         [-6,  7]]])

In [144]:
y_neg = mask(y_pre_t, y_true_t, -1, -10000)
y_pos = mask(y_pre_t, 1 - y_true_t, -1, -10000)
zeros = torch.zeros_like(y_pre[..., :1])


In [149]:
torch.logsumexp(mask(y_neg, y_true_t, -1 , 0), -1)

tensor([10005.8809, 10005.8555, 10005.7871, 10005.8750, 10005.8350, 10005.8320,
        10005.8916, 10005.8975, 10005.7744, 10005.8662, 10005.7842, 10005.4727,
        10005.5010, 10005.4463, 10005.4424, 10005.4844, 10005.5098, 10005.5010,
        10005.4971, 10005.4248, 10005.5049, 10005.5293, 10006.8857, 10006.8340,
        10006.8486, 10006.8730, 10006.8740, 10006.8564, 10006.9131, 10006.8643,
        10006.8848, 10006.8662, 10006.8701, 10006.7021, 10006.7021, 10006.7090,
        10006.7217, 10006.7061, 10006.7344, 10006.7031, 10006.7266, 10006.6719,
        10006.7695, 10006.7061, 10006.8262, 10006.8369, 10006.8789, 10006.8389,
        10006.8555, 10006.8418, 10006.8506, 10006.8623, 10006.8086, 10006.8193,
        10006.8408, 10006.5439, 10006.4922, 10006.4971, 10006.4453, 10006.4834,
        10006.4912, 10006.5205, 10006.5312, 10006.5352, 10006.5146, 10006.5176,
        10006.9355, 10006.8906, 10006.9717, 10006.9248, 10006.9414, 10006.9053,
        10006.8711, 10006.8936, 10006.92

In [151]:
torch.logsumexp(y_pos, -1).max()

tensor(13.1222, grad_fn=<MaxBackward1>)

In [147]:
y_neg[0]

tensor([-10000.,  10000.,  10000.,  ..., -10000., -10000., -10000.],
       grad_fn=<SelectBackward>)

In [133]:
y_pre_lse = torch.logsumexp(y_pre.reshape(batch_size * labels, -1), -1)

In [134]:
torch.mean(y_pre_lse), y_pre_lse.sum()


(tensor(-562.3373, grad_fn=<MeanBackward0>),
 tensor(-432999.7500, grad_fn=<SumBackward0>))

In [135]:
y_pre_lse

tensor([ 1.0677e+01,  1.0822e+01,  1.0750e+01,  1.0693e+01,  1.0740e+01,
         1.0923e+01,  1.0747e+01,  1.0925e+01,  1.0802e+01,  1.0909e+01,
         1.0693e+01,  1.1115e+01,  1.1126e+01,  1.1079e+01,  1.1121e+01,
         1.1068e+01,  1.1020e+01,  1.1189e+01,  1.1078e+01,  1.1201e+01,
         1.1139e+01,  1.1121e+01,  7.9992e+00,  8.1396e+00,  8.3342e+00,
         8.0956e+00,  8.2037e+00,  8.1329e+00,  8.1696e+00,  8.0001e+00,
         8.1750e+00,  8.1314e+00,  8.2063e+00,  9.8333e+00,  9.8029e+00,
         9.8627e+00,  9.8792e+00,  9.6459e+00,  9.7904e+00,  9.8967e+00,
         9.9686e+00,  9.8246e+00,  9.8847e+00,  9.7843e+00,  8.6664e+00,
         8.6213e+00,  8.6628e+00,  8.4220e+00,  8.5865e+00,  8.4927e+00,
         8.6092e+00,  8.6626e+00,  8.6846e+00,  8.7370e+00,  8.4756e+00,
         1.0656e+01,  1.0792e+01,  1.0880e+01,  1.0773e+01,  1.0725e+01,
         1.0813e+01,  1.0734e+01,  1.0779e+01,  1.0896e+01,  1.0849e+01,
         1.0615e+01,  6.3006e+00,  6.4676e+00,  6.2

In [65]:
torch.logsumexp(y_pre, -1)

tensor([[[1.3776, 2.0593, 1.4604, 1.3654],
         [1.3465, 2.5925, 1.4304, 1.7803],
         [1.7513, 2.2055, 1.7226, 1.9992]],

        [[2.5427, 1.1659, 1.8437, 1.3016],
         [1.9105, 1.1027, 1.1330, 1.2608],
         [1.2859, 1.4064, 1.0067, 1.4719]]])

In [74]:
a = torch.randn(2,3)
a

tensor([[ 1.3754,  0.7536,  2.8579],
        [-0.1465,  1.8424,  1.6103]])

In [73]:
torch.logsumexp(a, -1)

tensor([8.4927, 8.5206])

In [75]:
a[0][1] = -10000
a

tensor([[ 1.3754e+00, -1.0000e+04,  2.8579e+00],
        [-1.4648e-01,  1.8424e+00,  1.6103e+00]])

In [77]:
torch.logsumexp(a, -1)

tensor([3.0625, 2.4998])

In [89]:
import math
math.exp(10)

22026.465794806718

In [91]:
y_pre.max()

tensor(3.8043)

In [88]:
torch.logsumexp(y_pre.reshape(-1), -1)

tensor(10.3947)

In [90]:
64 * 64

4096

In [105]:
a = torch.tensor([1] * 4).to(float)
torch.logsumexp(a, -1)

tensor(2.3863, dtype=torch.float64)

In [106]:
math.log(math.e * 10000)

10.210340371976184