# Imports and Definitions

In [None]:
import pandas as pd
from glob import glob
from tqdm import tqdm
import os
import random
import urllib

In [None]:
root = "/scratch/tsoares/wikidumps/simplewiki-NS0-20231001/processed_data/"

In [None]:
def unencode_title(title):
    clean_title = urllib.parse.unquote(title).replace('_', ' ')
    return clean_title

# Load data

In [None]:
link_files = glob(os.path.join(root, "good_links*"))
page_files = glob(os.path.join(root, "good_pages*"))
link_files.sort()
page_files.sort()

In [None]:
print(link_files)
print(page_files)

In [None]:
dfs = []
for file in link_files:
    dfs.append(pd.read_parquet(file))
df_links = pd.concat(dfs)
df_links = df_links.sample(100_000).reset_index(drop=True)
df_links['source_title'] = df_links['source_title'].apply(unencode_title)
df_links['target_title'] = df_links['target_title'].apply(unencode_title)
df_links

In [None]:
dfs = []
for file in page_files:
    dfs.append(pd.read_parquet(file, columns=['title', 'lead_paragraph']))
df_pages = pd.concat(dfs)
df_pages['title'] = df_pages['title'].apply(unencode_title)
df_pages

In [None]:
df_links = df_links.to_dict(orient='records')
df_pages = df_pages.to_dict(orient='records')

In [None]:
mention_map = pd.read_parquet(os.path.join(root, "mention_map.parquet"))
mention_map = mention_map.to_dict(orient='records')
entity_map = {}
for row in mention_map:
    title = unencode_title(row['target_title'])
    mention = row['mention']    
    if title in entity_map:
        entity_map[title].add(mention)
    else:
        entity_map[title] = set([mention])
entity_map

# Create auxiliary data structures

In [None]:
source_to_all_targets = {}
target_to_all_sources = {}
for link in tqdm(df_links):
    source = link['source_title']
    target = link['target_title']
    source_section = link['source_section'].split('<sep>')[0]
    if source not in source_to_all_targets:
        source_to_all_targets[source] = []
    source_to_all_targets[source].append(target)
    if target not in target_to_all_sources:
        target_to_all_sources[target] = []
    target_to_all_sources[target].append(source)

In [None]:
page_leads = {}
for page in tqdm(df_pages):
    title = page['title']
    lead = page['lead_paragraph']
    page_leads[title] = lead

# Set-up positive samples

In [None]:
positive_samples = []
for row in tqdm(df_links):
    sample = {}
    sample['source_title'] = row['source_title']
    sample['source_lead'] = page_leads[sample['source_title']]
    sample['target_title'] = row['target_title']
    sample['target_lead'] = page_leads[sample['target_title']]
    sample['link_context'] = row['context']
    sample['source_section'] = row['source_section'].split('<sep>')[0]
    sample['label'] = 1

    positive_samples.append(sample)

# Set-up negative samples

## Define hyper-parameters

In [None]:
negative_strategies = {
    'easy_replace_source': True,
    'easy_replace_target': True,
    'hard_replace_source': True,
    'hard_replace_target': True,
    'replace_context': True
}
negative_samples_per_positive = 50

## Build negative samples from positive ones

In [None]:
strategies = [key for key in negative_strategies if negative_strategies[key]]
strategies

In [None]:
negative_samples = []
for i in tqdm(range(len(positive_samples))):
    valid_strategies = strategies.copy()
    if len(source_to_all_targets[positive_samples[i]['source_title']]) == 1:
        valid_strategies.remove('hard_replace_target')
    if len(target_to_all_sources[positive_samples[i]['target_title']]) == 1:
        valid_strategies.remove('hard_replace_source')
    list_strategies = random.choices(valid_strategies, k=negative_samples_per_positive)
    new_samples = []
    for strategy in list_strategies:
        if strategy == 'easy_replace_source':
            new_source = random.choices(positive_samples, k=1)[0]['source_title']
            while new_source in target_to_all_sources[positive_samples[i]['target_title']]:
                new_source = random.choices(positive_samples, k=1)[0]['source_title']
            new_sample = positive_samples[i].copy()
            new_sample['source_title'] = new_source
            new_sample['source_lead'] = page_leads[new_source]
            new_sample['neg_type'] = 'easy_replace_source'
        elif strategy == 'easy_replace_target':
            new_target = random.choices(positive_samples, k=1)[0]['target_title']
            while new_target in source_to_all_targets[positive_samples[i]['source_title']]:
                new_target = random.choices(positive_samples, k=1)[0]['target_title']
            new_sample = positive_samples[i].copy()
            new_sample['target_title'] = new_target
            new_sample['target_lead'] = page_leads[new_target]
            new_sample['neg_type'] = 'easy_replace_target'
        elif strategy == 'hard_replace_source':
            new_source_section = random.choices(target_to_all_sources[positive_samples[i]['target_title']], k=1)[0]
            new_sample = positive_samples[i].copy()
            new_sample['source_title'] = new_source_section
            new_sample['source_lead'] = page_leads[new_source_section]
            new_sample['neg_type'] = 'hard_replace_source'
        elif strategy == 'hard_replace_target':
            safe_targets = []
            for target in source_to_all_targets[positive_samples[i]['source_title']]:
                found = False
                for mention in entity_map[target]:
                    if mention in positive_samples[i]['link_context']:
                        found = True
                        break
                if not found:
                    safe_targets.append(target)
            if len(safe_targets) == 0:
                new_target = random.choices(positive_samples, k=1)[0]['target_title']
                while new_target in source_to_all_targets[positive_samples[i]['source_title']]:
                    new_target = new_target = random.choices(positive_samples, k=1)[0]['target_title']
            else:
                new_target = random.choices(safe_targets, k=1)[0]
            new_sample = positive_samples[i].copy()
            new_sample['target_title'] = new_target
            new_sample['target_lead'] = page_leads[new_target]
            new_sample['neg_type'] = 'hard_replace_target'
        elif strategy == 'replace_context':
            new_sample = positive_samples[i].copy()
            new_context = random.choices(positive_samples, k=1)[0]['link_context']
            mention_words = []
            for mention in entity_map[new_sample['target_title']]:
                mention_words.append(mention)
            while True:
                found = False
                for mention in entity_map[new_sample['target_title']]:
                    if mention in new_context:
                        found = True
                        break
                if not found:
                    break
                new_context = random.choices(positive_samples, k=1)[0]['link_context']
            new_sample['link_context'] = new_context
            new_sample['neg_type'] = 'replace_context'
        new_sample['label'] = 0
        new_samples.append(new_sample)
    negative_samples.extend(new_samples)    

In [None]:
df = pd.DataFrame(positive_samples + negative_samples)
df = df.sample(frac=1).reset_index(drop=True)
df

In [None]:
train_df = df.sample(frac=0.8)
val_df = df.drop(train_df.index).sample(frac=0.5)
test_df = df.drop(train_df.index).drop(val_df.index)

In [None]:
for page in df_pages:
    if page['title'] == '1934':
        print(page['lead_paragraph'])

In [None]:
for link in df_links:
    if link['source_title'] == '1934' or link['target_title'] == '1934':
        print(link)

In [None]:
for mention in mention_map:
    if '1934' == mention['mention'] or '1934' == mention['target_title']:
        print(mention)