<a href="https://colab.research.google.com/github/Shriyatha/Named_Entity_Recognition/blob/main/BERT_BASED_NER_English.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# BERT-based Named Entity Recognition (NER) on CoNLL-2003 Dataset
# =============================================================

# Install required packages
!pip install transformers datasets evaluate seqeval torch tqdm matplotlib pandas seaborn -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m75.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m54.7 MB/s[0m eta [36m0

In [20]:
import torch
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizerFast,
    BertForTokenClassification,
    DataCollatorForTokenClassification,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from datasets import load_dataset
from evaluate import load
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import random
import time

In [22]:
# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [23]:
# 1. Data Loading and Exploration
# ==============================
print("\n==== Data Loading and Exploration ====")
# Load dataset and metric
dataset = load_dataset("conll2003")
metric = load("seqeval")

# Basic dataset info
print("\nDataset splits:")
for split in dataset.keys():
    print(f"- {split}: {dataset[split].num_rows} examples")

# Examine data structure
print("\nDataset features:", dataset["train"].features)


==== Data Loading and Exploration ====

Dataset splits:
- train: 14041 examples
- validation: 3250 examples
- test: 3453 examples

Dataset features: {'id': Value(dtype='string', id=None), 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'pos_tags': Sequence(feature=ClassLabel(names=['"', "''", '#', '$', '(', ')', ',', '.', ':', '``', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'], id=None), length=-1, id=None), 'chunk_tags': Sequence(feature=ClassLabel(names=['O', 'B-ADJP', 'I-ADJP', 'B-ADVP', 'I-ADVP', 'B-CONJP', 'I-CONJP', 'B-INTJ', 'I-INTJ', 'B-LST', 'I-LST', 'B-NP', 'I-NP', 'B-PP', 'I-PP', 'B-PRT', 'I-PRT', 'B-SBAR', 'I-SBAR', 'B-UCP', 'I-UCP', 'B-VP', 'I-VP'], id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-PER', '

In [24]:
# 2. Data Preprocessing
# ====================
print("\n==== Data Preprocessing ====")
# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

# Define label_list
label_list = dataset["train"].features["ner_tags"].feature.names

# Preprocess function to align labels with tokens
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)  # Special tokens get -100
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])  # First token of word
            else:
                label_ids.append(-100)  # Subsequent tokens of word get -100
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

print("Tokenizing dataset...")
start_time = time.time()
# Tokenize dataset
tokenized_datasets = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=dataset["train"].column_names
)
tokenization_time = time.time() - start_time
print(f"Tokenization completed in {tokenization_time:.2f} seconds")

# Show tokenization example
print("\nTokenization example:")
example_idx = 0
example_tokens = dataset["train"][example_idx]["tokens"][:10]
example_tags = [label_list[idx] for idx in dataset["train"][example_idx]["ner_tags"][:10]]
print("Original:")
for token, tag in zip(example_tokens, example_tags):
    print(f"{token:<15} -> {tag}")

print("\nTokenized:")
tokenized_example = tokenizer(example_tokens, is_split_into_words=True)
tokenized_words = tokenizer.convert_ids_to_tokens(tokenized_example.input_ids)
for token in tokenized_words:
    print(token)



==== Data Preprocessing ====
Tokenizing dataset...


Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]

Tokenization completed in 5.05 seconds

Tokenization example:
Original:
EU              -> B-ORG
rejects         -> O
German          -> B-MISC
call            -> O
to              -> O
boycott         -> O
British         -> B-MISC
lamb            -> O
.               -> O

Tokenized:
[CLS]
EU
rejects
German
call
to
boycott
British
la
##mb
.
[SEP]


In [25]:
# 3. Model Setup
# =============
print("\n==== Model Setup ====")
# Create data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Create DataLoaders
batch_size = 16
print(f"Using batch size: {batch_size}")
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=batch_size
)
val_dataloader = DataLoader(
    tokenized_datasets["validation"],
    collate_fn=data_collator,
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets["test"],
    collate_fn=data_collator,
    batch_size=batch_size
)

