<a href="https://colab.research.google.com/github/Shriyatha/Named_Entity_Recognition/blob/main/RULE_BASED_NER_Telugu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install datasets evaluate seqeval spacy tabulate -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m13.7 MB/s[0m eta [36

In [3]:
from datasets import load_dataset
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
import spacy
import re
from collections import defaultdict, Counter
from tabulate import tabulate
import random
import numpy as np
from tqdm import tqdm
from seqeval.scheme import IOB2 as IOB2Scheme

In [11]:
class TeluguNER:
    def __init__(self):
        """Initialize the Telugu NER system with multilingual spaCy model and Telugu-specific patterns"""
        # Load spaCy's multilingual model
        try:
            self.nlp = spacy.load("xx_ent_wiki_sm")
        except:
            print("Installing spaCy multilingual model...")
            spacy.cli.download("xx_ent_wiki_sm")
            self.nlp = spacy.load("xx_ent_wiki_sm")

        # Entity configuration
        self.allowed_entities = {"PER", "LOC", "ORG"}
        self.spacy_to_wiki = {
            "PERSON": "PER",
            "GPE": "LOC",
            "LOC": "LOC",
            "ORG": "ORG",
            "NORP": "ORG",
            "FAC": "LOC"
        }

        # Initialize pattern stores and statistical trackers
        self._initialize_patterns()
        self._initialize_trackers()

    def _initialize_patterns(self):
        """Initialize Telugu-specific patterns for different entity types"""
        # Name patterns for person entities
        self.correction_patterns = {
            "PER": {
                "name_patterns": [
                    re.compile(r'.*(?:రెడ్డి|శర్మ|వర్మ|నాయుడు|చౌదరి|రావు|కుమార్|దేవి|గౌడ్|జోషి|పాటేల్|మీనా)$'),
                    re.compile(r'^(?:శ్రీ|శ్రీమతి|శ్రీమాన్|డాక్టర్|ప్రొఫెసర్) .*')
                ],
                "full_name_patterns": [
                    re.compile(r'^[^\s]+\s+[^\s]+(?:\s+[^\s]+)?$') # Fixed: Added closing parenthesis and $
                ]
            },
            "ORG": {
                "company_patterns": [
                    re.compile(r'.*(?:లిమిటెడ్|ప్రైవేట్|కంపెనీ|సంస్థ|బ్యాంక్)$'), # Fixed: Added closing parenthesis and $
                ],
                "educational_patterns": [
                    re.compile(r'^(?:విశ్వవిద్యాలయం|కళాశాల|పాఠశాల|ఇన్స్టిట్యూట్).*$'), # Fixed: Added closing parenthesis and $
                    re.compile(r'.*(?:విశ్వవిద్యాలయం|కళాశాల|పాఠశాల|ఇన్స్టిట్యూట్)$') # Fixed: Added closing parenthesis and $
                ]
            },
            "LOC": {
                "geographic_patterns": [
                    re.compile(r'.*(?:జిల్లా|నగరం|గ్రామం|రాష్ట్రం|దేశం|మండలం|కేంద్రం|నగర్)$'), # Fixed: Added closing parenthesis and $
                ],
                "natural_features": [
                    re.compile(r'.*(?:నది|పర్వతం|సముద్రం|సరస్సు|దక్షిణ|ఉత్తర|పశ్చిమ|తూర్పు)$') # Fixed: Added closing parenthesis and $
                ]
            }
        }

        # High-precision patterns
        self.high_precision_patterns = {
            "PER": [
                re.compile(r'.*(?:గారు|జీ)$'), # Fixed: Added closing parenthesis and $
                re.compile(r'^(?:డాక్టర్|ప్రొఫెసర్|నేతాజీ) .*')
            ],
            "ORG": [
                re.compile(r'.*(?:కార్పొరేషన్|సంఘం|మండలి|పరిషత్)$'), # Fixed: Added closing parenthesis and $
                re.compile(r'^(?:భారత|ఆంధ్రప్రదేశ్|తెలంగాణ) .*(?:ప్రభుత్వం|సంస్థ)$') # Fixed: Added closing parenthesis and $
            ],
            "LOC": [
                re.compile(r'.*(?:రాష్ట్రం|దేశం|రాజధాని|ప్రాంతం)$'), # Fixed: Added closing parenthesis and $
            ]
        }

        # Context indicators
        self.context_indicators = {
            "PER": ["అయిన", "పేరు", "వ్యక్తి", "నటుడు", "నటి", "రచయిత", "కుమారుడు", "కుమార్తె", "అన్న", "అక్క"],
            "ORG": ["సంస్థ", "కంపెనీ", "బృందం", "సమూహం", "పార్టీ", "ఆఫీసు", "శాఖ", "యూనిట్"],
            "LOC": ["ప్రాంతం", "నగరం", "దేశం", "రాష్ట్రం", "గ్రామం"]
        }

        self.context_post_indicators = {
            "PER": ["చెప్పారు", "తెలిపారు", "అన్నారు", "పేర్కొన్నారు"],
            "ORG": ["ప్రకటించింది", "తెలిపింది", "నిర్వహించింది"],
            "LOC": ["లో", "నుండి", "వరకు", "లోని", "ప్రాంతంలో"]
        }
    def _initialize_trackers(self):
        """Initialize statistical trackers and performance metrics"""
        self.misclassification_corrections = defaultdict(list)
        self.confusion_matrix = {
            "PER": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0},
            "LOC": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0},
            "ORG": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0},
            "O": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0}
        }
        self.performance_metrics = {}
        self.error_examples = []
        self.entity_token_freq = {
            "PER": Counter(),
            "LOC": Counter(),
            "ORG": Counter()
        }
        self.token_confidence = {}
        self.entity_errors = defaultdict(lambda: {"missed": 0, "wrong_type": 0, "partial": 0})

    def find_token_span(self, tokens, text):
        """Find token span in a list of tokens by matching words.
        This is more robust than exact matching."""
        target_tokens = text.split()

        if not target_tokens:
            return None

        for i in range(len(tokens) - len(target_tokens) + 1):
            match = True
            for j, target_token in enumerate(target_tokens):
                if tokens[i+j].lower() != target_token.lower():
                    match = False
                    break
            if match:
                return i, i + len(target_tokens)
        return None

    def create_balanced_dataset(self, dataset_split="train", max_samples_per_class=2000, min_samples_per_class=500):
        """Create a balanced dataset by downsampling majority classes"""
        print(f"Creating balanced dataset from {dataset_split} split...")

        split_data = dataset[dataset_split]
        id_to_label = split_data.features["ner_tags"].feature.int2str

        # Count entity types
        entity_examples = defaultdict(list)
        other_examples = []

        for idx, (tokens, tags) in enumerate(zip(split_data["tokens"], split_data["ner_tags"])):
            string_tags = [id_to_label(tag) for tag in tags]
            entity_types_in_example = set()

            for tag in string_tags:
                if tag != "O":
                    entity_type = tag[2:] if "-" in tag else tag
                    entity_types_in_example.add(entity_type)

            if entity_types_in_example:
                for entity_type in entity_types_in_example:
                    entity_examples[entity_type].append(idx)
            else:
                other_examples.append(idx)

        # Print entity distribution
        print("Entity distribution in original dataset:")
        for entity_type, examples in entity_examples.items():
            print(f"  {entity_type}: {len(examples)} examples")
        print(f"  None (O): {len(other_examples)} examples")

        # Select balanced samples
        selected_indices = []
        for entity_type, examples in entity_examples.items():
            num_samples = min(max_samples_per_class, len(examples))
            num_samples = max(num_samples, min_samples_per_class)
            selected_indices.extend(random.sample(examples, num_samples))

        # Add proportional "O" examples
        num_other_samples = min(len(selected_indices) // 2, len(other_examples))
        selected_indices.extend(random.sample(other_examples, num_other_samples))

        return {
            "tokens": [split_data["tokens"][i] for i in selected_indices],
            "ner_tags": [split_data["ner_tags"][i] for i in selected_indices]
        }

    def load_correction_patterns_from_training(self, dataset_data):
        """Learn common misclassifications from training data"""
        print("Learning correction patterns from training data...")
        id_to_label = dataset_data["train"].features["ner_tags"].feature.int2str

        spacy_errors = defaultdict(Counter)
        entity_stats = defaultdict(int)
        entity_tokens = defaultdict(Counter)
        train_data = dataset_data["train"]

        for idx, (tokens, tags) in enumerate(tqdm(zip(train_data["tokens"], train_data["ner_tags"]),
                                             total=len(train_data["tokens"]), desc="Analyzing training data")):
            string_tags = [id_to_label(tag) for tag in tags]
            text = " ".join(tokens)
            doc = self.nlp(text)

            # Track entity statistics and token frequencies
            i = 0
            while i < len(tokens):
                tag = string_tags[i]
                if tag.startswith("B-"):
                    entity_type = tag[2:]
                    entity_stats[entity_type] += 1
                    entity_tokens[entity_type][tokens[i]] += 1

                    j = i + 1
                    while j < len(string_tags) and string_tags[j].startswith("I-") and string_tags[j][2:] == entity_type:
                        entity_tokens[entity_type][tokens[j]] += 1
                        j += 1
                    i = j
                else:
                    i += 1

            # Extract entities from spaCy
            spacy_entities = []
            for ent in doc.ents:
                if ent.label_ in self.spacy_to_wiki:
                    wiki_label = self.spacy_to_wiki[ent.label_]
                    spacy_entities.append((ent.text, wiki_label, ent.start_char, ent.end_char))

            # Extract gold entities
            gold_entities = []
            i = 0
            while i < len(tokens):
                if string_tags[i].startswith("B-"):
                    entity_type = string_tags[i][2:]
                    start_idx = i
                    i += 1
                    while i < len(string_tags) and string_tags[i].startswith("I-") and string_tags[i][2:] == entity_type:
                        i += 1
                    end_idx = i
                    entity_text = " ".join(tokens[start_idx:end_idx])
                    gold_entities.append((entity_text, entity_type))
                else:
                    i += 1

            # Find missed entities
            for gold_text, gold_type in gold_entities:
                found = False
                for spacy_text, spacy_type, _, _ in spacy_entities:
                    if gold_text.lower() == spacy_text.lower():
                        found = True
                        break

                if not found:
                    spacy_errors[gold_type][gold_text] += 1

        print(f"Entity statistics in training data: {dict(entity_stats)}")
        self.entity_token_freq = entity_tokens
        self._calculate_token_confidence_scores()

        # Create correction patterns
        pattern_count = 0
        for entity_type, error_counter in spacy_errors.items():
            for phrase, count in error_counter.most_common(100):
                if count > 1 and len(phrase) > 2:
                    try:
                        # Use word boundary for more precise matching
                        escaped_phrase = re.escape(phrase)
                        pattern = re.compile(r'\b' + escaped_phrase + r'\b')
                        self.misclassification_corrections[entity_type].append((pattern, count))
                        pattern_count += 1
                    except re.error:
                        continue

        print(f"Created {pattern_count} correction patterns from training data")

    def _calculate_token_confidence_scores(self):
        """Calculate confidence scores for tokens based on their frequency in different entity types"""
        all_tokens = Counter()
        for entity_type, counter in self.entity_token_freq.items():
            for token, count in counter.items():
                all_tokens[token] += count

        # Only consider tokens that appear multiple times
        for token, total_count in all_tokens.items():
            if total_count < 3:  # Require at least 3 occurrences
                continue

            scores = {}
            for entity_type in self.allowed_entities:
                entity_count = self.entity_token_freq[entity_type].get(token, 0)
                if entity_count > 0:
                    scores[entity_type] = entity_count / total_count

            if scores:
                max_type = max(scores, key=scores.get)
                # Only keep tokens with strong entity association
                if scores[max_type] > 0.7:  # Increased threshold for higher precision
                    self.token_confidence[token] = (max_type, scores[max_type])

    def get_spacy_predictions(self, tokens):
        """Get entity predictions from spaCy model"""
        text = " ".join(tokens)
        doc = self.nlp(text)

        labels = ["O"] * len(tokens)
        confidence = [0.0] * len(tokens)

        # Create a mapping from character positions to token indices
        char_to_token = {}
        char_pos = 0
        for i, token in enumerate(tokens):
            for j in range(len(token)):
                char_to_token[char_pos + j] = i
            char_pos += len(token) + 1  # +1 for space

        # Apply spaCy entities
        for ent in doc.ents:
            if ent.label_ in self.spacy_to_wiki:
                entity_type = self.spacy_to_wiki[ent.label_]
                if entity_type in self.allowed_entities:
                    # Find token spans using character positions
                    try:
                        start_token = char_to_token.get(ent.start_char, None)
                        end_token = char_to_token.get(ent.end_char - 1, None)

                        if start_token is not None and end_token is not None:
                            labels[start_token] = f"B-{entity_type}"
                            confidence[start_token] = 0.7

                            for i in range(start_token + 1, end_token + 1):
                                if i < len(tokens):
                                    labels[i] = f"I-{entity_type}"
                                    confidence[i] = 0.7
                    except:
                        # Fall back to text matching if character positions fail
                        span = self.find_token_span(tokens, ent.text)
                        if span:
                            start, end = span
                            labels[start] = f"B-{entity_type}"
                            confidence[start] = 0.7
                            for i in range(start + 1, end):
                                labels[i] = f"I-{entity_type}"
                                confidence[i] = 0.7

        return labels, confidence

    def apply_regex_patterns(self, tokens):
        """Apply regex pattern-based entity detection with confidence scoring"""
        regex_labels = ["O"] * len(tokens)
        confidence_scores = [0.0] * len(tokens)
        text = " ".join(tokens)

        # Apply high precision patterns first
        for entity_type, patterns in self.high_precision_patterns.items():
            for pattern in patterns:
                # Check full text for matches
                try:
                    for match in pattern.finditer(text):
                        matched_text = match.group()
                        span = self.find_token_span(tokens, matched_text)
                        if span:
                            start, end = span
                            regex_labels[start] = f"B-{entity_type}"
                            confidence_scores[start] = 0.95
                            for i in range(start + 1, end):
                                regex_labels[i] = f"I-{entity_type}"
                                confidence_scores[i] = 0.95
                except:
                    continue

        # Apply regular patterns with lower confidence
        for entity_type, pattern_groups in self.correction_patterns.items():
            for group_name, patterns in pattern_groups.items():
                base_confidence = 0.85
                if "name_patterns" in group_name or "full_name_patterns" in group_name:
                    base_confidence = 0.9

                for pattern in patterns:
                    try:
                        for match in pattern.finditer(text):
                            matched_text = match.group()
                            span = self.find_token_span(tokens, matched_text)
                            if span and all(regex_labels[i] == "O" for i in range(span[0], span[1])):
                                start, end = span
                                regex_labels[start] = f"B-{entity_type}"
                                confidence_scores[start] = base_confidence
                                for i in range(start + 1, end):
                                    regex_labels[i] = f"I-{entity_type}"
                                    confidence_scores[i] = base_confidence
                    except:
                        continue

        return regex_labels, confidence_scores

    def check_context_indicators(self, tokens):
        """Use context words to identify potential entities"""
        context_labels = ["O"] * len(tokens)
        confidence_scores = [0.0] * len(tokens)

        # Forward context indicators
        for i, token in enumerate(tokens):
            if i < len(tokens) - 1:  # Make sure we're not at the last token
                for entity_type, indicators in self.context_indicators.items():
                    if token.lower() in indicators:
                        # The token after the indicator is likely the start of the entity
                        if i+1 < len(tokens):
                            context_labels[i+1] = f"B-{entity_type}"
                            confidence_scores[i+1] = 0.7

                            # Look for continuation of the entity (up to 3 tokens)
                            for j in range(i+2, min(i+5, len(tokens))):
                                if any(tokens[j].lower() in inds for entity, inds in self.context_indicators.items()):
                                    break
                                if any(tokens[j].lower() in inds for entity, inds in self.context_post_indicators.items()):
                                    break
                                context_labels[j] = f"I-{entity_type}"
                                confidence_scores[j] = 0.7

        # Post-context indicators
        for i, token in enumerate(tokens):
            if i > 0:  # Make sure we're not at the first token
                for entity_type, indicators in self.context_post_indicators.items():
                    if token.lower() in indicators:
                        # The token before the indicator is likely the end of the entity
                        if context_labels[i-1] == "O":
                            context_labels[i-1] = f"B-{entity_type}"
                            confidence_scores[i-1] = 0.7

                            # Look backward for potential entity beginning (up to 3 tokens)
                            start_idx = max(0, i-4)
                            for j in range(i-2, start_idx-1, -1):
                                if j < 0:
                                    break
                                if any(tokens[j].lower() in inds for entity, inds in self.context_indicators.items()):
                                    break
                                if context_labels[j] == "O":
                                    context_labels[j] = f"I-{entity_type}"
                                    confidence_scores[j] = 0.65

                            # Fix the BIO scheme - first token should be B-
                            for j in range(start_idx, i):
                                if context_labels[j].startswith("I-"):
                                    prefix = context_labels[j][2:]
                                    if j == 0 or context_labels[j-1] == "O" or context_labels[j-1][2:] != prefix:
                                        context_labels[j] = f"B-{prefix}"

        # Apply token-level confidence
        for i, token in enumerate(tokens):
            if token in self.token_confidence and context_labels[i] == "O":
                entity_type, conf = self.token_confidence[token]
                if conf > 0.8:  # Higher threshold for isolated tokens
                    context_labels[i] = f"B-{entity_type}"
                    confidence_scores[i] = conf * 0.8

        return context_labels, confidence_scores

    def apply_learned_corrections(self, tokens, base_labels):
        """Apply corrections based on learned misclassifications"""
        custom_labels = base_labels.copy()
        confidence_scores = [0.0] * len(tokens)
        for i, label in enumerate(base_labels):
            if label != "O":
                confidence_scores[i] = 0.6

        text = " ".join(tokens)

        for entity_type, patterns_with_counts in self.misclassification_corrections.items():
            sorted_patterns = sorted(patterns_with_counts, key=lambda x: x[1], reverse=True)

            for pattern, count in sorted_patterns:
                try:
                    for match in pattern.finditer(text):
                        matched_text = match.group()
                        span = self.find_token_span(tokens, matched_text)

                        if span:
                            start, end = span
                            # Higher confidence for frequently missed patterns
                            pattern_confidence = min(0.5 + (count / 20) * 0.4, 0.9)

                            # Only override if current labels are "O" or lower confidence
                            if custom_labels[start] == "O" or confidence_scores[start] < pattern_confidence:
                                custom_labels[start] = f"B-{entity_type}"
                                confidence_scores[start] = pattern_confidence

                                for i in range(start + 1, end):
                                    custom_labels[i] = f"I-{entity_type}"
                                    confidence_scores[i] = pattern_confidence
                except:
                    continue

        return custom_labels, confidence_scores

    def hybrid_entity_extraction(self, tokens):
        """Combine multiple methods for entity extraction using weighted confidence scores"""
        # Get predictions from each method
        spacy_labels, spacy_confidence = self.get_spacy_predictions(tokens)
        regex_labels, regex_confidence = self.apply_regex_patterns(tokens)
        context_labels, context_confidence = self.check_context_indicators(tokens)
        custom_labels, custom_confidence = self.apply_learned_corrections(tokens, spacy_labels)

        # Combine predictions with confidence weighting
        final_labels = ["O"] * len(tokens)
        confidence_info = [""] * len(tokens)

        for i in range(len(tokens)):
            predictions = []

            if custom_labels[i] != "O":
                predictions.append((custom_labels[i], custom_confidence[i], "custom"))

            if regex_labels[i] != "O":
                predictions.append((regex_labels[i], regex_confidence[i], "regex"))

            if context_labels[i] != "O":
                predictions.append((context_labels[i], context_confidence[i], "context"))

            if spacy_labels[i] != "O":
                predictions.append((spacy_labels[i], spacy_confidence[i], "spacy"))

            if predictions:
                # Choose prediction with highest confidence
                best_pred, best_conf, method = max(predictions, key=lambda x: x[1])
                final_labels[i] = best_pred
                confidence_info[i] = f"{method}:{best_conf:.2f}"
            else:
                confidence_info[i] = "O:0.00"

        # Ensure BIO consistency
        final_labels = self._ensure_bio_consistency(final_labels)

        # Post-processing
        final_labels = self._post_process_entities(tokens, final_labels)

        return {
            "spacy": spacy_labels,
            "regex": regex_labels,
            "context": context_labels,
            "custom": custom_labels,
            "final": final_labels,
            "confidence": confidence_info
        }

    def _ensure_bio_consistency(self, labels):
        """Ensure BIO tagging is consistent"""
        consistent_labels = labels.copy()

        for i in range(1, len(labels)):
            # Fix I- tags that don't follow matching B- or I- tags
            if consistent_labels[i].startswith("I-"):
                entity_type = consistent_labels[i][2:]

                # If previous tag is O or different entity, convert to B-
                if (consistent_labels[i-1] == "O" or
                    (consistent_labels[i-1].startswith("B-") and consistent_labels[i-1][2:] != entity_type) or
                    (consistent_labels[i-1].startswith("I-") and consistent_labels[i-1][2:] != entity_type)):
                    consistent_labels[i] = f"B-{entity_type}"

        return consistent_labels

    def _post_process_entities(self, tokens, labels):
        """Apply post-processing rules to improve entity consistency"""
        processed_labels = labels.copy()

        # Rule 1: Fix single-token entities with low-confidence between same-type entities
        for i in range(1, len(tokens)-1):
            if processed_labels[i-1].startswith(("B-", "I-")) and processed_labels[i+1].startswith(("B-", "I-")):
                type_before = processed_labels[i-1][2:]
                type_after = processed_labels[i+1][2:]

                if type_before == type_after and processed_labels[i] == "O":
                    processed_labels[i] = f"I-{type_before}"

                # Fix B- at position i+1 if it's part of the same entity
                if processed_labels[i+1].startswith("B-") and type_before == type_after:
                    processed_labels[i+1] = f"I-{type_after}"

        # Rule 2: Fix adjacent entity boundaries
        for i in range(len(tokens)-1):
            if processed_labels[i].startswith("B-") and processed_labels[i+1].startswith("B-"):
                type_current = processed_labels[i][2:]
                type_next = processed_labels[i+1][2:]

                # If they're the same type, convert the second B- to I-
                if type_current == type_next:
                    processed_labels[i+1] = f"I-{type_next}"

        # Rule 3: Fix inconsistent I- tags
        for i in range(1, len(tokens)):
            if processed_labels[i].startswith("I-"):
                type_current = processed_labels[i][2:]

                # If previous is B- or I- of different type, convert to B-
                if processed_labels[i-1].startswith(("B-", "I-")):
                    type_prev = processed_labels[i-1][2:]
                    if type_prev != type_current:
                        processed_labels[i] = f"B-{type_current}"

        return processed_labels

    def evaluate_on_dataset(self, dataset_split="test", limit=None):
        """Evaluate the NER system on a dataset split"""
        print(f"Evaluating on {dataset_split} set...")

        split_data = dataset[dataset_split]
        id_to_label = split_data.features["ner_tags"].feature.int2str

        # Initialize metrics tracking
        true_predictions = []
        true_labels = []
        errors = []
        entity_errors = defaultdict(lambda: {"missed": 0, "wrong_type": 0, "partial": 0})
        confusion_matrix = {
            "PER": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0},
            "LOC": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0},
            "ORG": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0},
            "O": {"PER": 0, "LOC": 0, "ORG": 0, "O": 0}
        }

        # Process each example
        max_examples = min(limit, len(split_data)) if limit else len(split_data)
        for idx, (tokens, tags) in enumerate(tqdm(zip(split_data["tokens"], split_data["ner_tags"]),
                                                total=max_examples,
                                                desc=f"Evaluating {dataset_split}")):
            if limit and idx >= limit:
                break

            # Convert numeric tags to string labels
            string_tags = [id_to_label(tag) for tag in tags]

            # Ensure gold labels are in proper BIO format
            string_tags = self._ensure_bio_consistency(string_tags)
            true_labels.append(string_tags)

            # Get predictions
            results = self.hybrid_entity_extraction(tokens)
            final_preds = results["final"]

            # Ensure predictions are in proper BIO format
            final_preds = self._ensure_bio_consistency(final_preds)
            true_predictions.append(final_preds)

            # Record errors and confusion
            error_examples = []
            for i, (token, true, pred) in enumerate(zip(tokens, string_tags, final_preds)):
                true_type = true[2:] if "-" in true else "O"
                pred_type = pred[2:] if "-" in pred else "O"

                # Update confusion matrix
                if true_type in confusion_matrix and pred_type in confusion_matrix[true_type]:
                    confusion_matrix[true_type][pred_type] += 1

                # Record errors
                if true != pred:
                    # Get context (3 tokens before and after)
                    start_idx = max(0, i - 3)
                    end_idx = min(len(tokens), i + 4)
                    context = " ".join(tokens[start_idx:end_idx])

                    # Determine error type
                    error_type = "other"
                    if true.startswith("B-") and pred == "O":
                        error_type = "missed"
                        entity_errors[true_type]["missed"] += 1
                    elif true.startswith("B-") and pred.startswith("B-"):
                        error_type = "wrong_type"
                        entity_errors[true_type]["wrong_type"] += 1
                    elif true.startswith("I-") and (pred == "O" or pred.startswith("B-")):
                        error_type = "partial"
                        entity_entity = true[2:]
                        entity_errors[entity_entity]["partial"] += 1

                    error_examples.append({
                        "token": token,
                        "true": true,
                        "predicted": pred,
                        "context": context,
                        "error_type": error_type
                    })

            if error_examples:
                errors.extend(error_examples)

        # Calculate metrics using seqeval
        metrics = classification_report(
            true_labels, true_predictions,
            digits=4, output_dict=True,
            mode='strict', scheme=IOB2Scheme  # Use the class, not a string
        )

        # Store results for later analysis
        self.true_labels = true_labels
        self.true_predictions = true_predictions
        self.error_examples = errors
        self.confusion_matrix = confusion_matrix
        self.entity_errors = entity_errors
        self.performance_metrics = metrics

        # Extract overall metrics
        results = {
            "overall_precision": metrics["micro avg"]["precision"],
            "overall_recall": metrics["micro avg"]["recall"],
            "overall_f1": metrics["micro avg"]["f1-score"],
        }

        # Extract entity-specific metrics
        for entity_type in self.allowed_entities:
            entity_key = f"B-{entity_type}"
            if entity_key in metrics:
                results[entity_type] = {
                    "precision": metrics[entity_key]["precision"],
                    "recall": metrics[entity_key]["recall"],
                    "f1": metrics[entity_key]["f1-score"],
                    "support": metrics[entity_key]["support"]
                }

        return results

    def show_detailed_analysis(self):
        """Display classification report using seqeval in the specified format"""
        if not hasattr(self, 'true_labels') or not hasattr(self, 'true_predictions'):
            print("No evaluation results available. Run evaluate_on_dataset() first.")
            return

        # Generate the classification report using seqeval
        report = classification_report(
            self.true_labels,
            self.true_predictions,
            digits=4,
            output_dict=True
        )

        # Print the classification report header
        print("\nEntity      Precision   Recall   F1 score   Support")
        print("-" * 50)

        # Define the order of entities to display
        entity_order = ['LOC', 'ORG', 'PER']

        # Print metrics for each entity type
        for entity in entity_order:
            entity_key = f"{entity}"
            if entity_key in report:
                metrics = report[entity_key]
                print(f"{entity:<10} {metrics['precision']:>9.4f} {metrics['recall']:>8.4f} {metrics['f1-score']:>9.4f} {metrics['support']:>9}")

        # Print overall metrics
        print("\nOverall metrics:")
        print(f"Precision: {report['micro avg']['precision']:.4f}")
        print(f"Recall:    {report['micro avg']['recall']:.4f}")
        print(f"F1-score:  {report['micro avg']['f1-score']:.4f}")
        print(f"Support:   {report['micro avg']['support']}")

        # Print error types by entity
        print("\nError Types by Entity:")
        for entity_type, counts in self.entity_errors.items():
            total = sum(counts.values())
            if total > 0:
                print(f"{entity_type}:")
                for error_type, count in counts.items():
                    print(f"  {error_type}: {count} ({count/total:.1%})")

        # Print example errors
        print("\nExample Errors (first 10):")
        for error in self.error_examples[:10]:
            print(f"Token: {error['token']}")
            print(f"True: {error['true']}, Predicted: {error['predicted']}")
            print(f"Context: {error['context']}")
            print(f"Error Type: {error['error_type']}\n")


