- date: 2021-02-04 10:17:17
- author: Jerry Su
- slug: Conditional-Random-Field
- title: ConditionalRandomField
- category: 
- tags: Deep Learning, Pytorch, NLP

In [36]:
import torch
from torch import nn

In [144]:
# BMES四位序列标注法
PADDING = 0
B = 2
E = 3
S = 4
M = 5
START = 6
END = 7

LABEL_VOCAB = {0: '<pad>', 1: '<unk>', 2: 'B', 3: 'E', 4: 'S', 5: 'M'}

NUM_TAGS = 8

## 1.1 Initialize Input

In [145]:
# batch_size x max_len x num_tags : (2, 9, 6)

logits = torch.randn(3, 9, 6)
print(f"logits:\n {logits}\n")
tags = torch.tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
                     [6, 4, 2, 3, 4, 4, 7, 0, 0],
                     [6, 2, 5, 3, 4, 7, 0, 0, 0]])
print(f"tags:\n {tags}\n")
mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1, 1, 1, 0, 0],
                     [1, 1, 1, 1, 1, 1, 0, 0, 0]])
print(f"mask:\n {mask}")

logits:
 tensor([[[ 1.0642e+00,  1.3428e+00, -2.9033e-01,  1.2010e+00,  1.9096e+00,
          -7.4157e-01],
         [ 5.9740e-01, -1.7286e-01,  7.1238e-01, -1.1167e+00,  3.8874e-01,
          -2.3098e-01],
         [-5.7265e-02,  9.0805e-01, -1.3484e+00,  8.4410e-01,  7.1468e-01,
          -8.3558e-01],
         [ 3.1047e-01, -1.8642e+00,  1.5984e-01,  1.3603e+00, -2.2918e+00,
          -2.9508e-01],
         [ 6.9469e-01,  3.5199e-01, -1.1742e+00,  1.1912e+00, -9.1062e-01,
          -5.7808e-01],
         [ 5.8577e-01, -1.2518e+00,  1.6601e-01, -7.7357e-01,  1.3718e+00,
          -1.1598e+00],
         [ 5.3231e-01,  4.9077e-01, -2.2624e-02, -1.1507e+00, -2.0269e+00,
           1.0734e+00],
         [-4.8735e-02,  5.8399e-01, -1.2517e-01, -6.9643e-01, -6.7877e-01,
          -1.7191e+00],
         [ 8.2640e-01, -8.9425e-01, -1.7231e-01,  2.9305e-01,  1.0767e+00,
           2.7385e-01]],

        [[-3.6398e-01,  8.9014e-01,  6.6145e-01,  4.8513e-01, -1.1231e-01,
           2.6683e-01],

In [146]:
logits = logits.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()

In [147]:
trans_matrix = nn.Parameter(torch.randn(NUM_TAGS, NUM_TAGS))
trans_matrix

Parameter containing:
tensor([[ 0.6696, -0.4195,  0.4738,  0.0122, -0.8303,  0.0965,  1.9076, -1.9062],
        [ 0.9474,  0.7148,  1.0506,  0.0498,  0.0609, -0.6184,  1.4469, -0.5844],
        [-1.6855,  1.0079,  0.7865,  1.1741,  0.3916,  0.6720,  1.5643, -1.2791],
        [-0.7378, -0.5273,  1.1768,  0.6547, -1.2442,  0.0211,  1.2717, -0.5409],
        [ 1.7773,  0.1156,  1.4798, -0.7249,  0.0326,  0.5933, -1.1687,  0.3685],
        [-1.4618,  0.5730,  0.5229, -1.3505,  0.3255,  1.2817,  1.3270, -0.4554],
        [-0.6263, -1.0785,  1.3931,  0.0338,  0.3113,  0.2079, -1.0331,  0.3655],
        [ 3.1040, -1.7638,  0.3538,  0.0175,  0.0066, -0.6321, -1.5344,  2.1579]],
       requires_grad=True)

## 2. Compute the score for the gold path.

In [148]:
seq_len, batch_size, _ = logits.size()
print(f"seq_len: {seq_len}\nbatch_size: {batch_size}")

seq_len: 9
batch_size: 3


In [149]:
batch_idx = torch.arange(batch_size, dtype=torch.long)
batch_idx

tensor([0, 1, 2])

In [150]:
seq_idx = torch.arange(seq_len, dtype=torch.long)
seq_idx

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

### 2.1 transition probability score

In [151]:
mask

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 0.],
        [1., 0., 0.],
        [0., 0., 0.]])

