In [65]:
import torch
import torch.nn.functional as F
from my_py_toolkit.torch.transformer_utils import gen_pos_emb
# from my_py_toolkit.torch.tensor_toolkit import mask

In [66]:
def mask(tensor, tensor_mask, mask_dim, mask_value=0):
  """
  Mask a tensor.
  Args:
    tensor(torch.Tensor): 输入
    tensor_mask(torch.Tensor): mask 位置信息. 注：数据类型必须是 Int, 否则不能正确 mask。
    mask_dim(int): 整数，指定需要 mask 的维度，example：mask_dim = -1, 表示在最后一维上做 mask 操作.
      example: tensor is shape(3,3,3), -1 表示对最后一维的 len=3 的数组做 mask.
  Returns:
  """
  if not mask_dim < 0:
    raise Exception(f"Mask dim only supports negative numbers! Mask dim: {mask_dim} ")
  tensor_mask = 1 - tensor_mask
  if mask_dim < 0:
    mask_dim = tensor.dim() + mask_dim
  for _ in range(mask_dim - tensor_mask.dim() + 1):
    tensor_mask = tensor_mask.unsqueeze(1)
  for _ in range(tensor.dim() - tensor_mask.dim()):
    tensor_mask = tensor_mask.unsqueeze(-1)
  return tensor.masked_fill(tensor_mask, mask_value)

In [67]:
# sin 位置编码
def rope(inputs):
    """rotary position embedding"""
    length, dim = inputs.shape[-2:]
    pos = gen_pos_emb(length, dim)
    inputs_2 = torch.zeros_like(inputs)
    inputs_2[..., ::2] = -inputs[..., 1::2]
    inputs_2[..., 1::2] = inputs[..., ::2]
    sin, cos = torch.zeros_like(pos), torch.zeros_like(pos)
    sin[:, ::2], sin[:, 1::2] = pos[:, ::2], pos[:, ::2]
    cos[:, ::2], cos[:, 1::2] = pos[:, 1::2], pos[:, 1::2]
    return inputs * cos + inputs_2 * sin



In [68]:
# efficient globalpointer
class EfficientGlobalPointer(torch.nn.Module):
    def __init__(self, num_head, dim):
        super().__init__()
        self.linear_1 = torch.nn.Linear(dim, num_head * 2)
        self.linear_2 = torch.nn.Linear(num_head * 2, num_head * 2)

    def forward(self, inputs, mask_t):
        inputs = self.linear_1(inputs)
        q, k = inputs[..., ::2], inputs[..., 1::2]
        q, k = rope(q), rope(k)
        att = q@k.transpose(-1, -2)
        bias = self.linear_2(inputs).transpose(-1, -2)//2
        logits = att.unsqueeze(1) + bias[:, ::2, :].unsqueeze(-1) + bias[:, 1::2, :].unsqueeze(-2)
        logits = mask(logits, mask_t, -1, -10000)
        logits = mask(logits, mask_t, -2, -10000)
        logits = logits.masked_fill(torch.tril(torch.ones_like(logits).to(int)), -10000)
        return logits

        



In [69]:
# 
def multilabel_categorical_crossentropy(y_true, y_pre):
    y_pre = (1 - 2* y_true) * y_pre
    y_neg = mask(y_pre, y_true, -1, -10000)
    y_pos = mask(y_pre, 1-y_true, -1, -10000)
    zero = torch.zeros_like(y_true[..., :1])
    y_neg = torch.cat([y_neg, zero], -1)
    y_pos = torch.cat([y_pos, zero], -1)
    return torch.logsumexp(y_neg, -1) + torch.logsumexp(y_pos, -1)

In [70]:
def global_pointer_loss(y_true, y_pre):
    b, h = y_pre.shape[:2]
    y_pre = y_pre.reshape(b*h, -1)
    y_true = y_true.permute(0, 3, 1, 2).reshape(b*h, -1)
    return torch.mean(multilabel_categorical_crossentropy(y_true, y_pre))

In [76]:
def f1_globalpointer(y_true, y_pre):
    y_pre = y_pre.greater(0)
    return 2 * (y_true * y_pre).sum() / (y_true.sum() + y_pre.sum())

In [91]:
def em_globalpoiter(y_true, y_pre):
    y_pre = y_pre.greater(0)
    return (y_true * y_pre).sum()/ y_pre.sum()

