In [161]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.metrics import classification_report
from tqdm import tqdm
import pickle
import os

import warnings
warnings.filterwarnings('ignore')

In [162]:
def create_vocab(file_path):
    word_set = set()
    pos_set = set()
    chunk_set = set()
    tag_set = set()

    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                word, pos, chunk, tag = line.split()
                word_set.add(word)
                pos_set.add(pos)
                chunk_set.add(chunk)
                tag_set.add(tag)

    word2idx = {word: idx + 2 for idx, word in enumerate(sorted(word_set))}
    word2idx["<PAD>"] = 0
    word2idx["<UNK>"] = 1

    pos2idx = {pos: idx for idx, pos in enumerate(sorted(pos_set))}
    chunk2idx = {chunk: idx for idx, chunk in enumerate(sorted(chunk_set))}
    tag2idx = {tag: idx + 1 for idx, tag in enumerate(sorted(tag_set))}
    tag2idx["O"] = 0 

    return word2idx, pos2idx, chunk2idx, tag2idx


In [163]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MultiLabelNERDataset(Dataset):
    def __init__(self, file_path, word2idx, pos2idx, chunk2idx, tag2idx, max_len):
        self.data = self._read_file(file_path)
        self.word2idx = word2idx
        self.pos2idx = pos2idx
        self.chunk2idx = chunk2idx
        self.tag2idx = tag2idx
        self.max_len = max_len
        self.pad_idx = word2idx["<PAD>"]

    def _read_file(self, file_path):
        """Read the data file and return a list of sentences with words, POS, chunk, and tags."""
        sentences = []
        with open(file_path, "r", encoding="utf-8") as f:
            words, pos_tags, chunks, tags = [], [], [], []
            for line in f:
                line = line.strip()
                if line:
                    word, pos, chunk, tag = line.split()
                    words.append(word)
                    pos_tags.append(pos)
                    chunks.append(chunk)
                    tags.append(tag)
                else:
                    if words:
                        sentences.append((words, pos_tags, chunks, tags))
                        words, pos_tags, chunks, tags = [], [], [], []
            # Append the last sentence if file doesn't end with a newline
            if words:
                sentences.append((words, pos_tags, chunks, tags))
        return sentences

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        words, pos_tags, chunks, tags = self.data[idx]

        # Convert words, pos, chunk, and tags to indices
        word_indices = [self.word2idx.get(w, self.word2idx["<UNK>"]) for w in words]
        pos_indices = [self.pos2idx.get(p, 0) for p in pos_tags]
        chunk_indices = [self.chunk2idx.get(c, 0) for c in chunks]
        tag_indices = [self.tag2idx.get(t, 0) for t in tags]

        # Padding/truncation
        word_indices = word_indices[:self.max_len] + [self.pad_idx] * (self.max_len - len(word_indices))
        pos_indices = pos_indices[:self.max_len] + [0] * (self.max_len - len(pos_indices))
        chunk_indices = chunk_indices[:self.max_len] + [0] * (self.max_len - len(chunk_indices))
        tag_indices = tag_indices[:self.max_len] + [0] * (self.max_len - len(tag_indices))

        return (
            torch.tensor(word_indices, dtype=torch.long),
            torch.tensor(pos_indices, dtype=torch.long),
            torch.tensor(chunk_indices, dtype=torch.long),
            torch.tensor(tag_indices, dtype=torch.long),
        )

# Collate function for DataLoader
def collate_fn(batch):
    words, pos_tags, chunks, tags = zip(*batch)
    return (
        torch.stack(words, dim=0),
        torch.stack(pos_tags, dim=0),
        torch.stack(chunks, dim=0),
        torch.stack(tags, dim=0),
    )

In [164]:
class BiLSTMNERMultiLabel(nn.Module):
    def __init__(self, vocab_size, pos_size, chunk_size, ner_size, embed_dim, hidden_dim):
        super(BiLSTMNERMultiLabel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True, batch_first=True)

        # Linear layers for multi-label outputs
        self.fc_pos = nn.Linear(hidden_dim * 2, pos_size)  # POS head
        self.fc_chunk = nn.Linear(hidden_dim * 2, chunk_size)  # Chunk head
        self.fc_ner = nn.Linear(hidden_dim * 2, ner_size)  # NER head

    def forward(self, x):
        x = self.embedding(x)  # Shape: (batch_size, max_len, embed_dim)
        lstm_out, _ = self.lstm(x)  # Shape: (batch_size, max_len, hidden_dim * 2)

        # Multi-label outputs
        pos_out = self.fc_pos(lstm_out)  # Shape: (batch_size, max_len, pos_size)
        chunk_out = self.fc_chunk(lstm_out)  # Shape: (batch_size, max_len, chunk_size)
        ner_out = self.fc_ner(lstm_out)  # Shape: (batch_size, max_len, ner_size)

        return pos_out, chunk_out, ner_out