In [152]:
mask = mask.eq(True)
mask

tensor([[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True, False],
        [ True, False, False],
        [False, False, False]])

In [153]:
flip_mask = mask.eq(False)
flip_mask

tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])

In [154]:
tags

tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])

In [155]:
tags[: seq_len -1]

tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0]])

In [156]:
tags[1:]

tensor([[4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])

In [157]:
trans_matrix

Parameter containing:
tensor([[ 0.6696, -0.4195,  0.4738,  0.0122, -0.8303,  0.0965,  1.9076, -1.9062],
        [ 0.9474,  0.7148,  1.0506,  0.0498,  0.0609, -0.6184,  1.4469, -0.5844],
        [-1.6855,  1.0079,  0.7865,  1.1741,  0.3916,  0.6720,  1.5643, -1.2791],
        [-0.7378, -0.5273,  1.1768,  0.6547, -1.2442,  0.0211,  1.2717, -0.5409],
        [ 1.7773,  0.1156,  1.4798, -0.7249,  0.0326,  0.5933, -1.1687,  0.3685],
        [-1.4618,  0.5730,  0.5229, -1.3505,  0.3255,  1.2817,  1.3270, -0.4554],
        [-0.6263, -1.0785,  1.3931,  0.0338,  0.3113,  0.2079, -1.0331,  0.3655],
        [ 3.1040, -1.7638,  0.3538,  0.0175,  0.0066, -0.6321, -1.5344,  2.1579]],
       requires_grad=True)

In [158]:
tags

tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])

In [159]:
trans_score = trans_matrix[tags[:seq_len - 1], tags[1:]]
trans_score

tensor([[ 0.3113,  0.3113,  1.3931],
        [ 1.4798,  1.4798,  0.6720],
        [ 0.6720,  1.1741, -1.3505],
        [ 1.2817, -1.2442, -1.2442],
        [-1.3505,  0.0326,  0.3685],
        [-1.2442,  0.3685,  3.1040],
        [ 0.3685,  3.1040,  0.6696],
        [ 3.1040,  0.6696,  0.6696]], grad_fn=<IndexBackward>)

In [160]:
flip_mask

tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])

In [161]:
flip_mask[1:, :]

tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])

In [162]:
trans_score = trans_score.masked_fill(flip_mask[1:, :], 0)
trans_score

tensor([[ 0.3113,  0.3113,  1.3931],
        [ 1.4798,  1.4798,  0.6720],
        [ 0.6720,  1.1741, -1.3505],
        [ 1.2817, -1.2442, -1.2442],
        [-1.3505,  0.0326,  0.3685],
        [-1.2442,  0.3685,  0.0000],
        [ 0.3685,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]], grad_fn=<MaskedFillBackward0>)

### 2.2 emission probability score

In [164]:
# emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
# emit_score

In [167]:
logits.size()

torch.Size([9, 3, 6])

In [168]:
seq_idx

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

In [169]:
seq_idx.size()

torch.Size([9])

In [170]:
seq_idx.view(-1, 1)

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8]])

In [171]:
batch_idx

tensor([0, 1, 2])

In [173]:
batch_idx.size()

torch.Size([3])

In [175]:
batch_idx.view(1, -1).size()

torch.Size([1, 3])

In [176]:
batch_idx.view(1, -1)

tensor([[0, 1, 2]])

In [177]:
tags

tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])

In [178]:
logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags]

IndexError: index 6 is out of bounds for dimension 2 with size 6