# Initialize model
print("\nInitializing BERT model for token classification...")
model = BertForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=len(label_list),
    id2label={i: label for i, label in enumerate(label_list)},
    label2id={label: i for i, label in enumerate(label_list)}
).to(device)

# Print model architecture
print("\nModel architecture:")
print(model.__class__.__name__)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {total_params - trainable_params:,}")


==== Model Setup ====
Using batch size: 16

Initializing BERT model for token classification...


Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Model architecture:
BertForTokenClassification
Total parameters: 107,726,601
Trainable parameters: 107,726,601
Non-trainable parameters: 0


In [26]:
# 4. Training Setup
# ================
print("\n==== Training Setup ====")
# Hyperparameters
learning_rate = 2e-5
epochs = 5
weight_decay = 0.01
warmup_steps = 0
print(f"Learning rate: {learning_rate}")
print(f"Epochs: {epochs}")
print(f"Weight decay: {weight_decay}")
print(f"Warmup steps: {warmup_steps}")

# Optimizer and scheduler
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)



==== Training Setup ====
Learning rate: 2e-05
Epochs: 5
Weight decay: 0.01
Warmup steps: 0


In [27]:
# 5. Training Loop
# ==============
print("\n==== Training Loop ====")
# Helper function for evaluation
def evaluate(dataloader, desc="Evaluating"):
    model.eval()
    predictions, true_labels = [], []
    for batch in tqdm(dataloader, desc=desc):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
            logits = outputs.logits
            batch_preds = torch.argmax(logits, dim=-1).cpu().numpy()
            batch_labels = batch["labels"].cpu().numpy()
        for preds, labels in zip(batch_preds, batch_labels):
            # Filter out ignored index (-100)
            true_indices = [i for i, l in enumerate(labels) if l != -100]
            true_labels.append([label_list[labels[i]] for i in true_indices])
            predictions.append([label_list[preds[i]] for i in true_indices])
    results = metric.compute(predictions=predictions, references=true_labels)
    return results, predictions, true_labels

# Store metrics for plotting
train_losses = []
val_f1_scores = []
print("\nStarting training...")
start_training_time = time.time()
for epoch in range(epochs):
    epoch_start_time = time.time()
    # Training
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Training]"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    avg_train_loss = total_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    # Validation
    val_results, _, _ = evaluate(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Validation]")
    val_f1_scores.append(val_results["overall_f1"])
    epoch_time = time.time() - epoch_start_time

    # Print metrics
    print(f"\nEpoch {epoch+1}/{epochs} completed in {epoch_time:.2f} seconds")
    print(f"Training Loss: {avg_train_loss:.4f}")
    print(f"Validation F1: {val_results['overall_f1']:.4f}")
    print(f"Validation Precision: {val_results['overall_precision']:.4f}")
    print(f"Validation Recall: {val_results['overall_recall']:.4f}")

total_training_time = time.time() - start_training_time
print(f"\nTraining completed in {total_training_time:.2f} seconds ({total_training_time/60:.2f} minutes)")


==== Training Loop ====

Starting training...


Epoch 1/5 [Training]: 100%|██████████| 878/878 [05:16<00:00,  2.77it/s]
Epoch 1/5 [Validation]: 100%|██████████| 204/204 [00:25<00:00,  8.13it/s]



Epoch 1/5 completed in 342.35 seconds
Training Loss: 0.1127
Validation F1: 0.9254
Validation Precision: 0.9206
Validation Recall: 0.9303


Epoch 2/5 [Training]: 100%|██████████| 878/878 [05:16<00:00,  2.77it/s]
Epoch 2/5 [Validation]: 100%|██████████| 204/204 [00:25<00:00,  8.14it/s]



