In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
from model import BiLSTM_CRF
from utils import prepare

In [2]:
hparams = {
    'epochs': 3,
    'batch_size': 10,
    'embedding_dim': 300,
    'hidden_dim': 150,
    'device':'cuda:0'
}

In [3]:
tag2idx, vocab, loader = prepare(hparams)
hparams['vocab_size'] = len(vocab)
hparams['seq_length'] = loader.dataset.max_length

In [4]:
model = BiLSTM_CRF(hparams, tag2idx).to(hparams['device'])

In [9]:
with torch.no_grad():
    record = next(iter(loader))
    print(record['sentence'])
    print(model(record), record['label'])

tensor([[  2,   3,   2,  ...,   0,   0,   0],
        [119, 120,  51,  ...,   0,   0,   0],
        [ 43,  44,  45,  ...,   0,   0,   0],
        ...,
        [ 52, 159, 160,  ...,   0,   0,   0],
        [159, 160, 161,  ...,   0,   0,   0],
        [158,  46, 288,  ...,   0,   0,   0]], dtype=torch.int32)
(tensor([[-22697.1719],
        [-22725.3008],
        [-22712.2441],
        [-22714.0488],
        [-22678.3496],
        [-22704.0117],
        [-22661.1758],
        [-22708.7305],
        [-22665.4023],
        [-22712.5820]], device='cuda:0'), [tensor([[20],
        [20],
        [20],
        [20],
        [20],
        [20],
        [20],
        [20],
        [19],
        [20]], device='cuda:0'), tensor([[ 3],
        [ 3],
        [ 3],
        [ 3],
        [ 3],
        [ 3],
        [ 3],
        [ 3],
        [11],
        [ 3]], device='cuda:0'), tensor([[11],
        [11],
        [11],
        [11],
        [11],
        [11],
        [11],
        [11],
        [2

In [8]:
optimizer = optim.Adam(model.parameters(),lr=0.001)
for epoch in range(hparams['epochs']):
    tqdm_ = tqdm(enumerate(loader))
    total_loss = 0

    for step,x in tqdm_:
        model.zero_grad()
        loss = model.neg_log_likelihood(x)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        tqdm_.set_description("epoch {:d} , step {:d} , loss: {:.4f}".format(epoch+1, step, total_loss/(step+1)))

epoch 1 , step 97 , loss: -21724.5330: : 98it [00:26,  3.65it/s]
epoch 2 , step 97 , loss: -101667.9718: : 98it [00:26,  3.76it/s]
epoch 3 , step 97 , loss: -177464.2704: : 98it [00:28,  3.48it/s]


In [10]:
tag2idx

{'<START>': 0,
 '<END>': 1,
 '<PAD>': 2,
 'O': 3,
 'B-C': 4,
 'I-C': 5,
 'B-A': 6,
 'I-A': 7,
 'B-O': 8,
 'I-O': 9,
 'B-M': 10,
 'I-M': 11,
 'B-P': 12,
 'I-P': 13,
 'B-N': 14,
 'I-N': 15,
 'B-D': 16,
 'I-D': 17,
 'B-S': 18,
 'I-S': 19,
 'B-L': 20,
 'I-L': 21}