# MasakhaNER Project: Named Entity Recognition for African Languages

## Introduction and Problem Statement

This notebook explores Named Entity Recognition (NER) for low-resource African languages using the MasakhaNER dataset. NER is a fundamental NLP task that identifies and classifies named entities in text into predefined categories such as person names, organizations, locations, and dates.

The MasakhaNER dataset covers 10 African languages: Amharic, Hausa, Igbo, Kinyarwanda, Luganda, Luo, Nigerian-Pidgin, Swahili, Wolof, and Yorùbá. This project addresses the challenge of developing effective NLP tools for languages with limited digital resources.

### Project Goals:
1. Explore and analyze the MasakhaNER dataset
2. Implement preprocessing pipelines for African languages
3. Develop and compare different NER models
4. Evaluate model performance across languages
5. Identify challenges and opportunities for low-resource NLP

### Set Up and Data Loading

In [None]:
# Import required libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import re
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForTokenClassification
from tqdm.notebook import tqdm
import requests
import zipfile
import io

# Set display options
pd.set_option('display.max_columns', None)
sns.set(style='whitegrid')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### Data Loading and Preparation

We'll download the MasakhaNER dataset and load it into our environment. The dataset is structured in CoNLL format, with one token per line and entity tags in BIO format.

In [None]:
def download_masakhaner_dataset():
    """
    Download the MasakhaNER dataset from GitHub if not already present
    
    Returns:
        Path to the dataset directory
    """
    # Define the URL and directory
    url = "https://github.com/masakhane-io/masakhane-ner/archive/refs/heads/main.zip"
    data_dir = "masakhane-ner-data"
    
    # Check if directory already exists
    if os.path.exists(data_dir):
        print(f"Dataset directory already exists at {data_dir}")
        return data_dir
    
    # Download and extract the dataset
    print("Downloading MasakhaNER dataset...")
    response = requests.get(url)
    with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
        # Extract only the data directory to save space
        for file in zip_ref.namelist():
            if file.startswith('masakhane-ner-main/data/'):
                zip_ref.extract(file, '.')
    
    # Rename the directory for easier access
    os.rename('masakhane-ner-main/data', data_dir)
    
    # Clean up
    os.rmdir('masakhane-ner-main')
    
    print(f"Dataset downloaded and extracted to {data_dir}")
    return data_dir

def load_masakhaner_data(data_dir, languages=None, sample_size=None):
    """
    Load the MasakhaNER dataset for specified languages.
    
    Args:
        data_dir: Path to the directory containing MasakhaNER data
        languages: List of languages to load (if None, load all available)
        sample_size: If specified, load only this many examples per language and split
        
    Returns:
        Dictionary mapping language codes to their respective datasets
    """
    if languages is None:
        # All 10 languages in MasakhaNER
        languages = ['amh', 'hau', 'ibo', 'kin', 'lug', 'luo', 'pcm', 'swa', 'wol', 'yor']
    
    datasets = {}
    
    for lang in languages:
        lang_data = []
        
        # Load train, dev, and test sets
        for split in ['train', 'dev', 'test']:
            file_path = os.path.join(data_dir, lang, f"{split}.txt")
            
            if not os.path.exists(file_path):
                print(f"Warning: File not found: {file_path}")
                continue
                
            current_sentence = []
            sentences = []
            
            # Read the CoNLL formatted file
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    
                    if line:
                        # Parse CoNLL format: word, tag
                        parts = line.split()
                        if len(parts) >= 2:
                            word, tag = parts[0], parts[-1]  # Last column is the NER tag
                            current_sentence.append((word, tag))
                    elif current_sentence:
                        # End of sentence
                        sentences.append(current_sentence)
                        current_sentence = []
            
            # Don't forget the last sentence if file doesn't end with an empty line
            if current_sentence:
                sentences.append(current_sentence)
            
            # Sample if needed
            if sample_size and len(sentences) > sample_size:
                sentences = sentences[:sample_size]
            
            # Convert to DataFrame format
            for sentence in sentences:
                lang_data.append({
                    'tokens': [word for word, _ in sentence],
                    'tags': [tag for _, tag in sentence],
                    'split': split
                })
        
        datasets[lang] = pd.DataFrame(lang_data)
        print(f"Loaded {len(datasets[lang])} sentences for {lang}")
    
    return datasets

# Download and load the dataset
data_dir = download_masakhaner_dataset()

# For quick testing, let's use a smaller sample. Remove the sample_size parameter for the full dataset.
# For final analysis, use all data with: datasets = load_masakhaner_data(data_dir)
datasets = load_masakhaner_data(data_dir, sample_size=100)

# Display a sample from Swahili dataset
print("\nSample from Swahili (swa) dataset:")
sample_row = datasets['swa'].iloc[0]
print(f"Tokens: {sample_row['tokens']}")
print(f"Tags: {sample_row['tags']}")
print(f"Split: {sample_row['split']}")

## Exploratory Data Analysis

Now, let's explore the MasakhaNER dataset to understand its characteristics across different languages.

In [None]:
def explore_dataset(datasets):
    """
    Perform exploratory analysis on the MasakhaNER dataset
    
    Args:
        datasets: Dictionary mapping language codes to their respective datasets
        
    Returns:
        Tuple of DataFrames with dataset statistics and entity distribution
    """
    # Overall statistics
    print("Dataset Statistics:")
    stats = {}
    
    for lang, df in datasets.items():
        # Language name mapping for better readability
        lang_names = {
            'amh': 'Amharic', 'hau': 'Hausa', 'ibo': 'Igbo', 
            'kin': 'Kinyarwanda', 'lug': 'Luganda', 'luo': 'Luo',
            'pcm': 'Nigerian-Pidgin', 'swa': 'Swahili', 'wol': 'Wolof', 
            'yor': 'Yorùbá'
        }
        
        # Count sentences, tokens, and entities
        num_sentences = len(df)
        num_tokens = sum(df['tokens'].apply(len))
        
        # Count entities (non-O tags)
        all_tags = [tag for tags in df['tags'] for tag in tags]
        entity_tags = [tag for tag in all_tags if tag != 'O']
        num_entities = len(entity_tags)
        
        # Count entity types
        entity_types = set([tag.split('-')[1] if '-' in tag else tag for tag in entity_tags if tag != 'O'])
        
        # Average sentence length
        avg_sentence_len = num_tokens / num_sentences if num_sentences > 0 else 0
        
        # Store statistics
        stats[lang_names.get(lang, lang)] = {
            'sentences': num_sentences,
            'tokens': num_tokens,
            'entities': num_entities,
            'entity_density': num_entities / num_tokens if num_tokens > 0 else 0,
            'entity_types': ', '.join(sorted(entity_types)) if entity_types else 'None',
            'avg_sentence_len': avg_sentence_len
        }
    
    stats_df = pd.DataFrame(stats).T
    
    # Format the columns for better readability
    stats_df['entity_density'] = stats_df['entity_density'].map('{:.2%}'.format)
    stats_df['avg_sentence_len'] = stats_df['avg_sentence_len'].map('{:.1f}'.format)
    
    print(stats_df)
    
    # Visualize entity type distribution
    entity_counts = {}
    
    for lang, df in datasets.items():
        lang_name = {'amh': 'Amharic', 'hau': 'Hausa', 'ibo': 'Igbo', 
                    'kin': 'Kinyarwanda', 'lug': 'Luganda', 'luo': 'Luo',
                    'pcm': 'Nigerian-Pidgin', 'swa': 'Swahili', 'wol': 'Wolof', 
                    'yor': 'Yorùbá'}.get(lang, lang)
        
        # Get all entity tags and count them by type
        all_tag_types = []
        for tags in df['tags']:
            for tag in tags:
                if tag != 'O':  # Skip non-entity tags
                    # Extract the entity type (e.g., PER from B-PER)
                    entity_type = tag.split('-')[1] if '-' in tag else tag
                    all_tag_types.append(entity_type)
        
        entity_counts[lang_name] = Counter(all_tag_types)
    
    # Convert to DataFrame for easier plotting
    entity_df = pd.DataFrame(entity_counts).fillna(0)
    
    # Plot entity type distribution
    plt.figure(figsize=(14, 8))
    entity_df.plot(kind='bar', stacked=True)
    plt.title('Entity Type Distribution Across Languages')
    plt.ylabel('Count')
    plt.xlabel('Entity Type')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Plot entity density by language
    plt.figure(figsize=(12, 6))
    entity_density = pd.Series({k: float(v['entity_density'].strip('%'))/100 
                               for k, v in stats_df['entity_density'].items()})
    entity_density.sort_values().plot(kind='bar')
    plt.title('Entity Density by Language')
    plt.ylabel('Entity Density (entities/token)')
    plt.xlabel('Language')
    plt.tight_layout()
    plt.show()
    
    # Plot average sentence length by language
    plt.figure(figsize=(12, 6))
    avg_sent_len = pd.Series({k: float(v) for k, v in stats_df['avg_sentence_len'].items()})
    avg_sent_len.sort_values().plot(kind='bar')
    plt.title('Average Sentence Length by Language')
    plt.ylabel('Average Tokens per Sentence')
    plt.xlabel('Language')
    plt.tight_layout()
    plt.show()
    
    return stats_df, entity_df

# Run exploratory data analysis
stats_df, entity_df = explore_dataset(datasets)

# Additional analysis: Look at tag distribution for a specific language
def analyze_entity_distribution(df, lang_name):
    """Analyze entity distribution for a specific language"""
    all_tags = [tag for tags in df['tags'] for tag in tags]
    tag_counts = Counter(all_tags)
    
    # Create a DataFrame for better display
    tag_df = pd.DataFrame({
        'Tag': list(tag_counts.keys()),
        'Count': list(tag_counts.values())
    })
    tag_df = tag_df.sort_values('Count', ascending=False).reset_index(drop=True)
    
    # Calculate percentages
    total = tag_df['Count'].sum()
    tag_df['Percentage'] = tag_df['Count'] / total * 100
    
    print(f"\nEntity tag distribution for {lang_name}:")
    print(tag_df)
    
    # Plot the distribution
    plt.figure(figsize=(12, 6))
    ax = tag_df.head(10).plot(kind='bar', x='Tag', y='Count')
    plt.title(f'Top 10 Entity Tags for {lang_name}')
    plt.ylabel('Count')
    plt.xlabel('Tag')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    return tag_df

# Analyze entity distribution for Swahili
swahili_tag_df = analyze_entity_distribution(datasets['swa'], 'Swahili')

# Analyze token distribution (how many unique tokens)
def analyze_token_distribution(datasets):
    """Analyze token distributions across languages"""
    token_stats = {}
    
    for lang, df in datasets.items():
        lang_name = {'amh': 'Amharic', 'hau': 'Hausa', 'ibo': 'Igbo', 
                    'kin': 'Kinyarwanda', 'lug': 'Luganda', 'luo': 'Luo',
                    'pcm': 'Nigerian-Pidgin', 'swa': 'Swahili', 'wol': 'Wolof', 
                    'yor': 'Yorùbá'}.get(lang, lang)
        
        # Get all tokens
        all_tokens = [token for tokens in df['tokens'] for token in tokens]
        unique_tokens = set(all_tokens)
        
        token_stats[lang_name] = {
            'total_tokens': len(all_tokens),
            'unique_tokens': len(unique_tokens),
            'vocabulary_ratio': len(unique_tokens) / len(all_tokens) if all_tokens else 0
        }
    
    token_stats_df = pd.DataFrame(token_stats).T
    token_stats_df['vocabulary_ratio'] = token_stats_df['vocabulary_ratio'].map('{:.2%}'.format)
    
    print("\nToken Distribution Statistics:")
    print(token_stats_df)
    
    # Plot vocabulary ratio by language
    plt.figure(figsize=(12, 6))
    vocab_ratio = pd.Series({k: float(v.strip('%'))/100 
                            for k, v in token_stats_df['vocabulary_ratio'].items()})
    vocab_ratio.sort_values().plot(kind='bar')
    plt.title('Vocabulary Ratio by Language (unique tokens/total tokens)')
    plt.ylabel('Vocabulary Ratio')
    plt.xlabel('Language')
    plt.tight_layout()
    plt.show()
    
    return token_stats_df

# Analyze token distribution
token_stats_df = analyze_token_distribution(datasets)