In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter

In [2]:
from typing import List, Dict, Any, Tuple

In [3]:
from tqdm.notebook import tqdm, trange

In [4]:
from scipy import stats

In [5]:
import faiss 

from sentence_transformers import SentenceTransformer
#from sentence_transformers.quantization import quantize_embeddings

  from tqdm.autonotebook import tqdm, trange


In [6]:
#model_id = "sentence-transformers/all-MiniLM-L6-v2"

In [7]:
# Parameters
model_id = "mixedbread-ai/mxbai-embed-large-v1"


In [14]:
# load an embedding model
model = SentenceTransformer(model_id, trust_remote_code=True)

In [15]:
def generate_single_test_case(num: int,
                              step: int = 1) -> Tuple[int, int, List[int]]:
    """
    Given a query number (int) and step size (int), generate a test case.
    Return the query number (int), step size (int), 
    and a candidate list of integers (list).
    """
    # make sure the step size is positive
    assert step > 0, f"step={step} is not valid!  Step size must be positive."
    
    start = num + step
    end = num + (11 * step)
    return (num, step, list(range(start, end, step)))

In [16]:
generate_single_test_case(10, 2)

(10, 2, [12, 14, 16, 18, 20, 22, 24, 26, 28, 30])

In [17]:
# Test case 1
num = 5
step = 2
expected_result = (5, 2, [7, 9, 11, 13, 15, 17, 19, 21, 23, 25])
assert generate_single_test_case(num, step) == expected_result

# Test case 2
num = 10
step = 3
expected_result = (10, 3, [13, 16, 19, 22, 25, 28, 31, 34, 37, 40])
assert generate_single_test_case(num, step) == expected_result

# Test case 3
num = 0
step = 5
expected_result = (0, 5, [5, 10, 15, 20, 25, 30, 35, 40, 45, 50])
assert generate_single_test_case(num, step) == expected_result

# Test case 4
num = -3
step = 1
expected_result = (-3, 1, [-2, -1, 0, 1, 2, 3, 4, 5, 6, 7])
assert generate_single_test_case(num, step) == expected_result

# Test case 5
num = 1000
step = 25
expected_result = (1000, 25, [1025, 1050, 1075, 1100, 1125, 1150, 1175, 1200, 1225, 1250])
assert generate_single_test_case(num, step) == expected_result

In [18]:
def generate_test_cases(min_num: int, 
                        max_num: int,
                        min_step: int,
                        max_step: int,
                        num_cases: int) -> List[Tuple[int, List[int]]]:
    """
    Generate test cases for the given parameters.
    Return a list of test cases.
    """
    # make sure the input parameters are valid
    assert min_num < max_num, f"min_num={min_num} must be less than max_num={max_num}"
    assert min_step < max_step, f"min_step={min_step} must be less than max_step={max_step}"
    assert num_cases > 0, f"num_cases={num_cases} must be a positive integer."

    return [
        generate_single_test_case(num, step) 
        for num, step in zip(
            np.random.randint(min_num, max_num +1, num_cases),
            np.random.randint(min_step, max_step + 1, num_cases)
        )
    ]

In [19]:
# genetate 10k test cases
test_cases = generate_test_cases(min_num=1,
                                 max_num=100,
                                 min_step=1,
                                 max_step=5, 
                                 num_cases=10000)

In [20]:
# group test cases by step size
test_cases_by_step = {
    step: [t_case for t_case in test_cases if t_case[1] == step]
    for step in range(1, 6)
}

In [21]:
for step, t_cases in test_cases_by_step.items():
    print(f"Step = {step}: {len(t_cases)} test cases")

Step = 1: 1962 test cases
Step = 2: 1996 test cases
Step = 3: 2034 test cases
Step = 4: 1973 test cases
Step = 5: 2035 test cases


In [22]:
# get a set of all numbers in test cases
all_numbers = [
    num
    for t_case in test_cases
    for num in t_case[2]
]

# add query numbers too
all_numbers.extend([t_case[0] for t_case in test_cases])

In [23]:
unique_numbers = list(set(all_numbers))

unique_numbers_embeddings = model.encode([str(i) for i in unique_numbers])

In [24]:
num_to_embeddings = {
    num: emb
    for num, emb in zip(unique_numbers, unique_numbers_embeddings)
}

In [27]:
n = 0
debug = True

step_to_tau_correlations = defaultdict(list)