In [89]:
def predict_globalpointer(y_pre, tags):
    res = [{}] * y_pre.shape[0]
    idx = (y_pre>0).nonzero()
    for b, tag_idx, start, end in idx.tolist():
        if tags[tag_idx] not in res[b]:
            res[b][tags[tag_idx]] = []
        res[b][tags[tag_idx]].append((start, end))
    return res


In [96]:
y_pre = torch.randn(2,3)
y_true = torch.randint(0, 2, (2,3))
y_pre, y_true

(tensor([[ 0.2778,  0.1780, -0.7899],
         [-0.5808,  1.0634, -0.0862]]),
 tensor([[1, 0, 0],
         [0, 0, 1]]))

In [97]:
em_globalpoiter(y_true, y_pre)

tensor(0.3333)

In [92]:
a = torch.randint(0, 2, (2,3,4,4))
a

tensor([[[[0, 0, 1, 1],
          [1, 1, 0, 0],
          [0, 1, 0, 1],
          [1, 1, 1, 1]],

         [[1, 0, 0, 1],
          [1, 1, 1, 1],
          [0, 0, 0, 1],
          [1, 0, 0, 0]],

         [[0, 0, 1, 1],
          [1, 0, 1, 1],
          [1, 1, 0, 0],
          [1, 0, 0, 1]]],


        [[[1, 0, 0, 0],
          [1, 1, 0, 0],
          [1, 0, 1, 0],
          [1, 1, 0, 0]],

         [[1, 1, 1, 0],
          [0, 1, 1, 0],
          [0, 0, 0, 1],
          [1, 0, 1, 1]],

         [[0, 0, 0, 1],
          [1, 0, 1, 0],
          [1, 0, 1, 1],
          [1, 1, 0, 1]]]])

In [90]:
predict_globalpointer(a, ['a', 'b', 'c'])

