In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchtext.datasets import SequenceTaggingDataset
from torchtext.vocab import Vocab
from torchtext.data import Field, BucketIterator

import numpy as np

# Define the fields for the data
word_field = Field(init_token='<bos>', eos_token='<eos>', lower=True)
label_field = Field(init_token='<bos>', eos_token='<eos>')

# Load the dataset
train_data, val_data, test_data = SequenceTaggingDataset.splits(
    path='./',
    train='train.txt',
    validation='val.txt',
    test='test.txt',
    fields=[('words', word_field), ('labels', label_field)],
)

# Build the vocabulary
word_field.build_vocab(train_data)
label_field.build_vocab(train_data)

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        
        # Embedding layer
        self.embedding_dim = embedding_dim
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        
        # BiLSTM layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True)
        
        # Linear layer to project LSTM outputs to tag space
        self.hidden2tag = nn.Linear(hidden_dim, len(tag_to_ix))
        
        # CRF layer
        self.transitions = nn.Parameter(torch.randn(len(tag_to_ix), len(tag_to_ix)))
        self.transitions.data[tag_to_ix['<bos>'], :] = -10000.
        self.transitions.data[:, tag_to_ix['<eos>']] = -10000.
        self.tag_to_ix = tag_to_ix
        
    def forward(self, sentence):
        # Embed the input sentence
        embeds = self.word_embeds(sentence)
        
        # Pass the embeddings through the BiLSTM layer
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        
        # Project the LSTM outputs to tag space
        tag_space = self.hidden2tag(lstm_out)
        
        return tag_space
    
    def _score_sentence(self, feats, tags):
        score = torch.zeros(1).to(device)
        tags = torch.cat([torch.tensor([self.tag_to_ix['<bos>']], dtype=torch.long).to(device), tags])
        for i, feat in enumerate(feats):
            score = score + self.transitions[tags[i+1], tags[i]] + feat[tags[i+1]]
        score = score + self.transitions[self.tag_to_ix['<eos>'], tags[-1]]
        return score
    
    def _forward_alg(self, feats):
        init_alphas = torch.full((1, len(self.tag_to_ix)), -10000.).to(device)
        init_alphas[0][self.tag_to_ix['<bos>']] = 0.
        
        forward_var = init_alphas
        
        for feat in feats:
            alphas_t = []
            for next_tag in range(len(self.tag_to_ix)):
                emit_score = feat[next_tag].view