In [65]:
import pyarrow.parquet as pq
import numpy as np

# Load the data
tb1 = pq.read_table('data/3-ja.parquet', columns=['emb'])
table = tb1[0].to_numpy()
flat_ds = list()
for vec in table:
    flat_ds.append(vec)
np_flat_ds = np.array(flat_ds)
table = np_flat_ds

In [66]:
docs = table[0:-1_000]
queries = table[-1_000:]

In [138]:
def quant_size_steps_universal(docs, quantiles):
    ranges = np.vstack((np.quantile(docs, quantiles[0]), np.quantile(docs, quantiles[1])))
    starts = np.full(768, ranges[0])
    steps = np.full(768, (ranges[1] - ranges[0]) / 255)
    return (starts, steps)

def quant_size_steps_universal_min_max(docs):
    ranges = np.vstack((np.min(docs), np.max(docs)))
    starts = np.full(768, ranges[0])
    steps = np.full(768, (ranges[1] - ranges[0]) / 255)
    return (starts, steps)

def quant_size_steps(docs, quantiles):
    ranges = np.vstack((np.quantile(docs, quantiles[0], axis=0), np.quantile(docs, quantiles[1], axis=0)))
    starts = ranges[0,:]
    steps = (ranges[1,:] - ranges[0,:]) / 255
    return (starts, steps)

def quant_size_steps_min_max(docs):
    ranges = np.vstack((np.min(docs, axis=0), np.max(docs, axis=0)))
    starts = ranges[0,:]
    steps = (ranges[1,:] - ranges[0,:]) / 255
    return (starts, steps)

def quantize(docs, starts, steps):
    return np.uint8((docs - starts) / steps)

def find_k_nearest_neighbors_dot(queries, docs, k=10):
    dot_product = -np.dot(queries, docs.T)  # Shape: (num_queries, num_docs)
    partitioned =  np.argpartition(dot_product, kth=k-1, axis=1)[:, :k] 
    return np.take_along_axis(partitioned, np.take_along_axis(dot_product, partitioned, axis=1).argsort(axis=1), axis=1)

def calculate_recall_overlap(baseline_knn, candidate_knn):
    # this calculates the average overlap between the baseline and candidate knn
    # for each query
    overlap = np.zeros(baseline_knn.shape[0])
    for i in range(baseline_knn.shape[0]):
        overlap[i] = len(np.intersect1d(baseline_knn[i], candidate_knn[i])) / baseline_knn.shape[1]
    return np.mean(overlap)

def calculate_recall_overlap_for_quantized(baseline_knns, ks, quant_starts, quant_steps, queries, docs):
    quantized_queries = quantize(queries, quant_starts, quant_steps)
    quantized_docs = quantize(docs, quant_starts, quant_steps)
    for k in ks:
        quantized_baseline_knn = find_k_nearest_neighbors_dot(quantized_queries, quantized_docs, k)
        print(k, calculate_recall_overlap(baseline_knns, quantized_baseline_knn))


In [135]:
baseline_knns = find_k_nearest_neighbors_dot(queries, docs, k=5) 
ks = [5, 10, 20, 50, 100, 200, 500, 1000]

In [144]:
test_cases = [
    {"name": "quantization_min_max_total", "value": quant_size_steps_universal_min_max(docs)},
    {"name": "quantization_min_max", "value": quant_size_steps_min_max(docs)},
    {"name": "quantization_total_99", "value": quant_size_steps_universal(docs, [0.01, 0.99])},
    {"name": "quantization_99", "value": quant_size_steps(docs, [0.01, 0.99])},
    {"name": "quantization_total_90", "value": quant_size_steps_universal(docs, [0.1, 0.90])},
    {"name": "quantization_90", "value": quant_size_steps(docs, [0.1, 0.90])},
]

In [145]:
for test_case in test_cases:
    print(test_case["name"])
    (starts, steps) = test_case["value"]
    calculate_recall_overlap_for_quantized(baseline_knns, ks, starts, steps, queries, docs)

quantization_min_max_total
10 0.0
50 0.0002
100 0.0008
1000 0.0074
quantization_min_max
10 0.0002
50 0.0004
100 0.0006000000000000001
1000 0.007200000000000001
quantization_total_99
10 0.0
50 0.0002
100 0.0006000000000000001
1000 0.007
quantization_99
10 0.0002
50 0.0002
100 0.0006000000000000001
1000 0.0058000000000000005
quantization_total_90
10 0.0
50 0.0004
100 0.0008
1000 0.0094
quantization_90
10 0.0002
50 0.0004
100 0.0006000000000000001
1000 0.01


In [140]:
(min_max_starts, min_max_step) = quant_size_steps_min_max(docs)
calculate_recall_overlap_for_quantized(baseline_knns, ks, min_max_starts, min_max_step, queries, docs)

10 0.0002
50 0.0004
100 0.0006000000000000001
1000 0.007200000000000001
