In [70]:
kb_path: str = r"C:\Users\t-isazawat\Downloads\7178_KBLaM_Knowledge_Base_augm_Supplementary Material\submission files\knowledge bases\enron.json"

In [71]:
import sys
import os
sys.path.append("./src")
from kblam.utils.data_utils import load_entities

entities = load_entities(kb_path)[0] # TODO: Fix this as this is a hack to account for load_entities being wrongly implemented

In [73]:
import bm25s
import numpy as np

class Statistics:
    def __init__(self):
        self.top_1_correct_count = 0
        self.top_5_correct_count = 0
        self.total_count = 0

    def update(self, retrieved_entity_indices, correct_index):
        self.total_count += 1
        if correct_index in retrieved_entity_indices[:1]:
            self.top_1_correct_count += 1
        if correct_index in retrieved_entity_indices[:5]:
            self.top_5_correct_count += 1

    def __str__(self):
        return f"Top-1 accuracy: {self.top_1_correct_count / self.total_count}, Top-5 accuracy: {self.top_5_correct_count / self.total_count}"

def create_retriever(entities):
    corpus = [f"The {entity['description_type']} of {entity['name']} is {entity['description']}" for entity in entities]
    corpus_tokens = bm25s.tokenize(corpus, stopwords="en")
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    return retriever

def experiment_iteration(entities, num_entities, num_queries, stats = None):
    if num_entities > len(entities):
        raise ValueError("num_entities must be less than the number of entities in the dataset")
    if num_queries > num_entities:
        raise ValueError("num_queries must be less than or equal to num_entities")
    sampled_entities = np.random.choice(entities, num_entities, replace=False)
    retriever = create_retriever(sampled_entities)
    queries = [entity["Q"] for entity in sampled_entities]

    tokenized_queries = bm25s.tokenize(queries, stopwords="en")
    retrieved = retriever.retrieve(tokenized_queries, k=5)

    if stats is None:
        stats = Statistics()
    for index, results in enumerate(retrieved[0]):
        stats.update(results, index)

    return stats

def experiment(entities, num_entities, num_iterations, num_queries_per_iteration):
    stats = Statistics()
    total_experiments = num_iterations * num_queries_per_iteration
    print(f"Running experiment with the following parameters: {num_entities}/{len(entities)} entities, {num_iterations} iterations, {num_queries_per_iteration} queries per iteration, {total_experiments} total queries being run")
    for _ in range(num_iterations):
        stats = experiment_iteration(entities, num_entities, num_queries_per_iteration, stats)
    return stats

In [74]:
def run_experiment(entities):
    num_entities = [50, 100, 200, 400, 800, 1600, 3200, 6400]
    all_stats = []
    for num_entity in num_entities:
        stats = experiment(entities, num_entity, int(1000/50) + 1, 50)
        all_stats.append(stats)
        print(f"Num entities: {num_entity}, {stats}")
    return all_stats

all_stats = run_experiment(entities)

Running experiment with the following parameters: 50/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                            

Num entities: 50, Top-1 accuracy: 0.9980952380952381, Top-5 accuracy: 0.9990476190476191
Running experiment with the following parameters: 100/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                             

Num entities: 100, Top-1 accuracy: 0.9890476190476191, Top-5 accuracy: 0.9990476190476191
Running experiment with the following parameters: 200/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                             

Num entities: 200, Top-1 accuracy: 0.981904761904762, Top-5 accuracy: 0.9988095238095238
Running experiment with the following parameters: 400/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                             

Num entities: 400, Top-1 accuracy: 0.9678571428571429, Top-5 accuracy: 0.9978571428571429
Running experiment with the following parameters: 800/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                             

Num entities: 800, Top-1 accuracy: 0.9480952380952381, Top-5 accuracy: 0.9956547619047619
Running experiment with the following parameters: 1600/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                                      

Num entities: 1600, Top-1 accuracy: 0.8867857142857143, Top-5 accuracy: 0.990625
Running experiment with the following parameters: 3200/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                                      

Num entities: 3200, Top-1 accuracy: 0.8625595238095238, Top-5 accuracy: 0.9794047619047619
Running experiment with the following parameters: 6400/29096 entities, 21 iterations, 50 queries per iteration, 1050 total queries being run


                                                                            

Num entities: 6400, Top-1 accuracy: 0.8009300595238096, Top-5 accuracy: 0.9596205357142857


In [77]:
top_1_acc = []
top_5_acc = []
for stat in all_stats:
    top_1_acc.append(stat.top_1_correct_count / stat.total_count)
    top_5_acc.append(stat.top_5_correct_count / stat.total_count)

In [78]:
print(list(list(float('%.3g' % i) for i in el) for el in zip(top_1_acc, top_5_acc)))

[[0.998, 0.999], [0.989, 0.999], [0.982, 0.999], [0.968, 0.998], [0.948, 0.996], [0.887, 0.991], [0.863, 0.979], [0.801, 0.96]]


In [76]:
print("Synthetic:")
num_entities = [50, 100, 200, 400, 800, 1600, 3200, 6400]
print(list(zip(num_entities, list(float('%.3g' % i) for i in top_1_acc))))
print(list(zip(num_entities, list(float('%.3g' % i) for i in top_5_acc))))

Synthetic:
[(50, 0.992), (100, 0.999), (200, 0.994), (400, 0.992), (800, 0.988), (1600, 0.969), (3200, 0.95), (6400, 0.924)]
[(50, 1.0), (100, 1.0), (200, 1.0), (400, 1.0), (800, 1.0), (1600, 1.0), (3200, 1.0), (6400, 0.999)]


In [79]:
print("Enron:")
num_entities = [50, 100, 200, 400, 800, 1600, 3200, 6400]
print(list(zip(num_entities, list(float('%.3g' % i) for i in top_1_acc))))
print(list(zip(num_entities, list(float('%.3g' % i) for i in top_5_acc))))

Enron:
[(50, 0.998), (100, 0.989), (200, 0.982), (400, 0.968), (800, 0.948), (1600, 0.887), (3200, 0.863), (6400, 0.801)]
[(50, 0.999), (100, 0.999), (200, 0.999), (400, 0.998), (800, 0.996), (1600, 0.991), (3200, 0.979), (6400, 0.96)]
