In [10]:
from IPython.display import display
import csv
from hashlib import md5

from src.retrieve import retrieve_with_scores

file_name_widget = widgets.Text(
    value='question_answer_pairs.csv',
    description='File Name:',
)

load_button = widgets.Button(description="Compute metrics")

k_widget = widgets.BoundedIntText(
    value=10,
    min=1,
    max=50,
    step=1,
    description='K:',
    disabled=False
)

output = widgets.Output()

def load_csv(file_name):
    try:
        with open(file_name, mode='r', encoding='utf-8') as csv_file:
            reader = csv.DictReader(csv_file)
            data = [row for row in reader]
        return data
    except FileNotFoundError:
        return f"Error: File '{file_name}' not found."
    except Exception as e:
        return f"Error: {e}"

def on_button_click(b):
    with output:
        output.clear_output()
        file_name = file_name_widget.value
        questions = load_csv(file_name)
        precision, recall = compute_precision_recall(questions, k_widget.value)
        print("Precision: ", precision)
        print("Recall: ", recall)


def compute_precision_recall(questions, k):
    precision = 0
    recall = 0

    for question in questions:
        results = retrieve_with_scores(question['question'], k, -1)
        content_hashes = [md5(r.chunk.content.encode('utf-8'), usedforsecurity=False).hexdigest() for r in results]

        # This calculation assumes there is exactly one expected chunk
        # to retrieve
        if question['content_hash'] in content_hashes:
            recall += 1
            precision += 1/k

    precision /= len(questions)
    recall /= len(questions)
    return precision, recall


load_button.on_click(on_button_click)
display(file_name_widget, k_widget, load_button, output)

Text(value='question_answer_pairs.csv', description='File Name:')

BoundedIntText(value=10, description='K:', max=50, min=1)

Button(description='Compute metrics', style=ButtonStyle())

Output()