In [165]:
def compute_loss(pos_out, chunk_out, ner_out, pos_labels, chunk_labels, ner_labels, pos_weight=1.0, chunk_weight=1.0, ner_weight=1.0):
    criterion = nn.CrossEntropyLoss()
    pos_loss = criterion(pos_out.view(-1, pos_out.shape[-1]), pos_labels.view(-1))
    chunk_loss = criterion(chunk_out.view(-1, chunk_out.shape[-1]), chunk_labels.view(-1))
    ner_loss = criterion(ner_out.view(-1, ner_out.shape[-1]), ner_labels.view(-1))
    total_loss = pos_weight * pos_loss + chunk_weight * chunk_loss + ner_weight * ner_loss
    return total_loss


In [166]:
def train_model(model, train_loader, optimizer, device, pos_weight=1.0, chunk_weight=1.0, ner_weight=1.0):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc="Training", unit="batch"):
        words, pos_labels, chunk_labels, ner_labels = batch
        words = words.to(device)
        pos_labels = pos_labels.to(device)
        chunk_labels = chunk_labels.to(device)
        ner_labels = ner_labels.to(device)

        optimizer.zero_grad()
        pos_out, chunk_out, ner_out = model(words)
        loss = compute_loss(pos_out, chunk_out, ner_out, pos_labels, chunk_labels, ner_labels, pos_weight, chunk_weight, ner_weight)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


In [167]:
def flatten(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]

In [168]:
def evaluate_model(model, loader, idx2pos, idx2chunk, idx2tag, device):
    model.eval()
    all_preds_pos, all_labels_pos = [], []
    all_preds_chunk, all_labels_chunk = [], []
    all_preds_ner, all_labels_ner = [], []

    with torch.no_grad():
        for batch in loader:
            words, pos_tags, chunk_tags, ner_tags = batch
            words = words.to(device)
            pos_tags = pos_tags.to(device)
            chunk_tags = chunk_tags.to(device)
            ner_tags = ner_tags.to(device)
            
            # Forward pass
            outputs = model(words)
            pos_output, chunk_output, ner_output = outputs[0], outputs[1], outputs[2]
            
            # Convert predictions to labels
            pos_preds = torch.argmax(pos_output, dim=-1)
            chunk_preds = torch.argmax(chunk_output, dim=-1)
            ner_preds = torch.argmax(ner_output, dim=-1)
            
            all_preds_pos.extend(pos_preds.cpu().numpy().tolist())
            all_labels_pos.extend(pos_tags.cpu().numpy().tolist())

            all_preds_chunk.extend(chunk_preds.cpu().numpy().tolist())
            all_labels_chunk.extend(chunk_tags.cpu().numpy().tolist())

            all_preds_ner.extend(ner_preds.cpu().numpy().tolist())
            all_labels_ner.extend(ner_tags.cpu().numpy().tolist())
    
    # Flatten the lists
    all_preds_pos = flatten(all_preds_pos)
    all_labels_pos = flatten(all_labels_pos)

    all_preds_chunk = flatten(all_preds_chunk)
    all_labels_chunk = flatten(all_labels_chunk)

    all_preds_ner = flatten(all_preds_ner)
    all_labels_ner = flatten(all_labels_ner)

    # Generate classification reports
    pos_report = classification_report(all_labels_pos, all_preds_pos)
    chunk_report = classification_report(all_labels_chunk, all_preds_chunk)
    ner_report = classification_report(all_labels_ner, all_preds_ner)

    return pos_report, chunk_report, ner_report

In [169]:
def save_model(model, path, vocab):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    
    # Save model weights
    torch.save(model.state_dict(), f"{path}.pth")
    
    # Save vocab
    with open(f"{path}_vocab.pkl", "wb") as f:
        pickle.dump(vocab, f)
    
    print(f"Model and vocab saved to {path}.pth and {path}_vocab.pkl")

In [170]:
def load_model(model_class, path, device, embed_dim, hidden_dim):
    # Load vocab
    with open(f"{path}_vocab.pkl", "rb") as f:
        vocab = pickle.load(f)
    
    word2idx, pos2idx, chunk2idx, tag2idx = (
        vocab["word2idx"],
        vocab["pos2idx"],
        vocab["chunk2idx"],
        vocab["tag2idx"],
    )
    
    # Initialize and load model
    model = model_class(
        len(word2idx),
        len(pos2idx),
        len(chunk2idx),
        len(tag2idx),
        embed_dim,
        hidden_dim,
    )
    model.load_state_dict(torch.load(f"{path}.pth", map_location=device))
    model.to(device)
    model.eval()
    
    print("Model and vocab loaded successfully.")
    return model, word2idx, pos2idx, chunk2idx, tag2idx


In [183]:
def predict(model, text, word2idx, pos2idx, chunk2idx, idx2tag, max_len, device):
    words = text.split()
    word_indices = [word2idx.get(w, word2idx["<UNK>"]) for w in words]
    word_indices = word_indices[:max_len] + [word2idx["<PAD>"]] * (max_len - len(word_indices))
    
    model_input = torch.tensor([word_indices]).to(device)

    with torch.no_grad():
        pos_out, chunk_out, ner_out = model(model_input) 
        predictions = torch.sigmoid(ner_out).cpu().numpy()
    
    predicted_tags = []
    for word_pred in predictions[0][:len(words)]:
        max_prob_idx = word_pred.argmax() 
        predicted_tags.append([idx2tag[max_prob_idx]])
    
    result = [(w, tags if tags else ["O"]) for w, tags in zip(words, predicted_tags)]
    return result


