In [4]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datasets import load_dataset
from sklearn_crfsuite import CRF
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, auc
from collections import Counter

# Set seaborn style for better visuals
sns.set(style="whitegrid", palette="muted")

# Step 1: Load CoNLL-2003 dataset
conll2003 = load_dataset("conll2003")
label_names = conll2003["train"].features["ner_tags"].feature.names

# Step 2: Prepare data for CRF
def prepare_data_for_crf(dataset):
    sentences, labels = [], []
    for example in dataset:
        words = example['tokens']
        ner_tags = [label_names[tag] for tag in example['ner_tags']]
        sentences.append(words)
        labels.append(ner_tags)
    return sentences, labels

X_train, y_train = prepare_data_for_crf(conll2003['train'])
X_val, y_val = prepare_data_for_crf(conll2003['validation'])
X_test, y_test = prepare_data_for_crf(conll2003['test'])

# Step 3: Enhanced NER Tag Distribution Plot
def plot_ner_distribution(y_data, title):
    label_counts = Counter(tag for seq in y_data for tag in seq)
    plt.figure(figsize=(12, 6))
    sns.barplot(
        x=list(label_counts.keys()), 
        y=list(label_counts.values()), 
        hue=list(label_counts.keys()), 
        palette="viridis", 
        legend=False
    )
    plt.xticks(rotation=45, ha='right')
    plt.title(title, fontsize=14, pad=10)
    plt.xlabel("NER Tags", fontsize=12)
    plt.ylabel("Count", fontsize=12)
    plt.tight_layout()
    plt.show()
    plt.close()

plot_ner_distribution(y_train, "NER Tag Distribution in Training Set")
plot_ner_distribution(y_val, "NER Tag Distribution in Validation Set")
plot_ner_distribution(y_test, "NER Tag Distribution in Test Set")

# Step 4: Feature extraction for CRF
def word2features(sent, i):
    word = sent[i]
    features = {
        'bias': 1.0,
        'word.lower': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper': word.isupper(),
        'word.istitle': word.istitle(),
        'word.isdigit': word.isdigit()
    }
    if i > 0:
        word1 = sent[i-1]
        features.update({
            'prev_word.lower': word1.lower(),
            'prev_word.istitle': word1.istitle(),
            'prev_word.isupper': word1.isupper()
        })
    else:
        features['BOS'] = True
    if i < len(sent)-1:
        word1 = sent[i+1]
        features.update({
            'next_word.lower': word1.lower(),
            'next_word.istitle': word1.istitle(),
            'next_word.isupper': word1.isupper()
        })
    else:
        features['EOS'] = True
    return features

def extract_features(sentences):
    return [[word2features(sent, i) for i in range(len(sent))] for sent in sentences]

X_train_feats = extract_features(X_train)
X_val_feats = extract_features(X_val)
X_test_feats = extract_features(X_test)

# Step 5: Train CRF Model
crf = CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train_feats, y_train)

# Step 6: Evaluation with all plots
def evaluate_model(X, y_true, split_name, labels):
    y_pred = crf.predict(X)
    y_true_flat = [item for sublist in y_true for item in sublist]
    y_pred_flat = [item for sublist in y_pred for item in sublist]
    
    print(f"{split_name} Classification Report:")
    report = classification_report(y_true_flat, y_pred_flat, target_names=labels, output_dict=True)
    print(classification_report(y_true_flat, y_pred_flat, target_names=labels))
    
    # Precision-Recall Curve (using marginal probabilities)
    plt.figure(figsize=(10, 6))
    plotted_curves = False
    for label in ['B-PER', 'I-PER', 'B-LOC', 'O']:
        try:
            label_idx = labels.index(label)
            y_true_binary = [1 if tag == label else 0 for tag in y_true_flat]
            # Flatten marginal probabilities across all sequences
            marginals = [prob[label] for seq in crf.predict_marginals(X) for prob in seq]
            y_score = marginals[:len(y_true_binary)]  # Ensure length matches
            precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
            pr_auc = auc(recall, precision)
            plt.plot(recall, precision, label=f'{label} (AUC = {pr_auc:.2f})')
            plotted_curves = True
        except (ValueError, KeyError):
            continue
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title(f'Precision-Recall Curve - {split_name}', fontsize=14)
    if plotted_curves:
        plt.legend(loc='best')
    else:
        plt.text(0.5, 0.5, 'No valid curves plotted', ha='center', va='center', fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    plt.close()
    
    # Confusion Matrix
    cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d', 
        xticklabels=labels, 
        yticklabels=labels, 
        cmap='RdYlBu',
        cbar_kws={'label': 'Count'},
        linewidths=0.5,
        linecolor='gray'
    )
    plt.xlabel("Predicted", fontsize=12)
    plt.ylabel("True", fontsize=12)
    plt.title(f"Confusion Matrix - {split_name}", fontsize=14, pad=10)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    plt.close()
    
    return y_pred, report

# Evaluate on all sets
y_train_pred, train_report = evaluate_model(X_train_feats, y_train, "Training Set", crf.classes_)
y_val_pred, val_report = evaluate_model(X_val_feats, y_val, "Validation Set", crf.classes_)
y_test_pred, test_report = evaluate_model(X_test_feats, y_test, "Test Set", crf.classes_)

# Step 7: Transition Weights Visualization
def plot_transition_weights(crf, top_n=20):
    transitions = crf.transition_features_
    sorted_trans = sorted(transitions.items(), key=lambda x: abs(x[1]), reverse=True)[:top_n]
    from_labels, to_labels = zip(*[t[0] for t in sorted_trans])
    weights = [t[1] for t in sorted_trans]
    
    plt.figure(figsize=(12, 6))
    sns.barplot(
        x=[f"{f} → {t}" for f, t in zip(from_labels, to_labels)], 
        y=weights, 
        hue=[f"{f} → {t}" for f, t in zip(from_labels, to_labels)], 
        palette="magma", 
        legend=False
    )
    plt.xticks(rotation=45, ha='right')
    plt.title("Top Transition Weights in CRF", fontsize=14)
    plt.xlabel("Transitions", fontsize=12)
    plt.ylabel("Weight", fontsize=12)
    plt.tight_layout()
    plt.show()
    plt.close()

plot_transition_weights(crf)

# Step 8: F1-Score Comparison Plot
def plot_f1_comparison(train_report, val_report, test_report, title):
    classes = [c for c in val_report.keys() if c not in ['accuracy', 'macro avg', 'weighted avg']]
    train_f1 = [train_report[c]['f1-score'] for c in classes]
    val_f1 = [val_report[c]['f1-score'] for c in classes]
    test_f1 = [test_report[c]['f1-score'] for c in classes]
    
    x = np.arange(len(classes))
    width = 0.25
    plt.figure(figsize=(14, 6))
    plt.bar(x - width, train_f1, width, label='Training', color='#FF6F61', edgecolor='black')
    plt.bar(x, val_f1, width, label='Validation', color='#6B5B95', edgecolor='black')
    plt.bar(x + width, test_f1, width, label='Test', color='#88B04B', edgecolor='black')
    plt.xlabel('Classes', fontsize=12)
    plt.ylabel('F1-Score', fontsize=12)
    plt.title(title, fontsize=14)
    plt.xticks(x, classes, rotation=45, ha='right')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    plt.close()

plot_f1_comparison(train_report, val_report, test_report, "F1-Score Comparison: Training vs Validation vs Test")

ModuleNotFoundError: No module named 'datasets'

In [1]:
2+2


4