In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
from model import BiLSTM_CRF
from utils import prepare, predict

In [2]:
hparams = {
    'path':'/home/peitian_zhang/Data/NER/labeled_train.txt',
    'epochs': 150,
    'batch_size': 100,
    'embedding_dim': 300,
    'hidden_dim': 150,
    'device':'cuda:0',
}

In [3]:
tag2idx, vocab, loader = prepare(hparams)
hparams['vocab_size'] = len(vocab)
hparams['seq_length'] = loader.dataset.max_length
tag2idx

{'<START>': 0,
 '<END>': 1,
 '<PAD>': 2,
 'O': 3,
 'B-C': 4,
 'I-C': 5,
 'B-A': 6,
 'I-A': 7,
 'B-O': 8,
 'I-O': 9,
 'B-M': 10,
 'I-M': 11,
 'B-P': 12,
 'I-P': 13,
 'B-N': 14,
 'I-N': 15,
 'B-D': 16,
 'I-D': 17,
 'B-S': 18,
 'I-S': 19,
 'B-L': 20,
 'I-L': 21}

In [4]:
model = BiLSTM_CRF(hparams, tag2idx).to(hparams['device'])

In [5]:
with torch.no_grad():
    record = next(iter(loader))
    _, tag_seq = model(record['token'])
    print(tag_seq, record['label'])

tensor([[ 9, 11, 10,  ..., 11, 10,  8],
        [ 9, 11, 10,  ..., 11, 10,  8],
        [ 9, 11, 10,  ...,  8,  6,  8],
        ...,
        [ 9, 11, 10,  ..., 11, 10,  8],
        [ 9, 11, 10,  ..., 11, 10,  8],
        [ 9, 11, 10,  ..., 11, 10,  8]], device='cuda:0') tensor([[ 3,  3,  3,  ...,  2,  2,  2],
        [ 3,  3,  3,  ...,  2,  2,  2],
        [14, 15, 15,  ...,  2,  2,  2],
        ...,
        [ 3,  3,  3,  ...,  2,  2,  2],
        [ 8,  9,  9,  ...,  2,  2,  2],
        [16, 17, 17,  ...,  2,  2,  2]])


In [6]:
optimizer = optim.Adam(model.parameters(),lr=0.001)
for epoch in range(hparams['epochs']):
    tqdm_ = tqdm(enumerate(loader))
    total_loss = 0

    for step,x in tqdm_:
        model.zero_grad()
        loss = model.neg_log_likelihood(x)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        tqdm_.set_description("epoch {:d} , step {:d} , loss: {:.4f}".format(epoch+1, step, total_loss/(step+1)))

epoch 1 , step 9 , loss: 1421.0073: : 10it [00:03,  3.29it/s]
epoch 2 , step 9 , loss: 899.5295: : 10it [00:02,  3.37it/s]
epoch 3 , step 9 , loss: 295.0562: : 10it [00:02,  3.37it/s]
epoch 4 , step 9 , loss: -333.2220: : 10it [00:02,  3.36it/s]
epoch 5 , step 9 , loss: -970.6711: : 10it [00:02,  3.44it/s]
epoch 6 , step 9 , loss: -1623.1125: : 10it [00:02,  3.38it/s]
epoch 7 , step 9 , loss: -2296.8902: : 10it [00:02,  3.34it/s]
epoch 8 , step 9 , loss: -2991.2155: : 10it [00:02,  3.35it/s]
epoch 9 , step 9 , loss: -3682.3815: : 10it [00:02,  3.41it/s]
epoch 10 , step 9 , loss: -4383.0485: : 10it [00:02,  3.38it/s]
epoch 11 , step 9 , loss: -5080.5594: : 10it [00:02,  3.43it/s]
epoch 12 , step 9 , loss: -5774.6074: : 10it [00:02,  3.43it/s]
epoch 13 , step 9 , loss: -6459.7176: : 10it [00:02,  3.39it/s]
epoch 14 , step 9 , loss: -7135.3322: : 10it [00:02,  3.47it/s]
epoch 15 , step 9 , loss: -7804.2098: : 10it [00:03,  3.27it/s]
epoch 16 , step 9 , loss: -8465.2892: : 10it [00:02,  3.

In [7]:
with torch.no_grad():
    record = next(iter(loader))
    _, tag_seq = model(record['token'])
    print("Prediction:{}\n Ground Truth:{}".format(tag_seq, record['label']))

Prediction:tensor([[ 3,  3,  3,  ...,  2,  2,  2],
        [ 3,  3,  3,  ...,  2,  2,  2],
        [ 5, 15,  3,  ...,  2,  2,  2],
        ...,
        [ 3,  3,  3,  ...,  2,  2,  2],
        [ 9,  9,  9,  ...,  2,  2,  2],
        [12, 16, 17,  ...,  2,  2,  2]], device='cuda:0')
 Ground Truth:tensor([[ 3,  3,  3,  ...,  2,  2,  2],
        [ 3,  3,  3,  ...,  2,  2,  2],
        [14, 15, 15,  ...,  2,  2,  2],
        ...,
        [ 3,  3,  3,  ...,  2,  2,  2],
        [ 8,  9,  9,  ...,  2,  2,  2],
        [16, 17, 17,  ...,  2,  2,  2]])


In [8]:
predict(['窦志成获奖'],model,vocab)

[['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<P