In [45]:
from itertools import chain

import nltk
from datasets import load_dataset
import sklearn
from sklearn_crfsuite import metrics
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score, RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers

# Data Preparation
The data being used in this program is ***`CoNLL-2003`*** english dataset. NER tags and POS tags of each word are represented by an integer indices instead of string labels directly. These indices correspond to the actual labels, which are stored in the dataset's features. You can see more details about the dataset as below.

In [None]:
# Load the CoNLL-2003 dataset with trust_remote_code=True
dataset = load_dataset("conll2003", trust_remote_code=True)

train_dataset = dataset['train']
test_dataset = dataset['test']

# Display some examples
print(train_dataset[0])
print(test_dataset[0])

In [36]:
# Get the features of the dataset
ner_tags = dataset['train'].features['ner_tags'].feature
pos_tags = dataset['train'].features['pos_tags'].feature

# Display all NER tags and their count
print("NER Tags:", ner_tags.names)
print("Number of NER Tags:", len(ner_tags.names))

# Display all POS tags and their count
print("POS Tags:", pos_tags.names)
print("Number of POS Tags:", len(pos_tags.names))

# Function to convert dataset examples to IOB format with labels
def convert_to_iob_with_labels(dataset, ner_tags, pos_tags):
    iob_sents = []
    for example in dataset:
        tokens = example['tokens']
        ner_tag_indices = example['ner_tags']
        pos_tag_indices = example['pos_tags']
        iob_sent = [(tokens[i], pos_tags.int2str(pos_tag_indices[i]), ner_tags.int2str(ner_tag_indices[i])) for i in range(len(tokens))]
        iob_sents.append(iob_sent)
    return iob_sents

# Convert train and test datasets to IOB format with labels
train_sents = convert_to_iob_with_labels(dataset['train'], ner_tags, pos_tags)
test_sents = convert_to_iob_with_labels(dataset['test'], ner_tags, pos_tags)

# Display some examples
print(train_sents[0])
print(test_sents[0])

# Define functions to count sentences and words
def count_sentences_and_words(dataset_split):
    num_sentences = len(dataset_split)
    num_words = sum(len(example['tokens']) for example in dataset_split)
    return num_sentences, num_words

# Count sentences and words in train dataset
num_train_sentences, num_train_words = count_sentences_and_words(dataset['train'])
print(f"Number of sentences in train dataset: {num_train_sentences}")
print(f"Number of words in train dataset: {num_train_words}")

# Count sentences and words in test dataset
num_test_sentences, num_test_words = count_sentences_and_words(dataset['test'])
print(f"Number of sentences in test dataset: {num_test_sentences}")
print(f"Number of words in test dataset: {num_test_words}")

NER Tags: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
Number of NER Tags: 9
POS Tags: ['"', "''", '#', '$', '(', ')', ',', '.', ':', '``', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB']
Number of POS Tags: 47
[('EU', 'NNP', 'B-ORG'), ('rejects', 'VBZ', 'O'), ('German', 'JJ', 'B-MISC'), ('call', 'NN', 'O'), ('to', 'TO', 'O'), ('boycott', 'VB', 'O'), ('British', 'JJ', 'B-MISC'), ('lamb', 'NN', 'O'), ('.', '.', 'O')]
[('SOCCER', 'NN', 'O'), ('-', ':', 'O'), ('JAPAN', 'NNP', 'B-LOC'), ('GET', 'VB', 'O'), ('LUCKY', 'NNP', 'O'), ('WIN', 'NNP', 'O'), (',', ',', 'O'), ('CHINA', 'NNP', 'B-PER'), ('IN', 'IN', 'O'), ('SURPRISE', 'DT', 'O'), ('DEFEAT', 'NN', 'O'), ('.', '.', 'O')]
Number of sentences in train dataset: 14041
Number of words in train dataset: 203621


# Feature Extraction Functions
 **`word2features`:** Extracts features for a specific word in a sentence.  
Each word (or token) in a sequence is represented by a set of features. Each feature contributes to the likelihood of a specific label being assigned to that word. During training, the CRF: 
- Examines the patterns in the features: For instance, if the word "Paris" has is_title=True and is often labeled with B-LOC, the model will increase the weight for that feature-label pair (is_title=True → B-LOC).
- Learns which features and transitions are useful for predicting labels in different contexts (e.g., "John" is labeled B-PER when is_title=True, but "lives" is labeled O regardless of capitalization).  
  
 **`sent2features`:** Extracts features for all words in a sentence.  
 **`sent2labels`:** Extracts NER labels for all words in a sentence.  
 **`sent2tokens`:** Extracts tokens for all words in a sentence. 


In [37]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]



