<a href="https://colab.research.google.com/github/Shriyatha/Named_Entity_Recognition/blob/main/BERT_ENGLISH_NER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# BERT-based Named Entity Recognition (NER) on CoNLL-2003 Dataset
# =============================================================

# Install required packages
!pip install transformers datasets evaluate seqeval torch tqdm matplotlib pandas seaborn -q

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizerFast,
    BertForTokenClassification,
    DataCollatorForTokenClassification,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from datasets import load_dataset
from evaluate import load
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import random
import time

In [None]:
# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# 1. Data Loading and Exploration
# ==============================
print("\n==== Data Loading and Exploration ====")
# Load dataset and metric
dataset = load_dataset("conll2003")
metric = load("seqeval")

# Basic dataset info
print("\nDataset splits:")
for split in dataset.keys():
    print(f"- {split}: {dataset[split].num_rows} examples")

# Examine data structure
print("\nDataset features:", dataset["train"].features)

In [None]:
# 2. Data Preprocessing
# ====================
print("\n==== Data Preprocessing ====")
# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

# Define label_list
label_list = dataset["train"].features["ner_tags"].feature.names

# Preprocess function to align labels with tokens
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)  # Special tokens get -100
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])  # First token of word
            else:
                label_ids.append(-100)  # Subsequent tokens of word get -100
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

print("Tokenizing dataset...")
start_time = time.time()
# Tokenize dataset
tokenized_datasets = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=dataset["train"].column_names
)
tokenization_time = time.time() - start_time
print(f"Tokenization completed in {tokenization_time:.2f} seconds")



In [None]:
# 3. Model Setup
# =============
print("\n==== Model Setup ====")
# Create data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Create DataLoaders
batch_size = 16
print(f"Using batch size: {batch_size}")
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=batch_size
)
val_dataloader = DataLoader(
    tokenized_datasets["validation"],
    collate_fn=data_collator,
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets["test"],
    collate_fn=data_collator,
    batch_size=batch_size
)

# Initialize model
print("\nInitializing BERT model for token classification...")
model = BertForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=len(label_list),
    id2label={i: label for i, label in enumerate(label_list)},
    label2id={label: i for i, label in enumerate(label_list)}
).to(device)

# Print model architecture
print("\nModel architecture:")
print(model.__class__.__name__)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {total_params - trainable_params:,}")

In [None]:
# 4. Training Setup
# ================
print("\n==== Training Setup ====")
# Hyperparameters
learning_rate = 2e-5
epochs = 5
weight_decay = 0.01
warmup_steps = 0
print(f"Learning rate: {learning_rate}")
print(f"Epochs: {epochs}")
print(f"Weight decay: {weight_decay}")
print(f"Warmup steps: {warmup_steps}")

# Optimizer and scheduler
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)


In [None]:
# 5. Training Loop
# ==============
print("\n==== Training Loop ====")
# Helper function for evaluation
def evaluate(dataloader, desc="Evaluating"):
    model.eval()
    predictions, true_labels = [], []
    for batch in tqdm(dataloader, desc=desc):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
            logits = outputs.logits
            batch_preds = torch.argmax(logits, dim=-1).cpu().numpy()
            batch_labels = batch["labels"].cpu().numpy()
        for preds, labels in zip(batch_preds, batch_labels):
            # Filter out ignored index (-100)
            true_indices = [i for i, l in enumerate(labels) if l != -100]
            true_labels.append([label_list[labels[i]] for i in true_indices])
            predictions.append([label_list[preds[i]] for i in true_indices])
    results = metric.compute(predictions=predictions, references=true_labels)
    return results, predictions, true_labels

# Store metrics for plotting
train_losses = []
val_f1_scores = []
print("\nStarting training...")
start_training_time = time.time()
for epoch in range(epochs):
    epoch_start_time = time.time()
    # Training
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Training]"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    avg_train_loss = total_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    # Validation
    val_results, _, _ = evaluate(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Validation]")
    val_f1_scores.append(val_results["overall_f1"])
    epoch_time = time.time() - epoch_start_time

    # Print metrics
    print(f"\nEpoch {epoch+1}/{epochs} completed in {epoch_time:.2f} seconds")
    print(f"Training Loss: {avg_train_loss:.4f}")
    print(f"Validation F1: {val_results['overall_f1']:.4f}")
    print(f"Validation Precision: {val_results['overall_precision']:.4f}")
    print(f"Validation Recall: {val_results['overall_recall']:.4f}")

total_training_time = time.time() - start_training_time
print(f"\nTraining completed in {total_training_time:.2f} seconds ({total_training_time/60:.2f} minutes)")

In [None]:
# Plot training progress
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs + 1), train_losses, marker='o')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, epochs + 1), val_f1_scores, marker='o', color='green')
plt.title("Validation F1 Score")
plt.xlabel("Epoch")
plt.ylabel("F1 Score")
plt.grid(True)
plt.tight_layout()
plt.savefig('training_progress.png')
plt.close()

In [None]:
# 6. Test Evaluation and Analysis
# ==============================
print("\n==== Test Evaluation and Analysis ====")
print("Evaluating on test set...")
test_results, test_predictions, test_true_labels = evaluate(test_dataloader, "Testing")

# Print overall metrics
print("\nTest Results:")
print(f"Accuracy: {test_results['overall_accuracy']:.4f}")
print(f"Precision: {test_results['overall_precision']:.4f}")
print(f"Recall: {test_results['overall_recall']:.4f}")
print(f"F1 Score: {test_results['overall_f1']:.4f}")

# Print per-entity metrics
print("\nPer-Entity Type Metrics:")
entity_results = {}
for key in sorted(test_results.keys()):
    if key not in ['overall_accuracy', 'overall_precision', 'overall_recall', 'overall_f1']:
        entity_results[key] = {
            'precision': test_results[key]['precision'],
            'recall': test_results[key]['recall'],
            'f1': test_results[key]['f1'],
            'number': test_results[key]['number']
        }
        print(f"{key}:")
        print(f" Precision: {test_results[key]['precision']:.4f}")
        print(f" Recall: {test_results[key]['recall']:.4f}")
        print(f" F1: {test_results[key]['f1']:.4f}")
        print(f" Support: {test_results[key]['number']} entities")
