<a href="https://colab.research.google.com/github/Shriyatha/Named_Entity_Recognition/blob/main/CRF_NER_english.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install sklearn_crfsuite datasets evaluate seqeval tabulate -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m38.8 MB/s[0m eta [36m0:0

In [2]:
import matplotlib
matplotlib.use('agg')

In [5]:
import numpy as np
import sklearn_crfsuite
from sklearn_crfsuite import metrics
from collections import Counter
from datasets import load_dataset
from evaluate import load

# Feature extraction functions for CRF
# ----------------------------------
def word2features(sent, i):
    """Extract features from word at position i."""
    word = sent[i]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'word.contains_hyphen': '-' in word,
        'word.contains_dot': '.' in word,
        'word.length': len(word),
        'word.prefix-2': word[:2],
        'word.prefix-3': word[:3],
        'word.suffix-2': word[-2:],
        'word.suffix-3': word[-3:],
        'word.contains_digit': any(char.isdigit() for char in word),
        'word.contains_uppercase': any(char.isupper() for char in word),
        'word.is_alphanumeric': word.isalnum(),
        'word.is_alphabetic': word.isalpha(),
    }

    # Context features - previous word
    if i > 0:
        word1 = sent[i-1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:word.isdigit()': word1.isdigit(),
            '-1:word.prefix-2': word1[:2],
            '-1:word.suffix-2': word1[-2:],
        })
    else:
        features['BOS'] = True

    # Context features - next word
    if i < len(sent)-1:
        word1 = sent[i+1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:word.isdigit()': word1.isdigit(),
            '+1:word.prefix-2': word1[:2],
            '+1:word.suffix-2': word1[-2:],
        })
    else:
        features['EOS'] = True

    # Additional context features
    if i > 1:
        word2 = sent[i-2]
        features.update({
            '-2:word.lower()': word2.lower(),
            '-2:word.istitle()': word2.istitle(),
        })

    if i < len(sent)-2:
        word2 = sent[i+2]
        features.update({
            '+2:word.lower()': word2.lower(),
            '+2:word.istitle()': word2.istitle(),
        })

    return features

def sent2features(sent):
    """Convert sentence to list of features."""
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(tags, id_to_label):
    """Convert tags to labels."""
    return [id_to_label(tag) for tag in tags]

# Main CRF NER class
# -----------------
class CRFNER:
    def __init__(self, c1=0.1, c2=0.1, max_iterations=100):
        self.crf = sklearn_crfsuite.CRF(
            algorithm='lbfgs',
            c1=c1,
            c2=c2,
            max_iterations=max_iterations,
            all_possible_transitions=True
        )
        self.allowed_entities = {"PER", "LOC", "ORG", "MISC"}

    def train(self, X_train, y_train):
        """Train CRF model."""
        self.crf.fit(X_train, y_train)

    def predict(self, X_test):
        """Predict using CRF model."""
        return self.crf.predict(X_test)

    def get_transition_features(self):
        """Get learned transition features."""
        if hasattr(self.crf, 'transition_features_'):
            return dict(self.crf.transition_features_)
        return {}

    def get_state_features(self):
        """Get learned state features."""
        if hasattr(self.crf, 'state_features_'):
            return dict(self.crf.state_features_)
        return {}

# Prepare data for CRF
# -------------------
def prepare_crf_data(dataset):
    """Prepare CoNLL data for CRF training and testing."""
    id_to_label = dataset["train"].features["ner_tags"].feature.int2str

    # Prepare training data
    X_train = []
    y_train = []
    for example in dataset["train"]:
        tokens = example["tokens"]
        tags = example["ner_tags"]
        X_train.append(sent2features(tokens))
        y_train.append(sent2labels(tags, id_to_label))

    # Prepare test data
    X_test = []
    y_test = []
    for example in dataset["test"]:
        tokens = example["tokens"]
        tags = example["ner_tags"]
        X_test.append(sent2features(tokens))
        y_test.append(sent2labels(tags, id_to_label))

    return X_train, y_train, X_test, y_test