[{'a': [(0, 0),
   (0, 2),
   (0, 3),
   (1, 1),
   (1, 2),
   (1, 3),
   (2, 0),
   (2, 3),
   (3, 0),
   (3, 1),
   (3, 2),
   (3, 3),
   (0, 3),
   (1, 1),
   (1, 3),
   (2, 3),
   (3, 0)],
  'b': [(0, 0),
   (0, 3),
   (1, 0),
   (1, 2),
   (2, 0),
   (3, 2),
   (0, 1),
   (0, 3),
   (1, 0),
   (1, 1),
   (1, 3),
   (2, 0),
   (3, 1),
   (3, 3)],
  'c': [(0, 0),
   (0, 2),
   (0, 3),
   (1, 0),
   (1, 1),
   (1, 3),
   (2, 0),
   (2, 1),
   (2, 2),
   (0, 0),
   (0, 2),
   (0, 3),
   (1, 1),
   (1, 2),
   (1, 3),
   (2, 0),
   (3, 2),
   (3, 3)]},
 {'a': [(0, 0),
   (0, 2),
   (0, 3),
   (1, 1),
   (1, 2),
   (1, 3),
   (2, 0),
   (2, 3),
   (3, 0),
   (3, 1),
   (3, 2),
   (3, 3),
   (0, 3),
   (1, 1),
   (1, 3),
   (2, 3),
   (3, 0)],
  'b': [(0, 0),
   (0, 3),
   (1, 0),
   (1, 2),
   (2, 0),
   (3, 2),
   (0, 1),
   (0, 3),
   (1, 0),
   (1, 1),
   (1, 3),
   (2, 0),
   (3, 1),
   (3, 3)],
  'c': [(0, 0),
   (0, 2),
   (0, 3),
   (1, 0),
   (1, 1),
   (1, 3),
   (2, 0),
   (2, 

In [71]:
model = EfficientGlobalPointer(10, 768)

In [72]:
input = torch.randn(2, 512, 768)
mask_t = torch.cat([torch.ones(2, 511).to(int), torch.zeros(2,1).to(int)], dim=-1)
y_true =F.one_hot(torch.randint(0, 10, (2, 512, 512)))


In [73]:
logits = model(input, mask_t)

In [74]:
logits.shape, y_true.shape

(torch.Size([2, 10, 512, 512]), torch.Size([2, 512, 512, 10]))

In [75]:
global_pointer_loss(y_true, logits)

tensor(10021.7607, grad_fn=<MeanBackward0>)

In [60]:
t_logits = logits.reshape(20, -1)
t_y_true = y_true.permute(0, 3, 1, 2).reshape(20, -1)
multilabel_categorical_crossentropy(t_y_true, t_logits)

tensor([10021.7715, 10021.7861, 10021.7705, 10021.7842, 10021.7842, 10021.7891,
        10021.7959, 10021.7861, 10021.7852, 10021.7510, 10021.7266, 10021.7422,
        10021.7197, 10021.7324, 10021.7471, 10021.7354, 10021.7383, 10021.7227,
        10021.7217, 10021.7334], grad_fn=<AddBackward0>)

In [62]:
t_logits = (1 - 2* t_y_true) * t_logits


In [77]:
a = torch.randint(0, 10, (2,3,3))
b = torch.randint(0, 2, (2,3,3))
a, b

(tensor([[[7, 8, 1],
          [6, 1, 3],
          [3, 9, 9]],
 
         [[5, 1, 1],
          [7, 4, 1],
          [1, 3, 6]]]),
 tensor([[[0, 0, 1],
          [1, 0, 0],
          [1, 0, 0]],
 
         [[0, 1, 1],
          [0, 1, 1],
          [0, 0, 0]]]))

In [78]:
f1_globalpointer(b, a-5)

tensor(0.1429)

In [79]:
2/15

0.13333333333333333

In [63]:
torch.logsumexp(t_logits, -1)

tensor([10009.4805, 10009.4951, 10009.4727, 10009.4912, 10009.4893, 10009.4941,
        10009.5039, 10009.4941, 10009.4912, 10009.4541, 10009.4814, 10009.4971,
        10009.4766, 10009.4863, 10009.5029, 10009.4893, 10009.4941, 10009.4775,
        10009.4746, 10009.4893], grad_fn=<LogsumexpBackward>)

In [64]:
torch.logsumexp(logits.reshape(20, -1), -1)

tensor([12.3991, 12.3991, 12.3991, 12.3991, 12.3991, 12.3991, 12.3991, 12.3991,
        12.3991, 12.3991, 12.3503, 12.3503, 12.3503, 12.3503, 12.3503, 12.3503,
        12.3503, 12.3503, 12.3503, 12.3503], grad_fn=<LogsumexpBackward>)

In [57]:
torch.logsumexp(logits.reshape(20, -1), -1)

tensor([12.3991, 12.3991, 12.3991, 12.3991, 12.3991, 12.3991, 12.3991, 12.3991,
        12.3991, 12.3991, 12.3503, 12.3503, 12.3503, 12.3503, 12.3503, 12.3503,
        12.3503, 12.3503, 12.3503, 12.3503], grad_fn=<LogsumexpBackward>)

In [10]:
logits = logits.reshape(20, -1)
y_true = y_true.permute(0, 3, 1, 2).reshape(20, -1)

In [37]:
multilabel_categorical_crossentropy(y_true, logits)

tensor([10016.7969, 10016.3760, 10016.5176, 10016.6191, 10016.8467, 10016.2705,
        10016.6943, 10016.6387, 10016.7490, 10016.5791, 10016.7041, 10016.6104,
        10016.7402, 10016.7080, 10016.7930, 10016.4590, 10016.2832, 10016.7734,
        10016.5234, 10016.6299], grad_fn=<AddBackward0>)

In [12]:
logits

tensor([[ 1.2524e+00, -1.5804e+00, -2.7494e-01,  ...,  0.0000e+00,
          0.0000e+00, -1.0000e+04],
        [ 1.2524e+00, -1.5804e+00, -2.7494e-01,  ...,  0.0000e+00,
          0.0000e+00, -1.0000e+04],
        [ 1.2524e+00, -1.5804e+00, -2.7494e-01,  ...,  0.0000e+00,
          0.0000e+00, -1.0000e+04],
        ...,
        [-6.4900e-01,  9.2355e-01,  1.0836e+00,  ...,  0.0000e+00,
          0.0000e+00, -1.0000e+04],
        [-6.4900e-01,  9.2355e-01,  1.0836e+00,  ...,  0.0000e+00,
          0.0000e+00, -1.0000e+04],
        [-6.4900e-01,  9.2355e-01,  1.0836e+00,  ...,  0.0000e+00,
          0.0000e+00, -1.0000e+04]], grad_fn=<UnsafeViewBackward>)

In [14]:
torch.logsumexp(logits, -1), torch.logsumexp(y_true, -1)

RuntimeError: value cannot be converted to type int64_t without overflow: inf

In [27]:
torch.logsumexp(y_true.to(float), -1)

tensor([12.6359, 12.6353, 12.6348, 12.6350, 12.6348, 12.6350, 12.6337, 12.6366,
        12.6354, 12.6356, 12.6349, 12.6357, 12.6343, 12.6341, 12.6344, 12.6357,
        12.6349, 12.6360, 12.6358, 12.6364], dtype=torch.float64)

In [18]:
a = torch.sum(torch.exp(y_true), -1)
a

tensor([307412.1250, 307228.2812, 307071.8750, 307135.5000, 307061.5938,
        307114.8438, 306711.0625, 307608.0000, 307236.8750, 307297.0000,
        307090.8125, 307326.1875, 306896.6250, 306845.0938, 306944.7500,
        307345.1250, 307101.1250, 307417.2812, 307372.5938, 307537.5625])

In [19]:
torch.log(a)

tensor([12.6359, 12.6353, 12.6348, 12.6350, 12.6348, 12.6350, 12.6337, 12.6366,
        12.6354, 12.6356, 12.6349, 12.6357, 12.6343, 12.6341, 12.6344, 12.6357,
        12.6349, 12.6360, 12.6358, 12.6364])

In [24]:
y_true.shape

torch.Size([20, 262144])

In [125]:
mask_t = torch.cat([torch.ones(2,511), torch.zeros(2, 1)], dim=-1)
for i in range(2):
    mask_t = mask_t.unsqueeze(1)
mask_t.shape

torch.Size([2, 1, 1, 512])

In [121]:
c = mask(a, mask_t, -1, 0)

In [128]:
a.masked_fill(mask_t, 0)

tensor([[[[ 6.1681e-01,  9.2337e-01,  6.8976e-01,  ...,  8.1657e-01,
            4.9162e-01, -3.2247e-01],
          [-5.0044e-01,  9.0326e-01,  4.9031e-01,  ...,  8.9177e-01,
           -3.7672e-02, -6.6672e-02],
          [ 3.2430e-01, -1.0819e+00, -9.6953e-01,  ..., -1.8766e+00,
            5.2220e-01, -1.3683e+00],
          ...,
          [ 1.3276e+00, -3.4962e-02,  4.1144e-01,  ..., -1.2348e+00,
           -1.3301e+00,  1.4601e+00],
          [ 1.8981e-01,  6.5370e-01,  1.8940e+00,  ...,  3.5038e-01,
            6.2200e-01,  8.4701e-01],
          [ 1.3901e-01,  5.4822e-02,  8.2181e-01,  ..., -5.4730e-02,
           -2.8861e+00,  1.0642e+00]],

         [[ 1.1364e-01, -1.0161e+00,  1.6453e-01,  ..., -1.3152e+00,
            2.7253e-01,  2.9808e+00],
          [-9.6373e-01, -1.1081e+00, -2.1150e-01,  ...,  1.1506e+00,
           -5.9712e-01, -1.2213e+00],
          [-5.0249e-01, -1.2690e+00,  1.7970e+00,  ...,  1.5459e+00,
           -1.0546e+00, -1.7689e+00],
          ...,
     

In [123]:
c [0, 0, 0,]

tensor(-0.3225)

In [111]:
a = torch.randn(2,3,4,4)

a

tensor([[[[-0.7083,  0.6395,  0.9476,  1.6247],
          [-0.7797, -0.5809, -0.1415, -0.2204],
          [-0.7749, -1.5189,  0.0349,  0.5796],
          [ 0.9347, -1.0274,  2.0106,  0.2294]],

         [[-0.2601, -0.2294, -0.4697, -1.3011],
          [-0.7359,  0.2582, -1.7443, -1.4112],
          [ 0.2663, -0.2478,  0.6486,  0.3078],
          [ 2.4497, -1.0616, -2.0192,  0.4455]],

         [[ 1.0837,  0.1428, -0.2025,  0.1203],
          [-0.3908, -0.2356,  0.9879,  0.4874],
          [-1.2784, -0.8443,  1.4261, -0.3821],
          [ 0.3224,  0.3406,  0.6424, -1.9446]]],


        [[[-1.8480,  0.2892, -0.5073, -0.6431],
          [-1.0819, -0.9472,  1.5423, -0.0196],
          [-1.5680,  0.0045, -0.0828,  0.6299],
          [ 0.7570, -1.0380,  0.7313, -1.1656]],

         [[-0.9150, -0.6224,  0.4106, -0.7184],
          [ 2.0652, -0.8757, -1.4853, -0.5003],
          [-1.0999, -0.9408, -0.9605,  0.1260],
          [-1.6409,  0.1670,  0.6775,  0.0262]],

         [[ 1.5063, -0.4165,

torch.Size([2, 4])

In [112]:
m_t = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])
# for i in range(1):
#     m_t = m_t.unsqueeze(1)
m_t.shape

torch.Size([2, 4])

In [113]:
mask(a, mask_t, -1, 0)

RuntimeError: The size of tensor a (512) must match the size of tensor b (4) at non-singleton dimension 3

In [99]:
a.masked_fill(m_t, 0)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

In [20]:
a = torch.randn(2,3,3)
b = torch.randn(4, 3)
a, b

(tensor([[[-0.1185,  1.7025, -2.1981],
          [-1.9562,  1.0296,  0.1012],
          [ 1.6400, -1.7243,  0.6125]],
 
         [[ 1.5104, -1.4737, -1.2959],
          [ 0.9215, -0.1239,  0.9529],
          [ 0.2102, -0.7892,  0.0976]]]),
 tensor([[-0.5226, -0.5000, -0.6552],
         [-1.2526, -0.1788, -0.9412],
         [-0.5666,  1.1390, -0.6709],
         [ 1.1663,  0.4761, -0.3054]]))

In [21]:
c = a.unsqueeze(1) + b.unsqueeze(-1) + b.unsqueeze(-2)
c

tensor([[[[-1.1637e+00,  6.7989e-01, -3.3760e+00],
          [-2.9788e+00,  2.9596e-02, -1.0540e+00],
          [ 4.6214e-01, -2.8795e+00, -6.9789e-01]],

         [[-2.6237e+00,  2.7115e-01, -4.3920e+00],
          [-3.3876e+00,  6.7206e-01, -1.0188e+00],
          [-5.5386e-01, -2.8443e+00, -1.2699e+00]],

         [[-1.2516e+00,  2.2750e+00, -3.4355e+00],
          [-1.3837e+00,  3.3077e+00,  5.6933e-01],
          [ 4.0255e-01, -1.2562e+00, -7.2924e-01]],

         [[ 2.2142e+00,  3.3450e+00, -1.3372e+00],
          [-3.1375e-01,  1.9818e+00,  2.7192e-01],
          [ 2.5009e+00, -1.5536e+00,  1.7480e-03]]],


        [[[ 4.6511e-01, -2.4964e+00, -2.4737e+00],
          [-1.0114e-01, -1.1239e+00, -2.0235e-01],
          [-9.6761e-01, -1.9444e+00, -1.2128e+00]],

         [[-9.9484e-01, -2.9051e+00, -3.4897e+00],
          [-5.0988e-01, -4.8147e-01, -1.6714e-01],
          [-1.9836e+00, -1.9092e+00, -1.7849e+00]],

         [[ 3.7727e-01, -9.0125e-01, -2.5333e+00],
          [ 1.494

In [22]:
c.shape

torch.Size([2, 4, 3, 3])

In [18]:
b.transpose(1, 0)

tensor([[[ 0.5364,  0.0000,  0.0000],
         [-0.5761,  0.0000,  0.0000]],

        [[ 0.6233,  0.8007,  0.0000],
         [-0.8446,  1.0971,  0.0000]],

        [[ 1.2545,  0.2262,  0.0550],
         [-1.4512,  0.7378,  1.4985]]])

In [6]:
b[:, ::2] = -a[:, 1::2]
b[:, 1::2] = a[:, ::2]
b

tensor([[ 0.4895, -0.6885,  1.4711,  0.1186],
        [-1.5220,  0.7241,  1.3750,  0.4996]])

In [8]:
b[..., ::2]

tensor([[ 0.4895,  1.4711],
        [-1.5220,  1.3750]])

In [9]:
a = torch.randn(2,3,4)
a

tensor([[[-1.4303, -1.3988,  0.1014,  0.1428],
         [-0.0617, -0.3197,  0.5897, -0.2629],
         [ 0.8926,  1.7246,  0.0836, -0.2862]],

        [[-0.9023, -0.2006, -0.9303,  0.9663],
         [-1.7485,  1.5907, -1.2352,  1.4900],
         [ 1.5242, -1.2376, -0.3518, -1.1154]]])

In [11]:
a[..., ::2]

tensor([[[-1.4303,  0.1014],
         [-0.0617,  0.5897],
         [ 0.8926,  0.0836]],

        [[-0.9023, -0.9303],
         [-1.7485, -1.2352],
         [ 1.5242, -0.3518]]])