for k, t_case in tqdm(enumerate(test_cases), total=len(test_cases)):

    query_num, step, candidates = t_case

    # create a dictionary to map index to number
    index_to_number = {i: num for i, num in enumerate(candidates)}

    # get the embedding for the query number
    query_num_emb = num_to_embeddings[query_num].reshape(1, -1)

    # normalize the query number embedding
    faiss.normalize_L2(query_num_emb)

    # get the embeddings for the candidates
    candidate_embs = np.array([num_to_embeddings[c] for c in candidates])

    # get the dimensionality of the embeddings
    d = candidate_embs.shape[1]

    # create search index
    # normalize embeddings
    faiss.normalize_L2(candidate_embs)

    # create an inner product index
    DB_index = faiss.IndexFlatIP(d)

    # Add normalized vectors to the index
    DB_index.add(candidate_embs)

    # Search
    k = 10  # number of nearest neighbors to retrieve
    embedding_similarities, indices = DB_index.search(query_num_emb, k)

    # turn similarities into distances and round
    embedding_distances_rounded = [
        round(1 - d, 3) for d in embedding_similarities[0]
    ]

    # retrieve numbers from indices
    ranked_candidates = [index_to_number[i] for i in indices[0]]

    # calculate numerical distances (distance on the number line)
    numerical_distances = [abs(query_num - c) for c in ranked_candidates]

    # calculate Kendall's tau
    kendall_tau, _ = stats.kendalltau(
        numerical_distances, 
        embedding_distances_rounded
    )

    step_to_tau_correlations[step].append(kendall_tau)

    if debug:

        # print results
        print(f"Query Number: {query_num}")
        print(f"Step: {step}")
        print(f"Candidates: {candidates}")
        print(f"    Ranked: {ranked_candidates}")
        print()
        print(f"Numerical Distances: {numerical_distances}")
        print(f"Embedding Distances: {embedding_distances_rounded}")

        
        print(f"Kendall's Tau: {kendall_tau}")
        print()

        if n == 10:
            break

        n += 1

  0%|          | 0/10000 [00:00<?, ?it/s]

Query Number: 13
Step: 1
Candidates: [14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
    Ranked: [14, 21, 23, 20, 17, 18, 16, 15, 19, 22]

Numerical Distances: [1, 8, 10, 7, 4, 5, 3, 2, 6, 9]
Embedding Distances: [0.293, 0.352, 0.359, 0.394, 0.405, 0.465, 0.48, 0.483, 0.499, 0.5]
Kendall's Tau: -0.022222222222222223

Query Number: 19
Step: 5
Candidates: [24, 29, 34, 39, 44, 49, 54, 59, 64, 69]
    Ranked: [49, 39, 59, 29, 69, 24, 34, 54, 64, 44]

Numerical Distances: [30, 20, 40, 10, 50, 5, 15, 35, 45, 25]
Embedding Distances: [0.315, 0.322, 0.346, 0.378, 0.38, 0.382, 0.389, 0.426, 0.427, 0.486]
Kendall's Tau: 0.06666666666666667

Query Number: 90
Step: 4
Candidates: [94, 98, 102, 106, 110, 114, 118, 122, 126, 130]
    Ranked: [98, 94, 110, 114, 106, 130, 118, 102, 122, 126]

Numerical Distances: [8, 4, 20, 24, 16, 40, 28, 12, 32, 36]
Embedding Distances: [0.281, 0.296, 0.298, 0.32, 0.351, 0.403, 0.438, 0.445, 0.45, 0.487]
Kendall's Tau: 0.5111111111111111

Query Number: 68
Step: 1
Candidates

In [26]:
# get some statistics for each step size
# generate some statistics in a dataframe 
step_statistics = {
    "step": [],
    "mean_tau": [],
    "std_tau": []
}

for step, taus in step_to_tau_correlations.items():
    step_statistics["step"].append(step)
    step_statistics["mean_tau"].append(np.mean(taus))
    step_statistics["std_tau"].append(np.std(taus))

step_statistics_df = pd.DataFrame(step_statistics, index=step_statistics["step"])

# show dataframe orderby step size
step_statistics_df.sort_index()


# for step in range(1, 6):
#     taus = step_to_tau_correlations[step]
#     print(f"Step = {step}")
#     print(f"    mean Kendall's Tau: {np.mean(taus):.3f}")
#     print(f"  stddev Kendall's Tau: {np.std(taus):.3f}")
#     print()

Unnamed: 0,step,mean_tau,std_tau
1,1,0.278791,0.24046
2,2,0.179938,0.208732
3,3,0.121105,0.254391
4,4,0.145038,0.248601
5,5,0.102117,0.217201