In [172]:
# Paths
train_file = "data/eng/eng.train"
val_file = "data/eng/eng.testa"
test_file = "data/eng/eng.testb"

# Build vocabulary
word2idx, pos2idx, chunk2idx, tag2idx = create_vocab(train_file)
idx2pos = {idx: pos for pos, idx in pos2idx.items()}
idx2chunk = {idx: chunk for chunk, idx in chunk2idx.items()}
idx2tag = {idx: tag for tag, idx in tag2idx.items()}

In [173]:
tag2idx

{'B-LOC': 1,
 'B-MISC': 2,
 'B-ORG': 3,
 'I-LOC': 4,
 'I-MISC': 5,
 'I-ORG': 6,
 'I-PER': 7,
 'O': 0}

In [174]:
embed_dim = 128
hidden_dim = 256
batch_size = 32
epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = MultiLabelNERDataset(train_file, word2idx, pos2idx, chunk2idx, tag2idx, max_len=50)
val_dataset = MultiLabelNERDataset(val_file, word2idx, pos2idx, chunk2idx, tag2idx, max_len=50)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

In [175]:
model = BiLSTMNERMultiLabel(len(word2idx), len(pos2idx), len(chunk2idx), len(tag2idx), embed_dim, hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [176]:
for epoch in range(epochs):
    print(f"Epoch [{epoch + 1}/{epochs}]")
    train_loss = train_model(model, train_loader, optimizer, device)
    print(f"Training Loss: {train_loss:.4f}")

    print("Validation:")
    pos_report, chunk_report, ner_report = evaluate_model(model, val_loader, idx2pos, idx2chunk, idx2tag, device)
    print("NER Classification Report:\n", ner_report)

Epoch [1/3]


Training: 100%|██████████| 469/469 [00:06<00:00, 77.52batch/s]


Training Loss: 0.8617
Validation:
NER Classification Report:
               precision    recall  f1-score   support

           0       0.97      1.00      0.98    164812
           2       0.00      0.00      0.00         4
           4       0.81      0.39      0.52      2088
           5       0.68      0.14      0.23      1258
           6       0.87      0.20      0.33      2085
           7       0.63      0.56      0.59      3053

    accuracy                           0.97    173300
   macro avg       0.66      0.38      0.44    173300
weighted avg       0.96      0.97      0.96    173300

Epoch [2/3]


Training: 100%|██████████| 469/469 [00:06<00:00, 77.70batch/s]


Training Loss: 0.3267
Validation:
NER Classification Report:
               precision    recall  f1-score   support

           0       0.98      1.00      0.99    164812
           2       0.00      0.00      0.00         4
           4       0.90      0.57      0.70      2088
           5       0.89      0.34      0.49      1258
           6       0.84      0.41      0.56      2085
           7       0.81      0.60      0.69      3053

    accuracy                           0.97    173300
   macro avg       0.74      0.49      0.57    173300
weighted avg       0.97      0.97      0.97    173300

Epoch [3/3]


Training: 100%|██████████| 469/469 [00:05<00:00, 80.14batch/s]


Training Loss: 0.2177
Validation:
NER Classification Report:
               precision    recall  f1-score   support

           0       0.99      1.00      0.99    164812
           2       0.00      0.00      0.00         4
           4       0.90      0.71      0.79      2088
           5       0.85      0.58      0.69      1258
           6       0.84      0.59      0.69      2085
           7       0.85      0.67      0.75      3053

    accuracy                           0.98    173300
   macro avg       0.74      0.59      0.65    173300
weighted avg       0.98      0.98      0.98    173300



In [177]:
save_model(model, "save/models/multilabel_bilstm", {
    "word2idx": word2idx,
    "pos2idx": pos2idx,
    "chunk2idx": chunk2idx,
    "tag2idx": tag2idx,
})


Model and vocab saved to save/models/multilabel_bilstm.pth and save/models/multilabel_bilstm_vocab.pkl


In [179]:
loaded_model, loaded_word2idx, loaded_pos2idx, loaded_chunk2idx, loaded_tag2idx = load_model(
    BiLSTMNERMultiLabel,
    "save/models/multilabel_bilstm",
    device,
    embed_dim,
    hidden_dim
)
loaded_idx2tag = {idx: tag for tag, idx in loaded_tag2idx.items()}


Model and vocab loaded successfully.


In [184]:
text = "The European Union is headquartered in Brussels"
predictions = predict(
    loaded_model,
    text,
    loaded_word2idx,
    loaded_pos2idx,
    loaded_chunk2idx,
    loaded_idx2tag,
    max_len=50,
    device=device
)
print(predictions)


[('The', ['O']), ('European', ['I-ORG']), ('Union', ['I-ORG']), ('is', ['O']), ('headquartered', ['O']), ('in', ['O']), ('Brussels', ['I-LOC'])]