Epoch 2/5 completed in 342.64 seconds
Training Loss: 0.0279
Validation F1: 0.9419
Validation Precision: 0.9390
Validation Recall: 0.9447


Epoch 3/5 [Training]: 100%|██████████| 878/878 [05:16<00:00,  2.77it/s]
Epoch 3/5 [Validation]: 100%|██████████| 204/204 [00:24<00:00,  8.16it/s]



Epoch 3/5 completed in 342.25 seconds
Training Loss: 0.0137
Validation F1: 0.9431
Validation Precision: 0.9382
Validation Recall: 0.9481


Epoch 4/5 [Training]: 100%|██████████| 878/878 [05:16<00:00,  2.77it/s]
Epoch 4/5 [Validation]: 100%|██████████| 204/204 [00:24<00:00,  8.19it/s]



Epoch 4/5 completed in 342.38 seconds
Training Loss: 0.0081
Validation F1: 0.9468
Validation Precision: 0.9427
Validation Recall: 0.9510


Epoch 5/5 [Training]: 100%|██████████| 878/878 [05:16<00:00,  2.77it/s]
Epoch 5/5 [Validation]: 100%|██████████| 204/204 [00:24<00:00,  8.18it/s]



Epoch 5/5 completed in 342.44 seconds
Training Loss: 0.0052
Validation F1: 0.9491
Validation Precision: 0.9452
Validation Recall: 0.9530

Training completed in 1712.07 seconds (28.53 minutes)


In [28]:
# Plot training progress
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs + 1), train_losses, marker='o')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, epochs + 1), val_f1_scores, marker='o', color='green')
plt.title("Validation F1 Score")
plt.xlabel("Epoch")
plt.ylabel("F1 Score")
plt.grid(True)
plt.tight_layout()
plt.savefig('training_progress.png')
plt.close()

In [29]:
# 6. Test Evaluation and Analysis
# ==============================
print("\n==== Test Evaluation and Analysis ====")
print("Evaluating on test set...")
test_results, test_predictions, test_true_labels = evaluate(test_dataloader, "Testing")

# Print overall metrics
print("\nTest Results:")
print(f"Accuracy: {test_results['overall_accuracy']:.4f}")
print(f"Precision: {test_results['overall_precision']:.4f}")
print(f"Recall: {test_results['overall_recall']:.4f}")
print(f"F1 Score: {test_results['overall_f1']:.4f}")

# Print per-entity metrics
print("\nPer-Entity Type Metrics:")
entity_results = {}
for key in sorted(test_results.keys()):
    if key not in ['overall_accuracy', 'overall_precision', 'overall_recall', 'overall_f1']:
        entity_results[key] = {
            'precision': test_results[key]['precision'],
            'recall': test_results[key]['recall'],
            'f1': test_results[key]['f1'],
            'number': test_results[key]['number']
        }
        print(f"{key}:")
        print(f" Precision: {test_results[key]['precision']:.4f}")
        print(f" Recall: {test_results[key]['recall']:.4f}")
        print(f" F1: {test_results[key]['f1']:.4f}")
        print(f" Support: {test_results[key]['number']} entities")



==== Test Evaluation and Analysis ====
Evaluating on test set...


Testing: 100%|██████████| 216/216 [00:26<00:00,  8.15it/s]



Test Results:
Accuracy: 0.9832
Precision: 0.9097
Recall: 0.9231
F1 Score: 0.9164

Per-Entity Type Metrics:
LOC:
 Precision: 0.9335
 Recall: 0.9358
 F1: 0.9347
 Support: 1666 entities
MISC:
 Precision: 0.7791
 Recall: 0.8191
 F1: 0.7986
 Support: 702 entities
ORG:
 Precision: 0.8897
 Recall: 0.9181
 F1: 0.9037
 Support: 1661 entities
PER:
 Precision: 0.9664
 Recall: 0.9604
 F1: 0.9634
 Support: 1615 entities
