In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

class NERDataset(Dataset):
    def __init__(self, text, tags, word2idx, tag2idx, max_len):
        self.texts = texts
        self.tags = tags
        self.word2idx = word2idx
        self.tag2idx = tag2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        words = self.texts[idx]
        tags = self.tags[idx]

        # Convert words and tags into values
        word_ids = [self.word2idx.get(w, self.word2idx['<UNK>']) for w in words]
        tag_ids = [self.tag2idx[t] for t in tags]

        # Pad sequences
        word_ids = word_ids + [self.word2idx['<PAD>']] * (self.max_len - len(word_ids))
        tag_ids = tag_ids + [self.tag2idx['O']] * (self.max_len - len(tag_ids))

        return torch.tensor(word_ids), torch.tensor(tag_ids)


In [None]:
import torch.nn as nn

class NERLSTM(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim=128, hidden_dim=128):
        super(NERLSTM).__init__()

        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2idx['<PAD>'])

        # LSTM
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

        self.fc = nn.Linear(hidden_dim, tagset_size)

    def forward(self, x):
        # x: (batch_size, max_len)
        embeds = self.embedding(x)

        lstm_out, _  = self.lstm(embeds)

        tag_scores = self.fc(lstm_out)

        return tag_scores
