In [1]:
import os
import psycopg2
import numpy as np

N_values = ["10k", "1m"]
K_values = [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
MAX_QUERIES = 100

# Establish connection
db_connection_string = os.environ.get('DATABASE_URL')
conn = psycopg2.connect(db_connection_string)
cur = conn.cursor()

# Create index if needed
print("creating indices")
with open('./create_indices.sql', 'r') as file:
    sql = file.read()
    cur.execute(sql)
print("created indices")

for N in N_values:
    OUTPUT_FILE = f"recall/{N}.txt"
    print(f"about to calculate Recall for N={N}")

    with open(OUTPUT_FILE, "w") as f:
        for K in K_values:
            # Get all query ids
            cur.execute(f"SELECT id FROM sift_query{N}")
            query_ids = cur.fetchall()[:MAX_QUERIES]

            recall_at_k_sum = 0
            for query_id in query_ids:
                query_id = query_id[0]  # tuple unpacking
                cur.execute(f"""
                    SELECT
                        CARDINALITY(ARRAY(SELECT UNNEST(base_ids) INTERSECT SELECT UNNEST(truth_ids)))
                    FROM 
                    (
                        SELECT 
                            q.id AS query_id,
                            (SELECT ARRAY_AGG(b.id ORDER BY q.v <-> b.v) FROM sift_base{N} b LIMIT {K}) AS base_ids,
                            t.indices[1:{K}] AS truth_ids
                        FROM 
                            sift_query{N} q
                        JOIN 
                            sift_truth{N} t
                        ON 
                            q.id = t.id
                    ) subquery
                    WHERE
                        query_id = {query_id}
                """)
                recall_query = cur.fetchone()[0]

                print(f"Recall@{K} for query_id {query_id}: {recall_query}")
                recall_at_k_sum += int(recall_query)

            # Calculate the average recall for this K
            recall_at_k = recall_at_k_sum / len(query_ids)
            f.write(f"Recall@{K}: {recall_at_k}\n")
            print(f"Recall@{K}: {recall_at_k}")
            

        print(f"Completed all recall for {N}")

cur.close()
conn.close()

creating indices
created indices
about to calculate Recall for N=10k
Recall@1 for query_id 1: 1.00000000000000000000
Recall@1 for query_id 2: 1.00000000000000000000
Recall@1 for query_id 3: 1.00000000000000000000
Recall@1 for query_id 4: 1.00000000000000000000
Recall@1 for query_id 5: 1.00000000000000000000
Recall@1 for query_id 6: 1.00000000000000000000
Recall@1 for query_id 7: 1.00000000000000000000
Recall@1 for query_id 8: 1.00000000000000000000
Recall@1 for query_id 9: 1.00000000000000000000
Recall@1 for query_id 10: 1.00000000000000000000
Recall@1 for query_id 11: 1.00000000000000000000
Recall@1 for query_id 12: 1.00000000000000000000
Recall@1 for query_id 13: 1.00000000000000000000
Recall@1 for query_id 14: 1.00000000000000000000
Recall@1 for query_id 15: 1.00000000000000000000
Recall@1 for query_id 16: 1.00000000000000000000
Recall@1 for query_id 17: 1.00000000000000000000
Recall@1 for query_id 18: 1.00000000000000000000
Recall@1 for query_id 19: 1.00000000000000000000
Recall@1 

KeyboardInterrupt: 

In [2]:
import sys
sys.path.append('../util')

In [10]:
def parse_file(file_path):
    results = []
    with open(file_path, 'r') as file:
        lines = file.readlines()
    for line in lines:
        line = line.strip()
        parts = line.split(":")
        k = int(parts[0].split("@")[1])
        match = float(parts[1].strip())
        recall = match / k
        results.append((k, recall))
    return results

In [8]:
import os
dir = 'recall'
file_names = os.listdir(dir)

In [11]:
from converters import convert_string_to_number
import os
import plotly.graph_objects as go

def parse_data_and_generate_plot():
    dir = 'recall'
    file_names = os.listdir(dir)
    
    plot_items = []
    
    indexed_count = 0
    unindexed_count = 0
    for file_name in file_names:
        N = convert_string_to_number(file_name.split('.')[0])
        key = f"N = {N}"
        results = parse_file(dir + '/' + file_name)
        x_values, y_values = zip(*results)
        plot_items.append((key, x_values, y_values))
    
    # Plot data
    fig = go.Figure()
    for (key, x_values, y_values) in plot_items:
        fig.add_trace(go.Scatter(
            x=x_values,
            y=y_values,
            mode='lines+markers',
            name=key
        ))
    fig.update_layout(
        title=f"Recall vs. K",
        xaxis_title='Number of similar vectors retrieved (K)',
        yaxis_title='Recall'
    )
    fig.show()
parse_data_and_generate_plot()