In [None]:
import pandas as pd
from tqdm import tqdm

In [None]:
df_train_synth = pd.read_parquet('/dlabdata1/tsoares/linkrec-llms/data_modelling/training_ranking/datasets/simple_stage_1/train')
df_train_real = pd.read_parquet('/dlabdata1/tsoares/linkrec-llms/data_modelling/training_ranking/datasets/simple_stage_2/train')

In [None]:
for column in df_train_synth:
    if 'index' in column:
        df_train_synth[column] = df_train_synth[column].astype(int)
df_train_synth

In [None]:
print(len(df_train_real))
df_train_real = df_train_real[df_train_real['link_context'] != '']
print(len(df_train_real))
df_train_real

In [None]:
df_train_real = df_train_real.to_dict('records')
df_train_synth = df_train_synth.to_dict('records')

In [None]:
lengths = {}

In [None]:
lengths['real'] = {'text_present': [], 'missing_mention': [], 'missing_sentence': [], 'missing_span': []}
for row in tqdm(df_train_real):
    if row['missing_category'] == 'present':
        lengths['real']['text_present'].append(len(row['link_context']))
    if row['missing_category'] == 'missing_mention':
        lengths['real']['missing_mention'].append(len(row['link_context']))
    if row['missing_category'] == 'missing_sentence':
        lengths['real']['missing_sentence'].append(len(row['link_context']))
    if row['missing_category'] == 'missing_span':
        lengths['real']['missing_span'].append(len(row['link_context']))

In [None]:
lengths['synth'] = {'text_present': [], 'missing_mention': [], 'missing_sentence': [], 'missing_span': []}
for row in tqdm(df_train_synth):
    lengths['synth']['text_present'].append(len(row['link_context']))
    if (row['link_context'][:row['context_mention_start_index']] + row['link_context'][row['context_mention_end_index']:]).strip() != '':
        lengths['synth']['missing_mention'].append(len(row['link_context'][:row['context_mention_start_index']] + row['link_context'][row['context_mention_end_index']:]))
    if (row['link_context'][:row['context_sentence_start_index']] + row['link_context'][row['context_sentence_end_index']:]).strip() != '':
        lengths['synth']['missing_sentence'].append(len(row['link_context'][:row['context_sentence_start_index']] + row['link_context'][row['context_sentence_end_index']:]))
    if (row['link_context'][:row['context_span_start_index']] + row['link_context'][row['context_span_end_index']:]).strip() != '':
        lengths['synth']['missing_span'].append(len(row['link_context'][:row['context_span_start_index']] + row['link_context'][row['context_span_end_index']:]))

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Length of Link Contexts')
# plot both histograms for each category in the plots
# normalize the histograms
# compute the bins
max_val = max(max(lengths['real']['text_present']), max(lengths['synth']['text_present']))
bins = np.linspace(0, max_val, 50)
axs[0, 0].hist(lengths['real']['text_present'], bins=bins, density=True, alpha=0.6, label='Real')
axs[0, 0].hist(lengths['synth']['text_present'], bins=bins, density=True, alpha=0.6, label='Synthetic')
axs[0, 0].set_title('Text Present')
axs[0, 0].legend(loc='upper right')

max_val = max(max(lengths['real']['missing_mention']), max(lengths['synth']['missing_mention']))
bins = np.linspace(0, max_val, 50)
axs[0, 1].hist(lengths['real']['missing_mention'], bins=bins, density=True, alpha=0.6, label='Real')
axs[0, 1].hist(lengths['synth']['missing_mention'], bins=bins, density=True, alpha=0.6, label='Synthetic')
axs[0, 1].set_title('Missing Mention')
axs[0, 1].legend(loc='upper right')

max_val = max(max(lengths['real']['missing_sentence']), max(lengths['synth']['missing_sentence']))
bins = np.linspace(0, max_val, 50)
axs[1, 0].hist(lengths['real']['missing_sentence'], bins=bins, density=True, alpha=0.6, label='Real')
axs[1, 0].hist(lengths['synth']['missing_sentence'], bins=bins, density=True, alpha=0.6, label='Synthetic')
axs[1, 0].set_title('Missing Sentence')
axs[1, 0].legend(loc='upper right')

max_val = max(max(lengths['real']['missing_span']), max(lengths['synth']['missing_span']))
bins = np.linspace(0, max_val, 50)
axs[1, 1].hist(lengths['real']['missing_span'], bins=bins, density=True, alpha=0.6, label='Real')
axs[1, 1].hist(lengths['synth']['missing_span'], bins=bins, density=True, alpha=0.6, label='Synthetic')
axs[1, 1].set_title('Missing Span')
axs[1, 1].legend(loc='upper right')

In [None]:
pd.DataFrame(lengths['real']['missing_span']).describe()

In [None]:
pd.DataFrame(lengths['synth']['missing_span']).describe()