# CRF evaluation with seqeval - without tabulate
# -------------------------
def evaluate_crf_ner_system(crf_model, X_test, y_test):
    """Evaluate CRF NER system using seqeval."""
    # Get predictions
    y_pred = crf_model.predict(X_test)

    # Evaluate using seqeval
    metric = load("seqeval")
    results = metric.compute(predictions=y_pred, references=y_test)

    # Display results
    print("CRF NER Evaluation Results")
    print("==========================")

    # Create detailed output for per-entity results
    print("{:<10} {:<10} {:<10} {:<10} {:<10}".format("Entity", "Precision", "Recall", "F1 Score", "Support"))
    print("-" * 50)

    for entity, metrics in results.items():
        if isinstance(metrics, dict) and any(e in entity for e in crf_model.allowed_entities):
            print("{:<10} {:<10.4f} {:<10.4f} {:<10.4f} {:<10}".format(
                entity,
                metrics['precision'],
                metrics['recall'],
                metrics['f1'],
                metrics['number']
            ))

    # Overall results
    print("\nOverall Metrics")
    print("==============")
    print(f"Overall Precision: {results['overall_precision']:.4f}")
    print(f"Overall Recall: {results['overall_recall']:.4f}")
    print(f"Overall F1 Score: {results['overall_f1']:.4f}")
    print(f"Overall Accuracy: {results['overall_accuracy']:.4f}")

    return results

# Generate detailed seqeval report
# ------------------------------
def generate_crf_seqeval_report(crf_model, X_test, y_test):
    """Generate detailed seqeval classification report for CRF."""
    # Get predictions
    y_pred = crf_model.predict(X_test)

    # Generate seqeval classification report
    metric = load("seqeval")
    print("CRF NER Classification Report (seqeval)")
    print("======================================")
    print(metric.compute(predictions=y_pred, references=y_test, mode='strict', scheme='IOB2'))

    return y_pred, y_test

# Analyze CRF model features
# ------------------------
def analyze_crf_features(crf_model):
    """Analyze important features learned by CRF."""
    # Get state and transition features
    state_features = crf_model.get_state_features()
    transition_features = crf_model.get_transition_features()

    # Top state features for each label
    print("Top State Features by Label:")
    print("===========================")
    label_features = {}

    for (label, feature), weight in state_features.items():
        if label not in label_features:
            label_features[label] = []
        label_features[label].append((feature, weight))

    # Sort and display top features for each label
    for label, features in label_features.items():
        if any(e in label for e in ["PER", "LOC", "ORG", "MISC"]):
            print(f"\n{label}")
            print("-" * len(label))
            top_features = sorted(features, key=lambda x: abs(x[1]), reverse=True)[:10]
            for feature, weight in top_features:
                print(f"  {feature}: {weight:.4f}")

    # Top transition features
    print("\nTop Transition Features:")
    print("======================")
    top_transitions = sorted(transition_features.items(), key=lambda x: abs(x[1]), reverse=True)[:20]
    for (from_label, to_label), weight in top_transitions:
        print(f"  {from_label} -> {to_label}: {weight:.4f}")

# Modified analysis functions that don't use tabulate
# -------------------------------------------------
def extract_entities(tokens, tags):
    """Extract entity spans from tokens and tags."""
    entities = []
    current_entity = None
    current_type = None

    for i, (token, tag) in enumerate(zip(tokens, tags)):
        if tag.startswith('B-'):
            if current_entity:
                entities.append((current_type, ' '.join(current_entity)))
            current_entity = [token]
            current_type = tag[2:]
        elif tag.startswith('I-') and current_entity:
            current_entity.append(token)
        else:
            if current_entity:
                entities.append((current_type, ' '.join(current_entity)))
                current_entity = None
                current_type = None

    if current_entity:
        entities.append((current_type, ' '.join(current_entity)))

    return entities

def analyze_examples(crf_model, dataset, num_examples=10):
    """Analyze CRF predictions on specific examples without tabulate."""
    id_to_label = dataset["train"].features["ner_tags"].feature.int2str
    test_examples = dataset["test"]

    print("Analyzing Specific Examples:")
    print("-" * 80)

    found_examples = 0

    for i in range(len(test_examples)):
        if found_examples >= num_examples:
            break

        tokens = test_examples[i]["tokens"]
        true_tags = [id_to_label(tag) for tag in test_examples[i]["ner_tags"]]

        # Get CRF predictions
        features = sent2features(tokens)
        pred_tags = crf_model.predict([features])[0]

        # Check if any prediction is different from the true label
        has_mismatches = any(true != pred for true, pred in zip(true_tags, pred_tags))

        # Always show examples with mismatches, or until we have enough examples
        if has_mismatches or found_examples < num_examples:
            found_examples += 1
            print(f"Example {found_examples}:")
            print(f"Sentence: {' '.join(tokens)}")

            # Create alignment table - using simple format without tabulate
            print("\nToken analysis:")
            print("{:<6} {:<15} {:<15} {:<15}".format("Index", "Token", "True Label", "Predicted Label"))
            print("-" * 60)

            for j, (token, true, pred) in enumerate(zip(tokens, true_tags, pred_tags)):
                print("{:<6} {:<15} {:<15} {:<15}".format(j, token, true, pred))

            # Extract entities
            true_entities = extract_entities(tokens, true_tags)
            pred_entities = extract_entities(tokens, pred_tags)

            print(f"\nTrue Entities: {true_entities}")
            print(f"Predicted Entities: {pred_entities}")
            print("-" * 80)