In [38]:
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

print("Original sentence:", ' '.join([token for token, pos, ner in train_sents[0]]))
print(X_train[0])
print(y_train[0])

Original sentence: EU rejects German call to boycott British lamb .
[{'bias': 1.0, 'word.lower()': 'eu', 'word[-3:]': 'EU', 'word[-2:]': 'EU', 'word.isupper()': True, 'word.istitle()': False, 'word.isdigit()': False, 'postag': 'NNP', 'postag[:2]': 'NN', 'BOS': True, '+1:word.lower()': 'rejects', '+1:word.istitle()': False, '+1:word.isupper()': False, '+1:postag': 'VBZ', '+1:postag[:2]': 'VB'}, {'bias': 1.0, 'word.lower()': 'rejects', 'word[-3:]': 'cts', 'word[-2:]': 'ts', 'word.isupper()': False, 'word.istitle()': False, 'word.isdigit()': False, 'postag': 'VBZ', 'postag[:2]': 'VB', '-1:word.lower()': 'eu', '-1:word.istitle()': False, '-1:word.isupper()': True, '-1:postag': 'NNP', '-1:postag[:2]': 'NN', '+1:word.lower()': 'german', '+1:word.istitle()': True, '+1:word.isupper()': False, '+1:postag': 'JJ', '+1:postag[:2]': 'JJ'}, {'bias': 1.0, 'word.lower()': 'german', 'word[-3:]': 'man', 'word[-2:]': 'an', 'word.isupper()': False, 'word.istitle()': True, 'word.isdigit()': False, 'postag'

# Training the model
1. Training Objective  
During training, the CRF tries to maximize the conditional probability of the correct label sequence 
$y$ given the feature sequence $X$. This is achieved by learning weights for each feature-label pair.

The objective function is:  
$$
\log P(y|X) = \sum_{t=1}^T (\theta \cdot f(y_{t-1}, y_t, X, t)) - \log Z(X)
$$
- $f(y_{t-1}, y_t, X, t)$: Feature function that captures how well the label transition and the observed features match.
- $\theta$: Weights of features (learned during training).
- $Z(X)$: Normalization factor (partition function) to make the probability valid.


2. Using **`lbfgs`** Algorithm for Optimization

**`lbfgs`** (Limited-memory Broyden–Fletcher–Goldfarb–Shanno) is a numerical optimization algorithm that optimizes the log-likelihood function efficiently. It handles high-dimensional data well by approximating the Hessian matrix.

NLL:    $$L(\theta) = -\sum_{(X,Y)} \log P(Y | X; \theta) + c_1 \cdot \sum_i |\theta_i| + c_2 \cdot \sum_i \theta_i^2$$
 - `c1`: L1 regularization (encourages sparsity in the model).  
 - `c2`: L2 regularization (penalizes large weights, encouraging smooth models).   
- `max_iterations`: Limits the number of iterations for the optimizer to converge.  
- `all_possible_transitions=True`: Ensures the model considers all valid state transitions, even if they don’t appear in the training data. This is crucial for learning valid transition probabilities.


Update rules:
$$\theta_{t+1} = \theta_{t} - \alpha_{t} \cdot H_{t}^{-1} \nabla \mathcal{L}(\theta_{t})$$

where:

* $\theta_{t}$ is the current weight vector at iteration t.
* $\alpha_{t}$ is the step size (learning rate).
* $\nabla \mathcal{L}(\theta_{t})$ is the gradient of the NLL with respect to the weights.
* $H_{t}^{-1}$ is an approximation of the inverse Hessian matrix (second-order derivative).




In [39]:
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)



In [40]:
labels = list(crf.classes_)
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

           O      0.989     0.989     0.989     38323
       B-LOC      0.856     0.814     0.834      1668
       I-LOC      0.745     0.626     0.681       257
      B-MISC      0.819     0.754     0.785       702
      I-MISC      0.688     0.653     0.670       216
       B-ORG      0.775     0.727     0.750      1661
       I-ORG      0.679     0.734     0.705       835
       B-PER      0.822     0.860     0.841      1617
       I-PER      0.861     0.951     0.904      1156

    accuracy                          0.956     46435
   macro avg      0.804     0.790     0.795     46435
weighted avg      0.956     0.956     0.956     46435



In [42]:
# Select an example from X_test
example_index = 0  # Change this index to select a different example
example_features = X_test[example_index]

# Predict the NER tags for the selected example
predicted_tags = crf.predict([example_features])[0]

# Assuming you have the original tokens in test_sents
example_tokens = [token for token, pos, ner in test_sents[example_index]]

# Combine tokens with their predicted NER tags
predicted_ner = list(zip(example_tokens, predicted_tags))

# Print the original sentence
original_sentence = ' '.join(example_tokens)
print("Example sentence:", original_sentence)

# Print the result
print(predicted_ner)

Example sentence: SOCCER - JAPAN GET LUCKY WIN , CHINA IN SURPRISE DEFEAT .
[('SOCCER', 'O'), ('-', 'O'), ('JAPAN', 'B-LOC'), ('GET', 'O'), ('LUCKY', 'O'), ('WIN', 'O'), (',', 'O'), ('CHINA', 'B-LOC'), ('IN', 'O'), ('SURPRISE', 'O'), ('DEFEAT', 'O'), ('.', 'O')]


In [43]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
B-PER  -> I-PER   5.005456
B-LOC  -> I-LOC   4.899898
I-MISC -> I-MISC  4.819772
B-MISC -> I-MISC  4.608090
I-LOC  -> I-LOC   4.411735
I-ORG  -> I-ORG   4.178668
B-ORG  -> I-ORG   3.825759
O      -> O       3.343132
I-PER  -> I-PER   2.929024
O      -> B-PER   2.010622
O      -> B-ORG   1.291255
O      -> B-LOC   0.957378
O      -> B-MISC  0.879598
B-LOC  -> O       0.719279
B-MISC -> O       0.442459
B-PER  -> O       0.064878
I-LOC  -> O       -0.100254
I-ORG  -> O       -0.134502
I-MISC -> O       -0.205254
I-PER  -> O       -0.213852

Top unlikely transitions:
B-LOC  -> B-PER   -3.088280
B-LOC  -> I-PER   -3.134822
B-LOC  -> I-MISC  -3.150349
B-MISC -> I-LOC   -3.167058
B-PER  -> B-MISC  -3.176339
B-PER  -> I-ORG   -3.194907
B-ORG  -> I-MISC  -3.225981
I-ORG  -> B-LOC   -3.359505
B-ORG  -> B-PER   -3.585883
B-PER  -> B-ORG   -3.614308
I-PER  -> B-PER   -3.739063
B-ORG  -> I-PER   -3.918017
B-ORG  -> B-LOC   -3.978385
B-MISC -> I-ORG   -4.187154
O      -> I-P

In [44]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
5.635804 B-PER    word.lower():clinton
5.608407 B-ORG    -1:word.lower():v
5.492983 B-LOC    +1:word.lower():1996-08-26
5.483221 I-LOC    -1:word.lower():wisc
5.404337 I-LOC    -1:word.lower():colo
5.334371 O        word.lower():minister
5.322884 O        word[-3:]:day
5.186727 B-LOC    +1:word.lower():1996-08-27
5.033312 B-PER    word.lower():ata-ur-rehman
4.997335 B-LOC    word.lower():hungary
4.987009 B-LOC    +1:word.lower():1996-08-25
4.983907 B-LOC    +1:word.lower():1996-08-23
4.878637 O        BOS
4.877019 B-LOC    +1:word.lower():1996-08-22
4.850749 B-LOC    word.lower():france
4.796517 B-LOC    word.lower():chester-le-street
4.783352 B-LOC    +1:word.lower():1996-08-28
4.723597 O        word.lower():august
4.712476 B-MISC   word.lower():german
4.641039 B-PER    word.lower():stenning
4.620123 B-LOC    +1:word.lower():1996-08-29
4.619735 B-ORG    word.lower():sungard
4.610181 B-LOC    word.lower():england
4.602871 O        word.lower():march
4.576173 B-MISC   word