# Main execution
if __name__ == "__main__":
    # Load dataset
    dataset = load_dataset("wikiann", "te")

    # Initialize and run NER system
    ner = TeluguNER()
    ner.load_correction_patterns_from_training(dataset)
    metrics = ner.evaluate_on_dataset(limit=1000)
    ner.show_detailed_analysis()

    # Print final results
    print("\n=== Final Evaluation Results ===")
    print(f"Overall F1: {metrics['overall_f1']:.4f}")
    print(f"Precision: {metrics['overall_precision']:.4f}")
    print(f"Recall: {metrics['overall_recall']:.4f}")

    for entity_type in ner.allowed_entities:
        if entity_type in metrics:
            print(f"\n{entity_type} Metrics:")
            print(f"F1: {metrics[entity_type]['f1']:.4f}")
            print(f"Precision: {metrics[entity_type]['precision']:.4f}")
            print(f"Recall: {metrics[entity_type]['recall']:.4f}")
            print(f"Support: {metrics[entity_type]['support']}")

Learning correction patterns from training data...


Analyzing training data: 100%|██████████| 1000/1000 [00:02<00:00, 477.49it/s]


Entity statistics in training data: {'LOC': 493, 'ORG': 347, 'PER': 364}
Created 155 correction patterns from training data
Evaluating on test set...


