In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.metrics import classification_report
from tqdm import tqdm
import pickle

import warnings
warnings.filterwarnings('ignore')

In [5]:
def build_vocab(file_path):
    word2idx = defaultdict(lambda: len(word2idx))
    tag2idx = defaultdict(lambda: len(tag2idx))
    word2idx["<PAD>"] = 0
    # tag2idx["O"] = 0
    with open(file_path, 'r') as f:
        for line in f:
            if line.strip():
                parts = line.strip().split(' ')
                word, tag = parts[0], parts[3]
                word2idx[word]
                tag2idx[tag]
    return dict(word2idx), dict(tag2idx)

In [6]:
class NERDataset(Dataset):
    def __init__(self, file_path, word2idx, tag2idx, max_len):
        self.sentences, self.labels = self._read_data(file_path)
        self.word2idx = word2idx
        self.tag2idx = tag2idx
        self.max_len = max_len

    def _read_data(self, file_path):
        sentences, labels = [], []
        sentence, label = [], []
        with open(file_path, 'r') as f:
            for line in f:
                if line.strip() == "":
                    if sentence:
                        sentences.append(sentence)
                        labels.append(label)
                        sentence, label = [], []
                else:
                    parts = line.strip().split()
                    word, tag = parts[0], parts[3]
                    sentence.append(word)
                    label.append(tag)
        if sentence:
            sentences.append(sentence)
            labels.append(label)
        return sentences, labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]

        # Convert to indices
        word_indices = [self.word2idx.get(w, self.word2idx["<PAD>"]) for w in sentence]
        tag_indices = [self.tag2idx.get(t, self.tag2idx["O"]) for t in label]

        # Pad sequences
        word_indices = word_indices[:self.max_len] + [self.word2idx["<PAD>"]] * (self.max_len - len(word_indices))
        tag_indices = tag_indices[:self.max_len] + [self.tag2idx["O"]] * (self.max_len - len(tag_indices))

        return torch.tensor(word_indices), torch.tensor(tag_indices)


