In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, tag_size, embedding_dim, hidden_dim):
        super(BiLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_dim, tag_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.fc(lstm_out)
        tag_scores = nn.functional.log_softmax(tag_space, dim=2)
        return tag_scores

class SequenceTaggingDataset(Dataset):
    def __init__(self, sentences, tags, word_to_ix, tag_to_ix):
        self.sentences = sentences
        self.tags = tags
        self.word_to_ix = word_to_ix
        self.tag_to_ix = tag_to_ix

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        sentence = self.sentences[index]
        tag = self.tags[index]
        sentence_ix = [self.word_to_ix[w] for w in sentence]
        tag_ix = [self.tag_to_ix[t] for t in tag]
        return sentence_ix, tag_ix

# sample data
sentences = [['The', 'cat', 'is', 'on', 'the', 'mat'], ['The', 'dog', 'is', 'in', 'the', 'house']]
tags = [['DET', 'NOUN', 'VERB', 'PREP', 'DET', 'NOUN'], ['DET', 'NOUN', 'VERB', 'PREP', 'DET', 'NOUN']]
word_to_ix = {'The': 0, 'cat': 1, 'is': 2, 'on': 3, 'the': 4, 'mat': 5, 'dog': 6, 'in': 7, 'house': 8}
tag_to_ix = {'DET': 0, 'NOUN': 1, 'VERB': 2, 'PREP': 3}

# set up dataset and dataloader
dataset = SequenceTaggingDataset(sentences, tags, word_to_ix, tag_to_ix)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# set up model and optimizer
model = BiLSTM(len(word_to_ix), len(tag_to_ix), 16, 16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# train model
for epoch in range(100):
    for batch_sentence, batch_tag in dataloader:
        optimizer.zero_grad()
        tag_scores = model(torch.FloatTensor(batch_sentence))
        loss = nn.functional.nll_loss(tag_scores.view(-1, len(tag_to_ix)), torch.tensor(batch_tag).view(-1))
        loss.backward()
        optimizer.step()
    if epoch % 10 == 0:
        print('Epoch:', epoch, 'Loss:', loss.item())

ValueError: only one element tensors can be converted to Python scalars