def analyze_confidence_scores(crf_model, X_test, num_examples=5):
    """Analyze prediction confidence scores without tabulate."""
    print("Analyzing Prediction Confidence:")
    print("-" * 80)

    for i in range(min(num_examples, len(X_test))):
        features = X_test[i]
        # Get marginal probabilities
        marginals = crf_model.crf.predict_marginals_single(features)
        predictions = crf_model.predict([features])[0]

        print(f"Example {i+1}:")
        print("{:<10} {:<10} {:<15} {:<40}".format("Position", "Predicted", "Confidence", "Top 3 Labels"))
        print("-" * 80)

        for j, (pred, probs) in enumerate(zip(predictions, marginals)):
            # Get top 3 probable labels
            sorted_probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]
            top_labels = ', '.join([f"{l}({p:.3f})" for l, p in sorted_probs])

            # Highlight low confidence predictions
            confidence = probs[pred]
            confidence_str = f"{confidence:.6f}"

            print("{:<10} {:<10} {:<15} {:<40}".format(j, pred, confidence_str, top_labels))

        print("-" * 80)

def run_detailed_error_analysis(crf_model, dataset):
    """Run detailed error analysis on the CRF model predictions."""
    id_to_label = dataset["train"].features["ner_tags"].feature.int2str
    test_examples = dataset["test"]

    # Error pattern analysis
    error_patterns = {
        'LOC_as_PER': 0,
        'PER_as_LOC': 0,
        'LOC_as_ORG': 0,
        'ORG_as_LOC': 0,
        'MISC_as_ORG': 0,
        'ORG_as_MISC': 0,
        'entity_as_O': 0,
        'O_as_entity': 0,
        'boundary_errors': 0
    }

    total_errors = 0
    total_tokens = 0

    for example in test_examples:
        tokens = example["tokens"]
        true_tags = [id_to_label(tag) for tag in example["ner_tags"]]

        # Get CRF predictions
        features = sent2features(tokens)
        pred_tags = crf_model.predict([features])[0]

        total_tokens += len(tokens)

        for true, pred in zip(true_tags, pred_tags):
            if true != pred:
                total_errors += 1

                # Analyze error type
                if true == 'O' and pred != 'O':
                    error_patterns['O_as_entity'] += 1
                elif true != 'O' and pred == 'O':
                    error_patterns['entity_as_O'] += 1
                elif true.endswith('LOC') and pred.endswith('PER'):
                    error_patterns['LOC_as_PER'] += 1
                elif true.endswith('PER') and pred.endswith('LOC'):
                    error_patterns['PER_as_LOC'] += 1
                elif true.endswith('LOC') and pred.endswith('ORG'):
                    error_patterns['LOC_as_ORG'] += 1
                elif true.endswith('ORG') and pred.endswith('LOC'):
                    error_patterns['ORG_as_LOC'] += 1
                elif true.endswith('MISC') and pred.endswith('ORG'):
                    error_patterns['MISC_as_ORG'] += 1
                elif true.endswith('ORG') and pred.endswith('MISC'):
                    error_patterns['ORG_as_MISC'] += 1
                elif true[0] != pred[0]:  # B/I mismatch
                    error_patterns['boundary_errors'] += 1

    # Print error analysis
    print("Detailed Error Pattern Analysis:")
    print("-" * 80)
    print(f"Total tokens analyzed: {total_tokens}")
    print(f"Total errors: {total_errors}")
    print(f"Error rate: {total_errors/total_tokens*100:.2f}%\n")

    print("Error patterns:")
    for pattern, count in sorted(error_patterns.items(), key=lambda x: x[1], reverse=True):
        if count > 0:
            percentage = count/total_errors*100
            print(f"  {pattern}: {count} ({percentage:.1f}% of errors)")

    return error_patterns, total_errors, total_tokens