In [8]:
class BiLSTMNER(nn.Module):
    def __init__(self, vocab_size, tag_size, embed_dim, hidden_dim):
        super(BiLSTMNER, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, tag_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        logits = self.fc(lstm_out)
        return logits

In [9]:
def train_model(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for words, tags in tqdm(loader, desc="Training", unit="batch"):
        words, tags = words.to(device), tags.to(device)
        optimizer.zero_grad()
        outputs = model(words)
        # Flatten for loss computation
        outputs = outputs.view(-1, outputs.shape[-1])  
        tags = tags.view(-1)
        loss = criterion(outputs, tags)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [10]:
def evaluate_model(model, loader, idx2tag, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for words, tags in loader:
            words, tags = words.to(device), tags.to(device)
            outputs = model(words)
            predictions = torch.argmax(outputs, dim=-1)
            all_preds.extend(predictions.view(-1).tolist())
            all_labels.extend(tags.view(-1).tolist())
    # Remove padding tokens
    valid_preds = [idx2tag[p] for p, l in zip(all_preds, all_labels) if idx2tag[l] != "O"]
    valid_labels = [idx2tag[l] for p, l in zip(all_preds, all_labels) if idx2tag[l] != "O"]
    # valid_preds = [idx2tag[p] for p, l in zip(all_preds, all_labels)]
    # valid_labels = [idx2tag[l] for p, l in zip(all_preds, all_labels)]
    return classification_report(valid_labels, valid_preds, output_dict=False)

In [11]:
def save_model(model, path, word2idx, tag2idx):
    torch.save(model.state_dict(), f"{path}.pth")
    with open(f"{path}_vocab.pkl", "wb") as f:
        pickle.dump({"word2idx": word2idx, "tag2idx": tag2idx}, f)
    print(f"Model and vocab saved to {path}.pth and {path}_vocab.pkl")

def load_model(model_class, path, vocab_path, embed_dim, hidden_dim, device):
    with open(vocab_path, "rb") as f:
        vocab = pickle.load(f)
    word2idx, tag2idx = vocab["word2idx"], vocab["tag2idx"]
    model = model_class(len(word2idx), len(tag2idx), embed_dim, hidden_dim)
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    model.eval()
    print(f"Model loaded from {path}")
    return model, word2idx, tag2idx

def predict(model, text, word2idx, idx2tag, max_len, device):
    words = text.split()
    word_indices = [word2idx.get(w, word2idx["<PAD>"]) for w in words]
    word_indices = word_indices[:max_len] + [word2idx["<PAD>"]] * (max_len - len(word_indices))
    
    model_input = torch.tensor([word_indices]).to(device)
    with torch.no_grad():
        outputs = model(model_input)
        predictions = torch.argmax(outputs, dim=-1).squeeze(0).tolist()

    # Convert predictions to tags
    tags = [idx2tag[idx] for idx in predictions[:len(words)]]
    return list(zip(words, tags))

In [12]:
# Paths
train_file = "data/eng/eng.train"
val_file = "data/eng/eng.testa"
test_file = "data/eng/eng.testb"

# Build vocabulary
word2idx, tag2idx = build_vocab(train_file)
idx2tag = {idx: tag for tag, idx in tag2idx.items()}

In [13]:
# Hyperparameters
embed_dim = 100
hidden_dim = 128
max_len = 50
batch_size = 32
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
train_dataset = NERDataset(train_file, word2idx, tag2idx, max_len)
val_dataset = NERDataset(val_file, word2idx, tag2idx, max_len)
test_dataset = NERDataset(test_file, word2idx, tag2idx, max_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [15]:
# # Calculate class weights
# num_tags = len(tag2idx)
# tag_counts = [0] * num_tags
# for _, labels in train_dataset:
#     for tag in labels.tolist():
#         tag_counts[tag] += 1
# total_tags = sum(tag_counts)
# class_weights = [total_tags / count if count > 0 else 0.0 for count in tag_counts]

# # Convert to tensor and move to device
# weights = torch.tensor(class_weights).to(device)

# # Define loss function with weights
# criterion = nn.CrossEntropyLoss(ignore_index=0, weight=weights)

In [16]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, ignore_index=-1):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='none')(inputs, targets)
        pt = torch.exp(-ce_loss)  # Probabilities of the true class
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

# Replace criterion with FocalLoss
criterion = FocalLoss(alpha=1, gamma=2, ignore_index=0)


In [17]:
# Model, optimizer, loss
model = BiLSTMNER(len(word2idx), len(tag2idx), embed_dim, hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [18]:
# Train and validate
for epoch in range(epochs):
    print(f"Epoch [{epoch + 1}/{epochs}]")
    train_loss = train_model(model, train_loader, optimizer, criterion, device)
    print(f"Training Loss: {train_loss:.4f}")

    print("Validation:")
    val_report = evaluate_model(model, val_loader, idx2tag, device)
    print(val_report)
    print("===================================================================================")

# Final test evaluation
print("Final Test Evaluation:")
test_report = evaluate_model(model, test_loader, idx2tag, device)
print(test_report)

Epoch [1/5]


Training: 100%|██████████| 469/469 [00:02<00:00, 178.02batch/s]


Training Loss: 0.0822
Validation:
              precision    recall  f1-score   support

      B-MISC       0.00      0.00      0.00         4
       I-LOC       0.85      0.49      0.62      2088
      I-MISC       0.89      0.29      0.43      1258
       I-ORG       0.00      0.00      0.00      2085
       I-PER       0.92      0.43      0.59      3053
           O       0.00      0.00      0.00         0

    accuracy                           0.32      8488
   macro avg       0.44      0.20      0.27      8488
weighted avg       0.67      0.32      0.43      8488

Epoch [2/5]


Training: 100%|██████████| 469/469 [00:02<00:00, 205.23batch/s]


Training Loss: 0.0199
Validation:
              precision    recall  f1-score   support

      B-MISC       0.00      0.00      0.00         4
       I-LOC       0.85      0.68      0.76      2088
      I-MISC       0.88      0.55      0.67      1258
       I-ORG       0.00      0.00      0.00      2085
       I-PER       0.92      0.60      0.73      3053
           O       0.00      0.00      0.00         0

    accuracy                           0.46      8488
   macro avg       0.44      0.30      0.36      8488
weighted avg       0.67      0.46      0.55      8488

Epoch [3/5]


Training: 100%|██████████| 469/469 [00:02<00:00, 206.08batch/s]


Training Loss: 0.0105
Validation:
              precision    recall  f1-score   support

      B-MISC       0.00      0.00      0.00         4
       I-LOC       0.87      0.72      0.79      2088
      I-MISC       0.90      0.61      0.73      1258
       I-ORG       0.00      0.00      0.00      2085
       I-PER       0.93      0.63      0.75      3053
           O       0.00      0.00      0.00         0

    accuracy                           0.50      8488
   macro avg       0.45      0.33      0.38      8488
weighted avg       0.68      0.50      0.57      8488

Epoch [4/5]


Training: 100%|██████████| 469/469 [00:02<00:00, 204.60batch/s]


Training Loss: 0.0057
Validation:
              precision    recall  f1-score   support

      B-MISC       0.00      0.00      0.00         4
       I-LOC       0.84      0.79      0.81      2088
      I-MISC       0.87      0.71      0.78      1258
       I-ORG       0.00      0.00      0.00      2085
       I-PER       0.91      0.71      0.80      3053
           O       0.00      0.00      0.00         0

    accuracy                           0.56      8488
   macro avg       0.44      0.37      0.40      8488
weighted avg       0.66      0.56      0.60      8488

Epoch [5/5]


Training: 100%|██████████| 469/469 [00:02<00:00, 208.51batch/s]


Training Loss: 0.0029
Validation:
              precision    recall  f1-score   support

      B-MISC       0.00      0.00      0.00         4
       I-LOC       0.86      0.77      0.81      2088
      I-MISC       0.87      0.72      0.79      1258
       I-ORG       0.00      0.00      0.00      2085
       I-PER       0.92      0.67      0.77      3053
           O       0.00      0.00      0.00         0

    accuracy                           0.54      8488
   macro avg       0.44      0.36      0.40      8488
weighted avg       0.67      0.54      0.59      8488

Final Test Evaluation:
              precision    recall  f1-score   support

       B-LOC       0.00      0.00      0.00         6
      B-MISC       0.00      0.00      0.00         9
       B-ORG       0.00      0.00      0.00         5
       I-LOC       0.81      0.68      0.74      1905
      I-MISC       0.79      0.61      0.69       908
       I-ORG       0.00      0.00      0.00      2480
       I-PER       0.

In [19]:
# Save model
model_path = "save/models/bilstm"
save_model(model, model_path, word2idx, tag2idx)

# Load model and predict
loaded_model, loaded_word2idx, loaded_tag2idx = load_model(
    BiLSTMNER, f"{model_path}.pth", f"{model_path}_vocab.pkl", embed_dim, hidden_dim, device
)
loaded_idx2tag = {idx: tag for tag, idx in loaded_tag2idx.items()}

Model and vocab saved to save/models/bilstm.pth and save/models/bilstm_vocab.pkl
Model loaded from save/models/bilstm.pth


In [20]:
test_text = "EU rejects German call to boycott British lamb ."
predictions = predict(
    loaded_model, test_text, loaded_word2idx, loaded_idx2tag, max_len, device
)
print("Predictions:", predictions)

Predictions: [('EU', 'O'), ('rejects', 'O'), ('German', 'I-MISC'), ('call', 'O'), ('to', 'O'), ('boycott', 'O'), ('British', 'I-MISC'), ('lamb', 'O'), ('.', 'O')]
