In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
from model import BiLSTM_CRF
from utils import prepare, predict

In [2]:
hparams = {
    'path':'/home/peitian_zhang/Data/NER/labeled_train.txt',
    'epochs': 200,
    'batch_size': 100,
    'embedding_dim': 300,
    'hidden_dim': 150,
    'device':'cuda:0',
}

In [3]:
tag2idx, vocab, loader = prepare(hparams)
hparams['vocab_size'] = len(vocab)
hparams['seq_length'] = loader.dataset.max_length
tag2idx

{'<START>': 0,
 '<END>': 1,
 '<PAD>': 2,
 'O': 3,
 'B-C': 4,
 'I-C': 5,
 'B-A': 6,
 'I-A': 7,
 'B-O': 8,
 'I-O': 9,
 'B-M': 10,
 'I-M': 11,
 'B-P': 12,
 'I-P': 13,
 'B-N': 14,
 'I-N': 15,
 'B-D': 16,
 'I-D': 17,
 'B-S': 18,
 'I-S': 19,
 'B-L': 20,
 'I-L': 21}

In [4]:
model = BiLSTM_CRF(hparams, tag2idx).to(hparams['device'])

In [5]:
with torch.no_grad():
    record = next(iter(loader))
    _, tag_seq = model(record['token'])
    print("Prediction:{}\n Ground Truth:{}".format(tag_seq, record['label']))

tensor([[21, 20, 20,  ...,  3, 21, 13],
        [21,  3,  3,  ...,  3, 21, 13],
        [21,  3,  3,  ...,  3, 21, 13],
        ...,
        [21, 20, 20,  ...,  3, 21, 13],
        [21, 20, 20,  ...,  3, 21, 13],
        [21, 20, 20,  ...,  3, 21, 13]], device='cuda:0') tensor([[ 3,  3,  3,  ...,  2,  2,  2],
        [ 3,  3,  3,  ...,  2,  2,  2],
        [14, 15, 15,  ...,  2,  2,  2],
        ...,
        [ 3,  3,  3,  ...,  2,  2,  2],
        [ 8,  9,  9,  ...,  2,  2,  2],
        [16, 17, 17,  ...,  2,  2,  2]])


In [6]:
optimizer = optim.Adam(model.parameters(),lr=0.001)
for epoch in range(hparams['epochs']):
    tqdm_ = tqdm(enumerate(loader))
    total_loss = 0

    for step,x in tqdm_:
        model.zero_grad()
        loss = model.neg_log_likelihood(x)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        tqdm_.set_description("epoch {:d} , step {:d} , loss: {:.4f}".format(epoch+1, step, total_loss/(step+1)))

epoch 1 , step 9 , loss: 1358.4268: : 10it [00:03,  3.21it/s]
epoch 2 , step 9 , loss: 818.9688: : 10it [00:03,  3.20it/s]
epoch 3 , step 9 , loss: 233.4179: : 10it [00:03,  3.17it/s]
epoch 4 , step 9 , loss: -379.6076: : 10it [00:02,  3.34it/s]
epoch 5 , step 9 , loss: -1000.3245: : 10it [00:03,  3.14it/s]
epoch 6 , step 9 , loss: -1653.9761: : 10it [00:02,  3.42it/s]
epoch 7 , step 9 , loss: -2331.6022: : 10it [00:02,  3.41it/s]
epoch 8 , step 9 , loss: -3020.8292: : 10it [00:03,  3.25it/s]
epoch 9 , step 9 , loss: -3733.6577: : 10it [00:03,  3.09it/s]
epoch 10 , step 9 , loss: -4429.4866: : 10it [00:02,  3.34it/s]
epoch 11 , step 9 , loss: -5124.9079: : 10it [00:02,  3.41it/s]
epoch 12 , step 9 , loss: -5812.1964: : 10it [00:02,  3.45it/s]
epoch 13 , step 9 , loss: -6489.9605: : 10it [00:03,  3.09it/s]
epoch 14 , step 9 , loss: -7159.4641: : 10it [00:03,  3.06it/s]
epoch 15 , step 9 , loss: -7822.5516: : 10it [00:03,  3.05it/s]
epoch 16 , step 9 , loss: -8479.5468: : 10it [00:02,  3

In [7]:
with torch.no_grad():
    record = next(iter(loader))
    _, tag_seq = model(record['token'])
    print("Prediction:{}\n Ground Truth:{}".format(tag_seq, record['label']))

Prediction:tensor([[3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        ...,
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3]], device='cuda:0')
 Ground Truth:tensor([[ 3,  3,  3,  ...,  2,  2,  2],
        [ 3,  3,  3,  ...,  2,  2,  2],
        [14, 15, 15,  ...,  2,  2,  2],
        ...,
        [ 3,  3,  3,  ...,  2,  2,  2],
        [ 8,  9,  9,  ...,  2,  2,  2],
        [16, 17, 17,  ...,  2,  2,  2]])


In [8]:
predict(['窦志成获奖'],model,vocab)

[['B-M',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O