def analyze_feature_contributions(crf_model, tokens, features, true_tags, pred_tags):
    """Analyze which features contributed to each prediction."""
    for i, (token, true, pred) in enumerate(zip(tokens, true_tags, pred_tags)):
        if true != pred:
            print(f"Token '{token}' - True: {true}, Predicted: {pred}")

            # Get top features for this token
            token_features = features[i]
            feature_weights = {}

            # Calculate feature contributions
            for fname, fval in token_features.items():
                if (pred, fname) in crf_model.crf.state_features_:
                    weight = crf_model.crf.state_features_[(pred, fname)]
                    feature_weights[fname] = weight * fval

                # Also check for true label weights for comparison
                if (true, fname) in crf_model.crf.state_features_:
                    true_weight = crf_model.crf.state_features_[(true, fname)]
                    # Store as tuple (pred_contrib, true_contrib)
                    if fname not in feature_weights:
                        feature_weights[fname] = (0, true_weight * fval)
                    else:
                        feature_weights[fname] = (feature_weights[fname], true_weight * fval)

            # Sort and display top contributing features
            sorted_features = sorted(feature_weights.items(),
                                  key=lambda x: abs(x[1] if isinstance(x[1], float) else x[1][0]),
                                  reverse=True)[:5]

            print("Top contributing features (predicted value contribution, true value contribution):")
            for fname, contrib in sorted_features:
                if isinstance(contrib, tuple):
                    print(f"  {fname}: predicted={contrib[0]:.4f}, true={contrib[1]:.4f}")
                else:
                    print(f"  {fname}: {contrib:.4f}")

# Main execution function with the improved analysis
def run_crf_evaluation_with_analysis():
    """Run CRF model training, evaluation, and detailed analysis."""
    # Load dataset
    print("Loading CoNLL2003 dataset...")
    dataset = load_dataset("conll2003")

    # Prepare data
    print("Preparing data for CRF...")
    X_train, y_train, X_test, y_test = prepare_crf_data(dataset)

    # Initialize and train CRF model
    print("Training CRF model...")
    crf_model = CRFNER(c1=0.1, c2=0.1, max_iterations=100)
    crf_model.train(X_train, y_train)

    # Evaluate model
    print("\nEvaluating CRF NER system...")
    results = evaluate_crf_ner_system(crf_model, X_test, y_test)

    # Generate detailed classification report
    print("\nGenerating detailed classification report...")
    generate_crf_seqeval_report(crf_model, X_test, y_test)

    # Analyze model features
    print("\nAnalyzing CRF model features...")
    analyze_crf_features(crf_model)

    # Analyze specific examples
    analyze_examples(crf_model, dataset, num_examples=10)

    # Run detailed error analysis
    run_detailed_error_analysis(crf_model, dataset)

    # Analyze confidence scores
    analyze_confidence_scores(crf_model, X_test, num_examples=5)

    return results

# Simple main function
def run_crf_evaluation():
    """Run basic CRF model training and evaluation."""
    # Load dataset
    print("Loading CoNLL2003 dataset...")
    dataset = load_dataset("conll2003")

    # Prepare data
    print("Preparing data for CRF...")
    X_train, y_train, X_test, y_test = prepare_crf_data(dataset)

    # Initialize and train CRF model
    print("Training CRF model...")
    crf_model = CRFNER(c1=0.1, c2=0.1, max_iterations=100)
    crf_model.train(X_train, y_train)

    # Evaluate model
    print("\nEvaluating CRF NER system...")
    results = evaluate_crf_ner_system(crf_model, X_test, y_test)

    # Generate detailed classification report
    print("\nGenerating detailed classification report...")
    generate_crf_seqeval_report(crf_model, X_test, y_test)

    # Analyze model features
    print("\nAnalyzing CRF model features...")
    analyze_crf_features(crf_model)

    return results

# Run the CRF evaluation
if __name__ == "__main__":
    run_crf_evaluation_with_analysis()

Loading CoNLL2003 dataset...
Preparing data for CRF...
Training CRF model...

Evaluating CRF NER system...
CRF NER Evaluation Results
Entity     Precision  Recall     F1 Score   Support   
--------------------------------------------------
LOC        0.8756     0.8735     0.8745     1668      
MISC       0.7731     0.7621     0.7676     702       
ORG        0.7936     0.7339     0.7626     1661      
PER        0.8479     0.8652     0.8564     1617      

Overall Metrics
Overall Precision: 0.8318
Overall Recall: 0.8162
Overall F1 Score: 0.8239
Overall Accuracy: 0.9617

Generating detailed classification report...
CRF NER Classification Report (seqeval)
{'LOC': {'precision': np.float64(0.8756009615384616), 'recall': np.float64(0.8735011990407674), 'f1': np.float64(0.8745498199279712), 'number': np.int64(1668)}, 'MISC': {'precision': np.float64(0.773121387283237), 'recall': np.float64(0.7621082621082621), 'f1': np.float64(0.7675753228120517), 'number': np.int64(702)}, 'ORG': {'precision