In [47]:
from bilstm_crf import BiLSTMCRF
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Load material

In [56]:
import json

# load embedding
embedding_maxtrix = np.load('embedding/embedding_matrix.npy')

# load vocab
with open('data/vocab.txt', 'r') as f:
    vocab = f.read().split('\n')
len(vocab)

# load tag_to_id
with open('data/tag_to_id.json', 'r') as f:
    tag_to_id = json.load((f))

# load train and dev data
TRAIN_PATH = 'data/span_detection_datasets_IOB/train.json'
DEV_PATH = 'data/span_detection_datasets_IOB/dev.json'

with open(TRAIN_PATH, 'r') as f:
    train_data = json.load(f)

with open(DEV_PATH, 'r') as f:
    dev_data = json.load(f)

train_sentences = list(train_data['text'].values())
dev_sentences = list(dev_data['text'].values())

train_labels = list(train_data['labels'].values())
dev_labels = list(dev_data['labels'].values())

## Convert data

In [27]:
import numpy as np

# Convert data to ids
def convert_to_ids(data, vocab, max_len=256):
    id_data = []

    pad_token_id = vocab.index('<PAD>')
    ukn_token_id = vocab.index('<UNK>')
    for sentence in data:
        ids = []
        for word in sentence.split():
            if word in vocab:
                ids.append(vocab.index(word))
            else:
                ids.append(ukn_token_id)

        if len(ids) < max_len:
            ids += [pad_token_id] * (max_len - len(ids))
        id_data.append(np.array(ids))
        
    return id_data

In [28]:
train_tokenized = convert_to_ids(train_sentences, vocab)
dev_tokenized = convert_to_ids(dev_sentences, vocab)

In [29]:
train_tokenized = [torch.LongTensor(tokenized) for tokenized in train_tokenized]
dev_tokenized = [torch.LongTensor(tokenized) for tokenized in dev_tokenized]

In [64]:
# load tag_to_id
with open('data/tag_to_id.json', 'r') as f:
    tag_to_id = json.load((f))

# Convert labels to ids
# labels = [[start, end, tag], ...]
def convert_labels_to_ids(label, tag_to_id, max_len=256):
    ids = [tag_to_id[tag] for tag in label]

    if len(ids) < max_len:
        ids += [tag_to_id['O']] * (max_len - len(ids))
        
    return np.array(ids)

In [65]:
train_labels_encoding = np.array([convert_labels_to_ids(label, tag_to_id) for label in train_labels])
dev_labels_encoding = np.array([convert_labels_to_ids(label, tag_to_id) for label in dev_labels])

# Model

In [68]:
span_detection_model = BiLSTMCRF(vocab_size=len(vocab), tag_to_ix=tag_to_id, hidden_dim=200, embedding_maxtrix=embedding_maxtrix)

In [None]:
# train
optimizer = optim.SGD(span_detection_model.parameters(), lr=0.01, weight_decay=1e-4)
epoch_num = 10

for epoch in range(epoch_num):
    print('epoch: {}'.format(epoch))
    for i in range(len(train_tokenized)):
        sentence = train_tokenized[i]
        label = train_labels_encoding[i]
        span_detection_model.zero_grad()
        loss = span_detection_model.neg_log_likelihood(sentence, label)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print('loss: {}'.format(loss.item()))

    # dev test
    # with torch.no_grad():
    #     for i in range(len(dev_sentences)):
    #         sentence = dev_sentences[i]
    #         label = dev_labels_encoding[i]
    #         span_detection_model.zero_grad()
    #         loss = span_detection_model.neg_log_likelihood(sentence, label)
    #         loss.backward()
    #         optimizer.step()
    #         if i % 100 == 0:
    #             print('loss: {}'.format(loss.item()))

# save model
# torch.save(span_detection_model.state_dict(), 'model/span_detection_model.pkl')