Evaluating test: 100%|██████████| 1000/1000 [00:02<00:00, 380.85it/s]



Entity      Precision   Recall   F1 score   Support
--------------------------------------------------
LOC           0.2982   0.1889    0.2313       450
ORG           0.1305   0.1824    0.1521       340
PER           0.5357   0.2362    0.3279       381

Overall metrics:
Precision: 0.2554
Recall:    0.2024
F1-score:  0.2258
Support:   1171

Error Types by Entity:
ORG:
  missed: 203 (34.2%)
  wrong_type: 24 (4.0%)
  partial: 366 (61.7%)
LOC:
  missed: 274 (69.9%)
  wrong_type: 4 (1.0%)
  partial: 114 (29.1%)
PER:
  missed: 236 (49.7%)
  wrong_type: 18 (3.8%)
  partial: 221 (46.5%)

Example Errors (first 10):
Token: ల
True: I-ORG, Predicted: O
Context: ప్రపంచ మస్జిద్ ల జాబితా
Error Type: partial

Token: జాబితా
True: I-ORG, Predicted: O
Context: ప్రపంచ మస్జిద్ ల జాబితా
Error Type: partial

Token: EU
True: O, Predicted: B-ORG
Context: EU BY BLR 112
Error Type: other

Token: BY
True: O, Predicted: I-ORG
Context: EU BY BLR 112 బెలారస్
Error Type: other

Token: బెలారస్
True: B-LOC, Predicted: