In [65]:
import pandas as pd
import numpy as np
import sys
import warnings
import re
from collections import defaultdict
import json
from tqdm import tqdm

Loading Dataset

In [66]:
base_url = "hf://datasets/cyanic-selkie/aida-conll-yago-wikidata/"
splits = {'train': 'train.parquet', 'validation': 'validation.parquet', 'test': 'test.parquet'}

In [67]:
df_train = pd.read_parquet(base_url + splits['train'])
df_val = pd.read_parquet(base_url + splits['validation'])
df_test = pd.read_parquet(base_url + splits['test'])

In [68]:
print(f"Train: {len(df_train)} examples")
print(f"Validation: {len(df_val)} examples")
print(f"Test: {len(df_test)} examples")

Train: 946 examples
Validation: 216 examples
Test: 231 examples


Showing Original Example

In [69]:
print(f"Train shape: {df_train.shape}")
print(f"Columns: {df_train.columns.tolist()}")
print(f"\nFirst row text length: {len(df_train.iloc[0]['text'])}")
print(f"First row entities: {len(df_train.iloc[0]['entities'])}")

if len(df_train.iloc[0]['entities']) > 0:
    print(f"First entity: {df_train.iloc[0]['entities'][0]}")

Train shape: (946, 3)
Columns: ['document_id', 'text', 'entities']

First row text length: 2790
First row entities: 48
First entity: {'start': 0, 'end': 2, 'tag': 'ORG', 'pageid': None, 'qid': None, 'title': None}


### Preprocessing

Apply preprocessing to a sample

In [70]:
df_sample = df_train.head(100).copy()

Step 1: Filter valid Entities (if needed)

In [71]:
def filter_valid_entities(row):
    entities = row['entities'] if 'entities' in row and row['entities'] is not None else []
    
    valid_entities = []
    for entity in entities:
        if entity.get('qid') is not None or entity.get('pageid') is not None:
            valid_entities.append(entity)
    
    return valid_entities

In [72]:
def apply_filter_valid_entities(df, inplace=False):
    if not inplace:
        df = df.copy()
    
    print("Filtering entities without KB links...")
    df['entities'] = df.apply(filter_valid_entities, axis=1)
    df['num_valid_entities'] = df['entities'].apply(len)
    
    return df

In [73]:
#df_sample = apply_filter_valid_entities(df_sample, inplace=True)

Step 2: Add context to entities

In [74]:
def extract_mention_with_context(text, start, end, context_window=50):
    mention = text[start:end]
    
    # Get left and right context
    left_start = max(0, start - context_window)
    right_end = min(len(text), end + context_window)
    
    left_context = text[left_start:start]
    right_context = text[end:right_end]
    full_context = text[left_start:right_end]
    
    return {
        'mention': mention,
        'left_context': left_context,
        'right_context': right_context,
        'full_context': full_context,
        'mention_start': start,
        'mention_end': end
    }

In [75]:
def add_context_to_entities(row, context_window=50):
    text = row['text']
    entities = row.get('entities', [])
    
    if entities is None or len(entities) == 0:
        return []
    
    entities_with_context = []
    for entity in entities:
        entity_copy = entity.copy()
        context_info = extract_mention_with_context(
            text, 
            entity['start'], 
            entity['end'], 
            context_window
        )
        entity_copy.update(context_info)
        entities_with_context.append(entity_copy)
    
    return entities_with_context

In [76]:
def apply_add_context(df, context_window=50, inplace=False):
    if inplace is False:
        df = df.copy()
    
    print(f"Adding context (window={context_window}) to entities...")
    df['entities'] = df.apply(
        lambda row: add_context_to_entities(row, context_window), 
        axis=1
    )
    
    return df

In [77]:
df_sample = apply_add_context(df_sample, context_window=200, inplace=False)

Adding context (window=200) to entities...


Step 3: Normalize mentions

In [78]:
def normalize_mention(mention):
    # Remove extra whitespace
    normalized = re.sub(r'\s+', ' ', mention).strip()
    
    # Remove leading/trailing punctuation
    normalized = re.sub(r'^[^\w]+|[^\w]+$', '', normalized)
    
    return normalized

In [79]:
def add_normalized_mentions(row):
    entities = row.get('entities', [])
    
    if not entities:
        return []
    
    entities_normalized = []
    for entity in entities:
        entity_copy = entity.copy()
        if 'mention' in entity:
            entity_copy['normalized_mention'] = normalize_mention(entity['mention'])
        entities_normalized.append(entity_copy)
    
    return entities_normalized

In [80]:
def apply_normalize_mentions(df, inplace=False):
    if not inplace:
        df = df.copy()
    
    print("Normalizing entity mentions...")
    df['entities'] = df.apply(add_normalized_mentions, axis=1)
    
    return df

In [81]:
df_sample = apply_normalize_mentions(df_sample, inplace=True)

Normalizing entity mentions...


Step 4: Remove overlapping entities

In [82]:
def remove_overlapping_entities(row):
    entities = row.get('entities', [])
    if not entities:
        return [], 0
    
    # Sort by start position, then by length (descending)
    sorted_entities = sorted(
        entities, 
        key=lambda e: (e['start'], -(e['end'] - e['start']))
    )
    
    non_overlapping = []
    last_end = -1
    
    for entity in sorted_entities:
        if entity['start'] >= last_end:
            non_overlapping.append(entity)
            last_end = entity['end']
    
    num_removed = len(entities) - len(non_overlapping)
    
    return non_overlapping, num_removed

In [83]:
def apply_remove_overlaps(df, inplace=False):
    if not inplace:
        df = df.copy()
    
    print("Removing overlapping entities...")
    result = df.apply(remove_overlapping_entities, axis=1)
    df['entities'] = result.apply(lambda x: x[0])
    df['removed_overlaps'] = result.apply(lambda x: x[1])
    
    total_removed = df['removed_overlaps'].sum()
    print(f"  Removed {total_removed} overlapping entities")
    
    return df

In [84]:
df_sample = apply_remove_overlaps(df_sample, inplace=True)

Removing overlapping entities...
  Removed 0 overlapping entities


Step 5: Create candidate pairs

In [85]:
def create_mention_candidate_pairs(row, max_candidates=10):
    
    pairs = []
    
    for entity in row.get('entities', []):
        pair = {
            'mention': entity.get('mention', ''),
            'context': entity.get('full_context', ''),
            'entity_type': entity.get('tag', ''),
            'true_qid': entity.get('qid'),
            'true_pageid': entity.get('pageid'),
            'true_title': entity.get('title'),
            # we'd generate candidates here
            'candidates': []  #for candidate entities
        }
        pairs.append(pair)
    
    return pairs

In [86]:
def apply_create_candidate_pairs(df, max_candidates=10, inplace=False):
    if not inplace:
        df = df.copy()
    
    print("Creating mention-candidate pairs...")
    df['mention_candidate_pairs'] = df.apply(
        lambda row: create_mention_candidate_pairs(row, max_candidates),
        axis=1
    )
    
    return df

In [87]:
df_sample = apply_create_candidate_pairs(df_sample, inplace=True)

Creating mention-candidate pairs...


Step 6: Create NIL detection examples

In [88]:
def create_nil_detection_examples(row):
    nil_examples = []
    linked_examples = []
    
    for entity in row.get('entities', []):
        entity_example = {
            'mention': entity.get('mention', ''),
            'context': entity.get('full_context', ''),
            'entity_type': entity.get('tag', ''),
            'is_nil': entity.get('qid') is None and entity.get('pageid') is None
        }
        
        if entity_example['is_nil']:
            nil_examples.append(entity_example)
        else:
            linked_examples.append(entity_example)
    
    return nil_examples, linked_examples

In [89]:
def apply_create_nil_examples(df, inplace=False):
    """
    Create NIL detection examples for entire DataFrame.
    """
    if not inplace:
        df = df.copy()
    
    print("Creating NIL detection examples...")
    result = df.apply(create_nil_detection_examples, axis=1)
    df['nil_examples'] = result.apply(lambda x: x[0])
    df['linked_examples'] = result.apply(lambda x: x[1])
    
    total_nil = df['nil_examples'].apply(len).sum()
    total_linked = df['linked_examples'].apply(len).sum()
    print(f"  NIL entities: {total_nil}, Linked entities: {total_linked}")
    
    return df

In [90]:
df_sample = apply_create_nil_examples(df_sample, inplace=True)

Creating NIL detection examples...
  NIL entities: 504, Linked entities: 2139


tep 7: Split long documents (if needed)

In [91]:
def split_long_documents(row, max_length=512, overlap=50):
    text = row['text']
    entities = row.get('entities', [])
    
    if len(text) <= max_length:
        return [row.to_dict()]  # No splitting needed
    
    chunks = []
    start_pos = 0
    chunk_id = 0
    
    while start_pos < len(text):
        end_pos = min(start_pos + max_length, len(text))
        chunk_text = text[start_pos:end_pos]
        
        # Find entities in this chunk
        chunk_entities = []
        for entity in entities:
            if entity['start'] >= start_pos and entity['end'] <= end_pos:
                # Adjust entity positions relative to chunk
                adjusted_entity = entity.copy()
                adjusted_entity['start'] = entity['start'] - start_pos
                adjusted_entity['end'] = entity['end'] - start_pos
                chunk_entities.append(adjusted_entity)
        
        chunk = {
            'text': chunk_text,
            'entities': chunk_entities,
            'chunk_start': start_pos,
            'chunk_end': end_pos,
            'chunk_id': chunk_id,
            'is_chunk': True
        }
        chunks.append(chunk)
        
        chunk_id += 1
        # Move to next chunk with overlap
        start_pos = end_pos - overlap
        if start_pos >= len(text) - overlap:
            break
    
    return chunks

In [92]:
def apply_split_long_documents(df, max_length=512, overlap=50):
    all_chunks = []
    for idx, row in df.iterrows():
        chunks = split_long_documents(row, max_length, overlap)
        for chunk in chunks:
            # Preserve original index info
            chunk['original_index'] = idx
            all_chunks.append(chunk)
    
    df_chunks = pd.DataFrame(all_chunks)
    
    print(f"  Original docs: {len(df)}, After splitting: {len(df_chunks)}")
    
    return df_chunks

In [93]:
#df_sample = apply_split_long_documents(df_sample, max_length=512)

In [94]:
print(f"Processed shape: {df_sample.shape}")
print(f"New columns: {df_sample.columns.tolist()}")

Processed shape: (100, 7)
New columns: ['document_id', 'text', 'entities', 'removed_overlaps', 'mention_candidate_pairs', 'nil_examples', 'linked_examples']


Export the Sample

In [95]:
df_sample.to_parquet("preprocessed_sample_aida.parquet", index=False)

In [96]:
#x = df_sample[['mention_candidate_pairs', 'nil_examples', 'linked_examples']].head(1)

In [97]:
# df_mcp = x['mention_candidate_pairs'].explode().apply(pd.Series)
# df_mcp