# **SISAP 2025 Challenge: Data Preparation and Compression**

This notebook presents the initial steps of data preparation and compression for the SISAP 2025 challenge dataset.

- **Section 1:** Loads the dataset `benchmark-dev-ccnews.h5` from the SISAP 2025 challenge repository on Hugging Face.

- **Section 2:** Applies a rotation transformation to the original embeddings.

- **Section 3:** Encodes the rotated embeddings into a more memory-efficient format. The resulting encoding is a dictionary containing four key elements:

  1. **`packed_binary_matrix`**: A PyTorch tensor storing the index matrix packed in bytes. Since PyTorch does not support bit-level boolean tensors, the indices are packed byte-wise. This tensor has the same number of rows as the original embedding database, with packing performed along each row.

  2. **`outliers`**: A tensor with fewer columns than the original embeddings, storing outlier values that correspond to zero entries in the index matrix.

  3. **`avg_values`**: A one-dimensional tensor containing a small set of average values. The first element is zero; the subsequent elements correspond to average values. Each value in the index matrix maps directly to an index in `avg_values`, so a matrix entry of `i` corresponds to `avg_values[i]`.

  4. **`og_shape_bin`**: Metadata used to correctly unpack the `packed_binary_matrix` back to its original shape.

  Because the compressed elements are organized as tensors with uniform row sizes, this structure supports efficient partial decoding: an index can be built over the compressed dataset where only the specific embedding(s) required during a query are decoded on demand.

- **Section 4:** Brute-force similarity search experiments to evaluate the recall@k performance of the proposed compression scheme.

- **Section 5:** Results obtained for the best configuration.



**Note:** The best recall values were obtained when using a 6-bit resolution for the index matrix and full dot product. Therefore, this notebook focuses on reporting the results for that configuration. However, all functions required to evaluate other bit resolutions are included and can be used for further experimentation.


---
*Notebook created by Scarlett Magdaleno-Gatica, Master's student in Computer Science at CICESE, for SISAP 2025 challenge. This is a preliminary version.*  
*Date: June 02, 2025*


### Install

In [1]:
!pip install datasets faiss-cpu h5py scikit-learn
!pip install huggingface_hub

Defaulting to user installation because normal site-packages is not writeable
Collecting faiss-cpu
  Obtaining dependency information for faiss-cpu from https://files.pythonhosted.org/packages/53/45/7c85551025d9f0237d891b5cffdc5d4a366011d53b4b0a423b972cc52cea/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting numpy>=1.17 (from datasets)
  Obtaining dependency information for numpy>=1.17 from https://files.pythonhosted.org/packages/b3/dd/2238b898e51bd6d389b7389ffb20d7f4c10066d80351187ec8e303a5a475/numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
  Obtaining dependency information for numpy>=1.17 from https://files.py

### Imports

In [3]:
import numpy as np
import h5py
import faiss
import torch
import tensorflow as tf
import os
import math
import pandas as pd
from datasets import load_dataset
from sklearn.preprocessing import normalize
from sklearn.metrics import precision_score
from typing import Tuple
from huggingface_hub import hf_hub_download
from huggingface_hub import login
from tqdm import tqdm

ModuleNotFoundError: No module named 'torch'

In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)


# results_dir = 'ntent/drive/MyDri/cove/Tesis/SISAP/Cache'
# new_results_dir = '/content/drive/MyDrive/Tesis/SISAP/NewCache'

# login_token = ""

# # Iniciar sesión en Hugging Face
# login(token=login_token)

Mounted at /content/drive


## **1. Load Dataset from HuggingFace**

In [2]:
def load_sisap_benchmark(file_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Loads the benchmark file of vectors (train, itest/otest) and returns base and queries as PyTorch tensors.
    """
    with h5py.File(file_path, 'r') as f:
        print("Keys available in the file:", list(f.keys()))

        # Ensure the expected keys are present
        if 'train' not in f or 'otest' not in f:
            raise KeyError("Keys 'train' or 'otest' are not found in the file.")
        print(f"original_dtype: {f['train'].dtype}")

        base = torch.tensor(np.array(f['train']))  # Training vectors (base)
        queries = torch.tensor(np.array(f['otest']['queries']))  # Query vectors from the 'otest' group
        dists = torch.tensor(np.array(f['otest']['dists']))  # Distances to nearest neighbors
        knns = torch.tensor(np.array(f['otest']['knns']))  # Indices of nearest neighbors

    return base, queries, dists, knns

# Load benchmark and ground truth
# file_benchmark_path = hf_hub_download(
#     repo_id="sadit/SISAP2025",                # dataset name
#     filename="benchmark-dev-ccnews.h5",       # file name
#     repo_type="dataset"                       # important: specify that it is a dataset
# )

file_benchmark_path = "/users/cfoste18/data/cfoste18/knn-construction/data/ccnews/ccnews.h5"
base, queries, dists, knns = load_sisap_benchmark(file_benchmark_path)

print("Base:", base.shape)
print("Queries:", queries.shape)
print("Ground truth dists:", dists.shape)
print("Ground truth knns:", knns.shape)


NameError: name 'Tuple' is not defined

## **2. Rotate embeddings with random rotation**

In [5]:
def generate_rotation_matrix_Q(n, dtype=torch.float32):
    """
    Generates an n x n rotation matrix Q using PyTorch and returns it as a PyTorch tensor.

    The columns of Q are orthonormal vectors, and the matrix has a determinant of 1.

    Args:
        n (int): Size of the matrix.
        dtype (torch.dtype): Desired dtype for the output rotation matrix (e.g., torch.float32, torch.float16).

    Returns:
        torch.Tensor: Orthonormal rotation matrix Q as a PyTorch tensor.
    """

    # Generate a random matrix using PyTorch
    random_matrix = torch.randn(n, n, dtype=dtype)

    # Apply QR decomposition to obtain an orthonormal matrix Q
    Q, R = torch.linalg.qr(random_matrix)

    # Ensure that Q is a rotation matrix (det(Q) = 1)
    det_Q = torch.det(Q)
    if det_Q < 0:
        # Adjust the sign of the last column to ensure det(Q) = 1
        Q[:, -1] *= -1  # Multiply the last column by -1

    return Q

In [22]:
# Generate the random rotation matrix
n = base.shape[1]
#Q = generate_rotation_matrix_Q(n)
#torch.save(Q, os.path.join(results_dir, 'rotation_matrix.pt'))
Q = torch.load(os.path.join(results_dir, 'rotation_matrix.pt'))

In [23]:
rotated_queries = torch.matmul(queries, Q.T)
rotated_base = torch.matmul(base, Q.T)

print(f"Rotated Queries: {rotated_queries.shape}")
print(f"Rotated Base: {rotated_base.shape}")

Rotated Queries: torch.Size([11000, 384])
Rotated Base: torch.Size([603664, 384])


## **3. OrthoQuant compression**

### Enconding Functions

In [8]:
def orthoQuant_encode(D, lower_limit, upper_limit, num_bits):
    dtype = D.dtype
    device = D.device
    B, N = D.shape
    num_segments = 2**num_bits - 1
    less_segments = num_segments // 2
    more_segments = num_segments - less_segments

    sorted_D, _ = torch.sort(D, dim=1)
    l = sorted_D[:, lower_limit]
    r = sorted_D[:, upper_limit]
    l = torch.where(l > 0, -r.abs(), l)
    r = torch.where(r < 0, l.abs(), r)

    mask_l = D < l.unsqueeze(1)
    mask_r = D > r.unsqueeze(1)

    # Initial counts per row
    counts_l = mask_l.sum(dim=1)
    counts_r = mask_r.sum(dim=1)

    # Expected values (we use the first one)
    expected_l = counts_l[0]
    expected_r = counts_r[0]

    # Expand masks if necessary
    mask_l = expand_mask_side(mask_l, counts_l, expected_l, sorted_D, D, side='left')
    mask_r = expand_mask_side(mask_r, counts_r, expected_r, sorted_D, D, side='right')

    # Recompute counts to ensure uniformity
    counts_l = mask_l.sum(dim=1)
    counts_r = mask_r.sum(dim=1)

    assert torch.all(counts_l == expected_l), "Left mask uniformity was not achieved"
    assert torch.all(counts_r == expected_r), "Right mask uniformity was not achieved"

    mask_og = mask_l | mask_r

    # Ensure all vectors have the same number of out-of-range values
    counts = mask_og.sum(dim=1)
    assert torch.all(counts == counts[0]), f"Vectors do not have the same number of out-of-range values! {counts}"

    non_zero_elements_tensor = D[mask_og].view(B, -1)

    mask_in_range = ~mask_og
    mask_neg = (l.unsqueeze(1) <= D) & (D < 0) & mask_in_range
    mask_pos = (0 <= D) & (D <= r.unsqueeze(1)) & mask_in_range

    # Fixed number of segments
    count_neg = mask_neg[0].sum()
    count_pos = mask_pos[0].sum()
    num_segments_l = more_segments if count_neg > count_pos else less_segments
    num_segments_r = num_segments - num_segments_l

    neg_edges = torch.linspace(-1, 0, num_segments_l + 1, device=device).view(1, -1)
    pos_edges = torch.linspace(0, 1, num_segments_r + 1, device=device).view(1, -1)

    # Scale edges based on thresholds
    neg_edges_scaled = l.view(-1, 1) * (-1 * neg_edges)
    pos_edges_scaled = r.view(-1, 1) * pos_edges

    # Initialization
    avg_values = torch.zeros((B, num_segments), dtype=dtype, device=device)

    for seg_idx in range(num_segments_l):
        lower = neg_edges_scaled[:, seg_idx].unsqueeze(1)
        upper = neg_edges_scaled[:, seg_idx + 1].unsqueeze(1)
        seg_mask = (D >= lower) & (D < upper) & mask_neg
        seg_vals = torch.where(seg_mask, D, torch.zeros_like(D))
        counts = seg_mask.sum(dim=1)
        valid = counts > 0
        sums = seg_vals.sum(dim=1)
        avg_values[valid, seg_idx] = sums[valid] / counts[valid]

    for seg_idx in range(num_segments_r):
        lower = pos_edges_scaled[:, seg_idx].unsqueeze(1)
        upper = pos_edges_scaled[:, seg_idx + 1].unsqueeze(1)
        seg_mask = (D >= lower) & (D <= upper) & mask_pos
        seg_vals = torch.where(seg_mask, D, torch.zeros_like(D))
        counts = seg_mask.sum(dim=1)
        valid = counts > 0
        sums = seg_vals.sum(dim=1)
        avg_values[valid, num_segments_l + seg_idx] = sums[valid] / counts[valid]

    # Add a zero value as the first segment (reserved for out-of-range positions)
    avg_values = torch.cat([
        torch.zeros((avg_values.size(0), 1), dtype=avg_values.dtype, device=avg_values.device),
        avg_values
    ], dim=1)  # shape (B, num_segments + 1)

    index_matrix = torch.zeros_like(D, dtype=torch.long)

    for seg_idx in range(num_segments_l):
        lower = neg_edges_scaled[:, seg_idx].unsqueeze(1)
        upper = neg_edges_scaled[:, seg_idx + 1].unsqueeze(1)
        seg_mask = (D >= lower) & (D < upper) & mask_neg
        index_matrix[seg_mask] = seg_idx + 1

    for seg_idx in range(num_segments_r):
        lower = pos_edges_scaled[:, seg_idx].unsqueeze(1)
        upper = pos_edges_scaled[:, seg_idx + 1].unsqueeze(1)
        seg_mask = (D >= lower) & (D <= upper) & mask_pos
        index_matrix[seg_mask] = num_segments_l + seg_idx + 1

    # Convert segment indices to binary
    binary_matrix = ((index_matrix.unsqueeze(-1) >> torch.arange(num_bits - 1, -1, -1, device=device)) & 1).to(torch.int8)
    binary_matrix = binary_matrix.reshape(B, num_bits * N)

    # Compress binary representation
    packed_binary_matrix = packbits_pytorch(binary_matrix)
    og_shape_bin = binary_matrix.shape

    compressed_batch = {
        'outliers': non_zero_elements_tensor,
        'avg_values': avg_values,
        'packed_binary_matrix': packed_binary_matrix,
        'og_shape_bin': og_shape_bin
    }

    return compressed_batch


In [66]:
def orthoQuant_encode_1bit(D, lower_limit, upper_limit):
    """
    Encodes vectors using a single segment per dimension (num_bits = 1).
    For each vector, elements outside the specified percentile range are discarded,
    and the average is computed over the remaining elements.

    Parameters:
    -----------
    D : torch.Tensor of shape (n, d)
        Rotated vectors to be compressed.
    lower_limit : int
        Lower index of the segment to use (percentile).
    upper_limit : int
        Upper index of the segment to use (percentile).

    Returns:
    --------
    A dictionary containing:
        - 'non_zero_elements': elements outside the [l, r] range
        - 'avg_values': per-vector average of the central segment, broadcast to shape (n, d)
        - 'packed_binary_matrix': packed boolean mask of the central segment
        - 'og_shape_bin': original shape before packing
    """
    n, d = D.shape
    device = D.device

    # Sort each row to compute percentiles
    sorted_D, _ = torch.sort(D, dim=1)
    l = sorted_D[:, lower_limit]  # (n,)
    r = sorted_D[:, upper_limit]  # (n,)

    # Ensure [l, r] defines a valid range
    l = torch.where(l > 0, -r.abs(), l)
    r = torch.where(r < 0, l.abs(), r)

    mask_l = D < l.unsqueeze(1)
    mask_r = D > r.unsqueeze(1)

    # Count initial number of elements outside the range per row
    counts_l = mask_l.sum(dim=1)
    counts_r = mask_r.sum(dim=1)

    # Use the counts from the first vector as the expected count
    expected_l = counts_l[0]
    expected_r = counts_r[0]

    # Expand masks to match expected counts if necessary
    mask_l = expand_mask_side(mask_l, counts_l, expected_l, sorted_D, D, side='left')
    mask_r = expand_mask_side(mask_r, counts_r, expected_r, sorted_D, D, side='right')

    # Verify that all rows now have the same number of masked elements
    counts_l = mask_l.sum(dim=1)
    counts_r = mask_r.sum(dim=1)

    assert torch.all(counts_l == expected_l), "Uniformity in left mask not achieved"
    assert torch.all(counts_r == expected_r), "Uniformity in right mask not achieved"

    mask_og = mask_l | mask_r

    # Ensure all vectors have the same number of out-of-range values
    counts = mask_og.sum(dim=1)
    assert torch.all(counts == counts[0]), f"Vectors do not have the same number of out-of-range values! {counts}"

    # Store out-of-range values as "non_zero_elements"
    non_zero_elements_tensor = D[mask_og].view(n, -1)

    # In-range segment
    mask_seg = ~mask_og          # (n, d)
    segment = D * mask_seg       # (n, d), zeros will not affect the average
    avg_values = segment.sum(dim=1, keepdim=True) / mask_seg.sum(dim=1, keepdim=True)  # (n, 1)

    # Add a column of zeros at the beginning
    zero_col = torch.zeros_like(avg_values)
    avg_values = torch.cat([zero_col, avg_values], dim=1)  # (n, 2)

    # Encode the binary mask
    binary_matrix = mask_seg.to(torch.uint8)  # (n, d)
    og_shape_bin = binary_matrix.shape
    packed_binary_matrix = packbits_pytorch(binary_matrix)

    compressed_batch = {
        'outliers': non_zero_elements_tensor,
        'avg_values': avg_values,
        'packed_binary_matrix': packed_binary_matrix,
        'og_shape_bin': og_shape_bin
    }

    return compressed_batch

In [10]:
def packbits_pytorch(unpacked, num_bits=8):
    """
    Packs a boolean tensor into uint8, equivalent to np.packbits.
    `unpacked`: boolean tensor of size (n, m * num_bits)
    Returns a uint8 tensor of size (n, m)
    """
    device = unpacked.device
    n, total_bits = unpacked.shape
    assert total_bits % num_bits == 0, "The total number of bits must be divisible by num_bits."

    m = total_bits // num_bits
    unpacked = unpacked.view(n, m, num_bits).to(torch.uint8)
    bits = torch.tensor([1 << i for i in range(num_bits - 1, -1, -1)], device=device, dtype=torch.uint8)
    packed = torch.sum(unpacked * bits, dim=-1)
    return packed.to(dtype=torch.uint8)


In [11]:
# Function to expand the mask per row if the count does not match
def expand_mask_side(mask_side, counts_side, expected, sorted_D, D, side='left'):
    B, N = mask_side.shape
    mask_expanded = mask_side.clone()

    # Rows that need expansion
    rows_to_expand = (counts_side != expected).nonzero(as_tuple=True)[0]
    if len(rows_to_expand) == 0:
        return mask_expanded  # No rows to expand

    diffs = (expected - counts_side[rows_to_expand]).tolist()  # differences per row

    # For each row to expand, obtain additional indices according to the side
    # The maximum number of elements to add is max(diffs)
    max_diff = max(diffs)

    # Create a matrix with relative indices to add: shape (max_diff,)
    offsets = torch.arange(max_diff, device=D.device)

    if side == 'left':
        # Indices to add: lower_limit + 1 + offset, vectorized
        base_idx = lower_limit + 1
        idx_to_add = base_idx + offsets.unsqueeze(0)  # shape (1, max_diff)
        # For each row, only up to diffs[i] elements are desired
        # For safety, clamp indices to < N
        idx_to_add = torch.clamp(idx_to_add, max=N-1)
    else:
        # Indices to add: upper_limit - 1 - offset
        base_idx = upper_limit - 1
        idx_to_add = base_idx - offsets.unsqueeze(0)
        idx_to_add = torch.clamp(idx_to_add, min=0)

    # Repeat idx_to_add for all rows to expand: shape (num_rows, max_diff)
    idx_to_add = idx_to_add.repeat(len(rows_to_expand), 1)

    # For each row, limit the number of elements to add according to diffs
    # Create a boolean mask to avoid adding extra elements
    diffs_tensor = torch.tensor(diffs, device=D.device).unsqueeze(1)  # (num_rows, 1)
    valid_mask = offsets.unsqueeze(0) < diffs_tensor  # (num_rows, max_diff)

    # Extract the sorted values to be added: shape (num_rows, max_diff)
    rows_idx = rows_to_expand.unsqueeze(1).expand(-1, max_diff)
    vals_to_add = sorted_D[rows_idx, idx_to_add]

    # Now add only the valid values by creating a mask for each row in D:
    for i, row in enumerate(rows_to_expand):
        vals = vals_to_add[i][valid_mask[i]]
        # Instead of looping per value, broadcasting can be used to mark the mask:
        # mask_expanded[row] = mask_expanded[row] | ((D[row].unsqueeze(1) == vals).any(dim=1))
        # But for pure PyTorch, we use:
        mask_expanded[row] = mask_expanded[row] | torch.isin(D[row], vals)

    return mask_expanded


In [12]:
def find_concavity_changes(array, slope_threshold=1.0, atol=1e-2):
    """
    Finds the points where the slope is equal to ±45 degrees, indicating concavity changes.

    Parameters:
        array (torch.Tensor): One-dimensional (flattened) tensor.
        slope_threshold (float): Target slope value (default is 1.0 for 45 degrees).
        atol (float): Tolerance to consider a slope as equal to the target.

    Returns:
        list: Indices where the slope is approximately ±slope_threshold.
    """
    if array.dim() != 1:
        raise ValueError("The tensor must be one-dimensional (flattened).")

    downsample_factor = max(1,int(len(array)/(2**11)))

    # Calculate the necessary padding size to make the array divisible
    padding_size = (downsample_factor - len(array) % downsample_factor) % downsample_factor

    # Perform padding with zeros or replicate the last value
    if padding_size > 0:
        padding = torch.zeros(padding_size)  # Padding with zeros
        array_padded = torch.cat((array, padding))
    else:
        array_padded = array

    # Perform downsampling
    downsampled_array = array_padded.view(-1, downsample_factor).mean(dim=1)

    # Calculate dx based on the maximum value of the array
    #dx = 10**(math.floor(math.log10(torch.max(torch.abs(array)).item())) - 2)
    # Ensure the tensor is on the correct device before calling .item()
    dx = 10**(math.floor(math.log10(torch.max(torch.abs(array)).cpu().item())) - 2)

    # Compute the first derivative using vectorized finite differences
    first_derivative = (downsampled_array[1:] - downsampled_array[:-1]) / dx

    def find_indices(atol):
        condition = (
            (torch.abs(first_derivative - slope_threshold) <= atol) |
            (torch.abs(first_derivative + slope_threshold) <= atol)
        )
        return torch.nonzero(condition).squeeze(1).tolist()

    def ensure_two_elements(atol_initial, atol_max, step_size):
        atol = atol_initial
        while atol <= atol_max:
            slope_indices = find_indices(atol)
            if len(slope_indices) >= 2:
                return slope_indices
            atol += step_size  # Increase tolerance
        return slope_indices  # Return best option

    # Search for indices with the initial tolerance
    slope_indices = ensure_two_elements(atol_initial=1e-6, atol_max=1, step_size=1e-1)

    # If not enough indices are found, apply an alternative strategy
    if len(slope_indices) < 2:
        print("Warning: Less than 2 indices found. You may need to adjust your parameters.")
        # You can return default values or continue execution as needed

    # Adjust the indices to correspond to the original array (considering downsampling)
    adjusted_indices = [i * downsample_factor for i in slope_indices]


    first_curve = adjusted_indices[0]
    second_curve = adjusted_indices[-1]


    return first_curve, second_curve

In [13]:
def estimate_global_concavity_thresholds(
    vectors: torch.Tensor,
    sample_size: int = 5000,
    seed: int = None
) -> Tuple[float, float]:
    """
    Estimates average lower and upper thresholds based on concavity changes
    from a random sample of vectors.

    Args:
        vectors (Tensor): Tensor of shape (N, D) containing N vectors.
        sample_size (int): Number of vectors to sample.
        seed (int, optional): Random seed for reproducibility.

    Returns:
        Tuple[float, float]: Average lower and upper concavity change thresholds.
    """
    N = vectors.size(0)
    sample_size = min(sample_size, N)

    if seed is not None:
        torch.manual_seed(seed)

    indices = torch.randperm(N)[:sample_size]

    lower_limits = []
    upper_limits = []

    for idx in indices:
        sorted_vec = torch.sort(vectors[idx])[0]
        lower, upper = find_concavity_changes(sorted_vec)
        lower_limits.append(lower)
        upper_limits.append(upper)

    avg_lower = float(torch.tensor(lower_limits, dtype=torch.float32).mean())
    avg_upper = float(torch.tensor(upper_limits, dtype=torch.float32).mean())

    return int(avg_lower), int(avg_upper)


In [14]:
def orthoQuant_encode_in_batches(dataset, lower_limit, upper_limit, num_bits, batch_size):
    """
    Processes a set of vectors using orthoQuant_encode_optimized in batches.

    Parameters:
    - dataset: tensor of shape (num_vectors, dim)
    - lower_limit: lower index for segmentation
    - upper_limit: upper index for segmentation
    - num_bits: number of bits per segment index
    - batch_size: batch size to use

    Returns:
    A dictionary with the concatenated fields:
    - non_zero_elements: tensor of shape (total_vectors, K)
    - avg_values: tensor of shape (total_vectors, num_segments + 1)
    - packed_binary_matrix: tensor of shape (total_vectors, compressed_dim)
    - og_shape_bin: tuple with the original binary shape before `packbits`
    """
    from collections import defaultdict

    device = dataset.device
    total_vectors = dataset.shape[0]

    # Initialize lists to accumulate results
    all_outliers = []
    all_avg_values = []
    all_packed_binaries = []

    for i in range(0, total_vectors, batch_size):
        batch = dataset[i:i + batch_size]
        if num_bits == 1:
            compressed = orthoQuant_encode_1bit(batch, lower_limit, upper_limit)
        else:
            compressed = orthoQuant_encode(batch, lower_limit, upper_limit, num_bits)

        all_outliers.append(compressed['outliers'].detach())
        all_avg_values.append(compressed['avg_values'].detach())
        all_packed_binaries.append(compressed['packed_binary_matrix'].detach())

        if i == 0:
            og_shape_bin = compressed['og_shape_bin']  # assumed to be the same for all batches

    result = {
        'outliers': torch.cat(all_outliers, dim=0),
        'avg_values': torch.cat(all_avg_values, dim=0),
        'packed_binary_matrix': torch.cat(all_packed_binaries, dim=0),
        'og_shape_bin': og_shape_bin
    }

    # Compute global mean per column and replace avg_values by this mean
    mean_per_column = result['avg_values'].mean(dim=0)
    result['avg_values'] = mean_per_column  # Shape changes from (total_vectors, d) to (d,)

    return result

In [19]:
import sys
def get_size_in_bytes(obj):
    if isinstance(obj, torch.Tensor):
        return obj.element_size() * obj.numel()
    elif isinstance(obj, np.ndarray):
        return obj.nbytes
    elif isinstance(obj, list) or isinstance(obj, tuple):
        return sum(get_size_in_bytes(item) for item in obj)
    elif isinstance(obj, dict):
        return sum(get_size_in_bytes(k) + get_size_in_bytes(v) for k, v in obj.items())
    else:
        return sys.getsizeof(obj)

### Decoding Functions

In [15]:
def orthoQuant_decode_database(outliers, avg_values, packed_binary_matrix, og_shape_bin, num_bits):
    """
    Reconstructs an approximation of the original tensor D using the
    out-of-range values and the segment averages.

    Args:
        outliers (torch.Tensor): (B, k) containing original out-of-range values.
        avg_values (torch.Tensor): (B, 2^num_bits) including a reserved zero value at index 0.
        packed_binary_matrix (torch.Tensor): (B, m) encoded as uint8.
        og_shape_bin (tuple): original shape of the binary matrix before packing.
        num_bits (int): Number of bits used to encode indices.

    Returns:
        D_reconstructed (torch.Tensor): Approximation of the original tensor (B, N).
    """

    device = packed_binary_matrix.device
    B = packed_binary_matrix.size(0)
    N = og_shape_bin[1] // num_bits

    # Unpack the encoded indices
    binary_matrix = unpackbits_pytorch(packed_binary_matrix, original_shape=og_shape_bin)
    index_matrix = binary_matrix.view(B, N, num_bits)

    indices = torch.zeros((B, N), dtype=torch.long, device=device)
    for bit in range(num_bits):
        indices = (indices << 1) | index_matrix[:, :, bit].to(torch.long)

    # Initialize output tensor
    D_reconstructed = torch.zeros((B, N), dtype=avg_values.dtype, device=device)

    # Insert approximated in-range values (segmented)
    D_reconstructed = avg_values.gather(1, indices)

    # Insert original out-of-range values (index 0)
    mask_out_of_range = (indices == 0)

    counts_per_row = mask_out_of_range.sum(dim=1)

    assert torch.all(mask_out_of_range.sum(dim=1) == outliers.size(1)), \
        "Number of out-of-range elements does not match"

    D_reconstructed[mask_out_of_range] = outliers.flatten()

    return D_reconstructed


In [16]:
def orthoQuant_decode_indices(packed_binary_matrix, og_shape_bin, num_bits):
    """
    Reconstructs an approximation of the original tensor D using the
    out-of-range values and the averages per segment.

    Args:
        packed_binary_matrix (torch.Tensor): (B, m) encoded as uint8.
        og_shape_bin (tuple): original shape of the binary matrix before packing.
        num_bits (int): Number of bits used to encode indices.

    Returns:
        D_reconstructed (torch.Tensor): Approximation of the original tensor (B, N).
    """

    device = packed_binary_matrix.device
    B = packed_binary_matrix.size(0)
    N = og_shape_bin[1] // num_bits

    # Unpack the encoded indices
    binary_matrix = unpackbits_pytorch(packed_binary_matrix, original_shape=og_shape_bin)
    index_matrix = binary_matrix.view(B, N, num_bits)

    indices = torch.zeros((B, N), dtype=torch.long, device=device)
    for bit in range(num_bits):
        indices = (indices << 1) | index_matrix[:, :, bit].to(torch.long)

    return indices


In [17]:
def unpackbits_pytorch(packed, num_bits=8, original_shape=None):
    """
    Unpacks a uint8 tensor into a boolean binary representation.
    Inverse operation of packbits_pytorch_puro.
    """
    device = packed.device
    unpacked = ((packed.unsqueeze(-1) >> torch.arange(num_bits - 1, -1, -1, device=device)) & 1).to(torch.int8)
    unpacked = unpacked.view(packed.size(0), -1)
    if original_shape is not None:
        # assert unpacked.numel() == original_shape[0] * original_shape[1]
        unpacked = unpacked[:, :original_shape[1]]
    return unpacked


In [18]:
def unpack_vector_indices(compressed_data, idx, num_bits):

    packed_row = compressed_data['packed_binary_matrix'][idx]
    device = packed_row.device
    grouped_bits = 8
    unpacked_row = ((packed_row.unsqueeze(-1) >> torch.arange(grouped_bits - 1, -1, -1, device=device)) & 1).to(torch.int8)
    unpacked_bits = unpacked_row.view(-1, num_bits).to(torch.long)

    indices = torch.zeros(unpacked_bits.size(0), dtype=torch.long, device=unpacked_row.device)
    for bit in range(num_bits):
        indices = (indices << 1) | unpacked_bits[:, bit]

    return indices

### Example for 6 bits

In [24]:
num_bits = 6
batch_size = 100000
lower_limit, upper_limit = estimate_global_concavity_thresholds(rotated_base, sample_size=5000, seed=24)
print(f'lower_limit: {lower_limit}, upper_limit: {upper_limit}')

rotated_base = rotated_base.to('cuda')
compressed_rotated_base = orthoQuant_encode_in_batches(rotated_base, lower_limit, upper_limit, num_bits, batch_size)

# save_path = os.path.join(results_dir, f'compressed_rotated_base_{num_bits}b.pt')
# torch.save(compressed_rotated_base, save_path)
# compressed_rotated_base = torch.load(save_path, map_location=torch.device('cpu'))

# Size of the original rotated_base
original_size = get_size_in_bytes(rotated_base)
original_size_gb = original_size / 1e9
original_size_mb = original_size / 1e6
print(f"\nTotal memory size of rotated base: {original_size} bytes ({original_size_mb:.2f} MB, {original_size_gb:.6f} GB)")

# Size of the compressed rotated_base
compressed_size = get_size_in_bytes(compressed_rotated_base)
compressed_size_gb = compressed_size / 1e9
compressed_size_mb = compressed_size / 1e6
print(f"\nTotal memory size of compressed rotated base: {compressed_size} bytes ({compressed_size_mb:.2f} MB, {compressed_size_gb:.6f} GB)")

# Compression ratio
compression_ratio = 100 * (1 - compressed_size / original_size)
print(f"\nCompression ratio: {compression_ratio:.2f}%")

lower_limit: 17, upper_limit: 364

Total memory size of rotated base: 927227904 bytes (927.23 MB, 0.927228 GB)

Total memory size of compressed rotated base: 260783406 bytes (260.78 MB, 0.260783 GB)

Compression ratio: 71.87%


In [25]:
rotated_queries = rotated_queries.to('cuda')
compressed_rotated_queries = orthoQuant_encode_in_batches(rotated_queries, lower_limit, upper_limit, num_bits, batch_size)

# save_path = os.path.join(results_dir, f'compressed_rotated_queries_{num_bits}b.pt')
# torch.save(compressed_rotated_queries, save_path)
# compressed_rotated_queries = torch.load(save_path, map_location=torch.device('cpu'))

# Size of the original rotated_queries
original_size = get_size_in_bytes(rotated_queries)
original_size_gb = original_size / 1e9
original_size_mb = original_size / 1e6
print(f"\nTotal memory size of rotated queries: {original_size} bytes ({original_size_mb:.2f} MB, {original_size_gb:.6f} GB)")

# Size of the compressed rotated_queries
compressed_size = get_size_in_bytes(compressed_rotated_queries)
compressed_size_gb = compressed_size / 1e9
compressed_size_mb = compressed_size / 1e6
print(f"\nTotal memory size of compressed rotated queries: {compressed_size} bytes ({compressed_size_mb:.2f} MB, {compressed_size_gb:.6f} GB)")

# Compression ratio
compression_ratio = 100 * (1 - compressed_size / original_size)
print(f"\nCompression ratio: {compression_ratio:.2f}%")



Total memory size of rotated queries: 16896000 bytes (16.90 MB, 0.016896 GB)

Total memory size of compressed rotated queries: 4752558 bytes (4.75 MB, 0.004753 GB)

Compression ratio: 71.87%


## **4. Recall experiments**

### Full dot product functions

In [32]:
import torch
from torch.func import vmap

def orthoquant_dot_product(compressed_base, compressed_queries, idx_query, num_bits):
    """
    Computes the normalized dot product between the compressed rotated query
    (given by idx_query) and the entire compressed rotated base, using a single
    average values vector (avg_values) shared for the entire base.

    Returns:
        sim: tensor of shape (n_base,) with the normalized dot products.
    """
    # Decode the base
    packed_binary_matrix = compressed_base['packed_binary_matrix']
    og_shape_bin = compressed_base['og_shape_bin']
    base_indices = orthoQuant_decode_indices(packed_binary_matrix, og_shape_bin, num_bits)
    base_avg_values = compressed_base['avg_values']  # <-- vector of shape (d,)
    base_outliers = compressed_base['outliers']  # (n_base, D)
    n_base, D = base_indices.shape

    # Decode the query
    query_indices = unpack_vector_indices(compressed_queries, idx_query, num_bits)     # (D,)
    query_avg_values = compressed_queries['avg_values']                    # (num_bins,)
    query_outliers = compressed_queries['outliers'][idx_query]                # (D,)


    # Expand query to match the base size
    query_indices_exp = query_indices.unsqueeze(0).expand(n_base, -1)
    query_outliers_exp = query_outliers.unsqueeze(0).expand(n_base, -1)

    # Map values for base: use the same avg_values vector for all
    base_vals = base_avg_values[base_indices]  # (n_base, D)

    # Map values for the query
    query_vals = query_avg_values[query_indices_exp]

    # Replace zeros by outliers
    base_vals[base_vals == 0] = base_outliers.flatten()
    query_vals[query_vals == 0] = query_outliers_exp.flatten()

    # Dot product and normalization
    dot = (base_vals * query_vals).sum(dim=1)
    norm_base = torch.norm(base_vals, dim=1)
    norm_query = torch.norm(query_vals, dim=1)

    #print('dot')

    similarities = dot / (norm_base * norm_query + 1e-8)

    return similarities


In [34]:
def recall_tables_by_k(
    idx_query,
    base,
    queries,
    compressed_base,
    compressed_queries,
    num_bits,
    ks=[1, 5, 10, 15, 20, 25, 30],
    Kprimes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 1024],
    eps=1e-8,
    batch_size=100000
):
    """
    Computes, for a given query, how many of the true top-k elements are present
    in the top-K' elements after rotation (compressed space).

    Returns:
        recall_tables: dictionary {k: DataFrame with 1 row and columns [idx, K'=...]}
    """

    # === True similarity ===
    q = queries[idx_query].unsqueeze(0)
    dot = torch.matmul(q, base.T)
    norm_q = torch.norm(q, dim=1, keepdim=True)
    norm_base = torch.norm(base, dim=1, keepdim=True).T
    sim = dot / (norm_q * norm_base + eps)
    idx_real = torch.argsort(sim, dim=1, descending=True).squeeze(0)

    # === Rotated similarity computed in batches ===
    sim_rot_batches = []
    num_vectors = compressed_base['outliers'].shape[0]

    for i in range(0, num_vectors, batch_size):
        batch_compressed_base = {
            'outliers': compressed_base['outliers'][i:i + batch_size, :].clone(),
            'avg_values': compressed_base['avg_values'].clone(),
            'packed_binary_matrix': compressed_base['packed_binary_matrix'][i:i + batch_size, :].clone(),
            'og_shape_bin': compressed_base['og_shape_bin'],
        }

        sim_batch = orthoquant_dot_product(
            batch_compressed_base,
            compressed_queries,
            idx_query,
            num_bits
        )

        sim_rot_batches.append(sim_batch)

    sim_rot = torch.cat(sim_rot_batches, dim=0)
    idx_rot = torch.argsort(sim_rot, descending=True)

    # === Compute recall table for each k ===
    recall_tables = {}

    for k in ks:
        top_k_real = set(idx_real[:k].tolist())
        row = []

        for Kp in Kprimes:
            top_Kp_rot = set(idx_rot[:Kp].tolist())
            recall = len(top_k_real.intersection(top_Kp_rot))
            row.append(recall)

        df = pd.DataFrame([row], columns=[f"@'{Kp}" for Kp in Kprimes])
        df.insert(0, 'idx', idx_query)
        recall_tables[k] = df

    return recall_tables


In [28]:
def recall_k_for_all_queries(
    base,
    queries,
    compressed_base,
    compressed_queries,
    num_bits,
    ks=[1, 5, 10, 15, 20, 25, 30],
    Kprimes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 1024]
):
    """
    Applies recall_tables_by_k to all queries.

    Returns:
        dict_df_recalls: Dictionary with a DataFrame for each k, where columns are K' and rows are queries.
    """
    from collections import defaultdict
    from tqdm import tqdm

    # Dictionary of lists for each k
    all_recalls_by_k = defaultdict(list)

    for idx_query in tqdm(range(queries.shape[0]), desc="Calculating min_Ks and recall"):
        df_recalls = recall_tables_by_k(
            idx_query,
            base,
            queries,
            compressed_base,
            compressed_queries,
            num_bits,
            ks=ks,
            Kprimes=Kprimes
        )

        # Save each DataFrame in its corresponding list
        for k in ks:
            all_recalls_by_k[k].append(df_recalls[k])

    # Concatenate results by k
    dict_df_recalls = {
        k: pd.concat(all_recalls_by_k[k], ignore_index=True)
        for k in ks
    }

    return dict_df_recalls


Little example:

In [36]:
num_bits=6

# save_path = os.path.join(results_dir, f'compressed_rotated_base_{num_bits}b.pt')
# compressed_rotated_base = torch.load(save_path, map_location=torch.device('cpu'))
# save_path = os.path.join(results_dir, f'compressed_rotated_queries_{num_bits}b.pt')
# compressed_rotated_queries = torch.load(save_path, map_location=torch.device('cpu'))


subset_compressed_base = {}
for key, value in compressed_rotated_base.items():
    if isinstance(value, torch.Tensor) and value.ndim > 1:
        subset_compressed_base[key] = value[:1000, :]
    else:
        subset_compressed_base[key] = value[:1000]


subset_compressed_queries = {}
for key, value in compressed_rotated_queries.items():
     if isinstance(value, torch.Tensor) and value.ndim > 1:
         subset_compressed_queries[key] = value[:100, :]
     else:
         subset_compressed_queries[key] = value[:100]

# Evaluar para una sola query
recalls = recall_tables_by_k(
     idx_query=0,
     base=base[:1000,:],
     queries=queries[:100,:],
     compressed_base=subset_compressed_base,
     compressed_queries=subset_compressed_queries,
     num_bits=num_bits
)


display(recalls)

{1:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    1    1    1     1     1     1      1      1       1,
 5:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    2    4    5     5     5     5      5      5       5,
 10:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    2    4    8    10    10    10     10     10      10,
 15:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    2    4    8    15    15    15     15     15      15,
 20:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    2    4    8    16    20    20     20     20      20,
 25:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    2    4    8    16    25    25     25     25      25,
 30:    idx  @'1  @'2  @'4  @'8  @'16  @'32  @'64  @'128  @'256  @'1024
 0    0    1    2    4    8    16    30    30     30     30      30}

In [37]:
dict_df_recalls =  recall_k_for_all_queries(
                            base[:1000,:],
                            queries[:100,:],
                            subset_compressed_base,
                            subset_compressed_queries,
                            num_bits
                        )


display(dict_df_recalls.keys())

Calculating min_Ks and recall: 100%|██████████| 100/100 [00:01<00:00, 92.35it/s]


dict_keys([1, 5, 10, 15, 20, 25, 30])

### Subsampled dot product functions

In [38]:
import torch
from torch.func import vmap
import math

def get_fixed_indices(D, k, seed=24):
    """
    Returns a fixed subset of indices of size k within a space of dimension D,
    generated using a fixed seed.
    """
    torch.manual_seed(seed)
    return torch.randperm(D)[:k]

def orthoquant_dot_product_subsampled(compressed_base, compressed_queries, idx_query, num_bits, subsample_factor):
    """
    Computes the normalized dot product between the compressed rotated query
    (given by idx_query) and the entire compressed rotated base, using a single
    average values vector (avg_values) for the whole base.

    Returns:
        sim: tensor of shape (n_base,) with the normalized dot products.
    """
    # Decode the base
    packed_binary_matrix = compressed_base['packed_binary_matrix']
    og_shape_bin = compressed_base['og_shape_bin']
    base_indices = orthoQuant_decode_indices(packed_binary_matrix, og_shape_bin, num_bits)
    base_avg_values = compressed_base['avg_values']  # <-- vector of shape (d,)
    base_outliers = compressed_base['outliers']  # (n_base, D)
    n_base, D = base_indices.shape

    # Decode the query
    query_indices = unpack_vector_indices(compressed_queries, idx_query, num_bits)     # (D,)
    query_avg_values = compressed_queries['avg_values']                    # (num_bins,)
    query_outliers = compressed_queries['outliers'][idx_query]                # (D,)

    # Expand query to match base
    query_indices_exp = query_indices.unsqueeze(0).expand(n_base, -1)
    query_outliers_exp = query_outliers.unsqueeze(0).expand(n_base, -1)

    # Map values for base: use the same avg_values vector for all
    base_vals = base_avg_values[base_indices]  # (n_base, D)

    # Map values for the query
    query_vals = query_avg_values[query_indices_exp]

    # Replace zeros with outliers
    base_vals[base_vals == 0] = base_outliers.flatten()
    query_vals[query_vals == 0] = query_outliers_exp.flatten()

    # Get fixed dim/subsample_factor dimensions
    k = D // subsample_factor
    selected_dims = get_fixed_indices(D, k)

    base_vals = base_vals[:, selected_dims]
    query_vals = query_vals[:, selected_dims]

    # Dot product and normalization
    dot = (base_vals * query_vals).sum(dim=1)
    norm_base = torch.norm(base_vals, dim=1)
    norm_query = torch.norm(query_vals, dim=1)

    #print('dot')

    similarities = dot / (norm_base * norm_query + 1e-8)

    return similarities


In [39]:
def recall_tables_by_k_subsampled(
    idx_query,
    base,
    queries,
    compressed_base,
    compressed_queries,
    num_bits,
    subsample_factor,
    ks=[1, 5, 10, 15, 20, 25, 30],
    Kprimes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 1024],
    eps=1e-8,
    batch_size=100000
):
    """
    Computes, for a given query, how many of the true top-k elements are present
    in the rotated top-K' results.

    Returns:
        recall_tables: dictionary {k: DataFrame with 1 row and columns [idx, K'=...]}
    """

    # === True similarity ===
    q = queries[idx_query].unsqueeze(0)
    dot = torch.matmul(q, base.T)
    norm_q = torch.norm(q, dim=1, keepdim=True)
    norm_base = torch.norm(base, dim=1, keepdim=True).T
    sim = dot / (norm_q * norm_base + eps)
    idx_real = torch.argsort(sim, dim=1, descending=True).squeeze(0)

    # === Rotated similarity in batches ===
    sim_rot_batches = []
    num_vectors = compressed_base['outliers'].shape[0]

    for i in range(0, num_vectors, batch_size):
        batch_compressed_base = {
            'outliers': compressed_base['outliers'][i:i + batch_size, :].clone(),
            'avg_values': compressed_base['avg_values'].clone(),
            'packed_binary_matrix': compressed_base['packed_binary_matrix'][i:i + batch_size, :].clone(),
            'og_shape_bin': compressed_base['og_shape_bin'],
        }

        sim_batch = orthoquant_dot_product_subsampled(
            batch_compressed_base,
            compressed_queries,
            idx_query,
            num_bits,
            subsample_factor
        )

        sim_rot_batches.append(sim_batch)

    sim_rot = torch.cat(sim_rot_batches, dim=0)
    idx_rot = torch.argsort(sim_rot, descending=True)

    # === Compute recall table for each k ===
    recall_tables = {}

    for k in ks:
        top_k_real = set(idx_real[:k].tolist())
        row = []

        for Kp in Kprimes:
            top_Kp_rot = set(idx_rot[:Kp].tolist())
            recall = len(top_k_real.intersection(top_Kp_rot))
            row.append(recall)

        df = pd.DataFrame([row], columns=[f"@'{Kp}" for Kp in Kprimes])
        df.insert(0, 'idx', idx_query)
        recall_tables[k] = df

    return recall_tables


In [40]:
def recall_tables_by_k_subsampled_all_queries(
    base,
    queries,
    compressed_base,
    compressed_queries,
    num_bits,
    subsample_factor,
    ks=[1, 5, 10, 15, 20, 25, 30],
    Kprimes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 1024]
):
    """
    Applies recall_tables_by_k_subsampled to all queries.

    Returns:
        dict_df_recalls: Dictionary with a DataFrame for each k, where columns are K' and rows are queries.
    """
    from collections import defaultdict
    from tqdm import tqdm

    # Dictionary of lists for each k
    all_recalls_by_k = defaultdict(list)

    for idx_query in tqdm(range(queries.shape[0]), desc="Calculating recalls"):
        df_recalls = recall_tables_by_k_subsampled(
            idx_query,
            base,
            queries,
            compressed_base,
            compressed_queries,
            num_bits,
            subsample_factor,
            ks=ks,
            Kprimes=Kprimes
        )

        # Store each DataFrame in its corresponding list
        for k in ks:
            all_recalls_by_k[k].append(df_recalls[k])

    # Concatenate results per k
    dict_df_recalls = {
        k: pd.concat(all_recalls_by_k[k], ignore_index=True)
        for k in ks
    }

    return dict_df_recalls


Example:

In [41]:
dict_df_recalls_subsampled =  recall_tables_by_k_subsampled_all_queries(
                            base[:1000,:],
                            queries[:100,:],
                            subset_compressed_base,
                            subset_compressed_queries,
                            num_bits,
                            subsample_factor=2
                        )


display(dict_df_recalls_subsampled.keys())

Calculating recalls: 100%|██████████| 100/100 [00:01<00:00, 93.21it/s]


dict_keys([1, 5, 10, 15, 20, 25, 30])

### Full dot product experiment

In [None]:
import torch
import os

# Set seed for reproducibility
torch.manual_seed(42)

# Assuming queries has shape (N, D)
num_queries = queries.shape[0]

# Randomly select 10,000 unique indices
sample_indices = torch.randperm(num_queries)[:10000]

# Optional: save the selected queries
sampled_queries = queries[sample_indices]

# sample_indices is a tensor of shape (10000,) that you can save
save_path = os.path.join(results_dir, f'sampled_query_indices_10k.pt')
torch.save(sample_indices, save_path)

# To reload later:
# save_path = os.path.join(results_dir, f'sampled_query_indices_10k.pt')
# sample_indices = torch.load(save_path)
# sampled_queries = queries[sample_indices]

In [42]:
num_bits = 6

save_path = os.path.join(results_dir, f'sampled_query_indices_10k.pt')
sample_indices = torch.load(save_path)
sampled_queries = queries[sample_indices]

# save_path = os.path.join(results_dir, f'compressed_rotated_base_{num_bits}b.pt')
# compressed_rotated_base = torch.load(save_path, map_location=torch.device('cpu'))
# save_path = os.path.join(results_dir, f'compressed_rotated_queries_{num_bits}b.pt')
# compressed_rotated_queries = torch.load(save_path, map_location=torch.device('cpu'))

subset_compressed_queries = {}
for key, value in compressed_rotated_queries.items():
    if isinstance(value, torch.Tensor) and value.ndim > 1:
        subset_compressed_queries[key] = value[sample_indices, :]  # filas y todas las columnas
    else:
        subset_compressed_queries[key] = value  # copiar otros objetos tal cual


# Size of the original rotated_base
original_size = get_size_in_bytes(rotated_base)
original_size_gb = original_size / 1e9
original_size_mb = original_size / 1e6
print(f"\nTotal memory size of rotated base: {original_size} bytes ({original_size_mb:.2f} MB, {original_size_gb:.6f} GB)")

# Size of the compressed rotated_base
compressed_size = get_size_in_bytes(compressed_rotated_base)
compressed_size_gb = compressed_size / 1e9
compressed_size_mb = compressed_size / 1e6
print(f"\nTotal memory size of compressed rotated base: {compressed_size} bytes ({compressed_size_mb:.2f} MB, {compressed_size_gb:.6f} GB)")

# Compression ratio
compression_ratio = 100 * (1 - compressed_size / original_size)
print(f"\nCompression ratio: {compression_ratio:.2f}%")



Total memory size of rotated base: 927227904 bytes (927.23 MB, 0.927228 GB)

Total memory size of compressed rotated base: 260783406 bytes (260.78 MB, 0.260783 GB)

Compression ratio: 71.87%


In [None]:
def move_compressed_data_to_device(compressed_data, device):
    return {k: v.to(device=device) if torch.is_tensor(v) else v
            for k, v in compressed_data.items()}

device = 'cuda'
compressed_rotated_base = move_compressed_data_to_device(compressed_rotated_base, device)
subset_compressed_queries = move_compressed_data_to_device(subset_compressed_queries, device)

dict_df_recalls =  recall_k_for_all_queries(
                            base.to(device=device),
                            sampled_queries.to(device=device),
                            compressed_rotated_base,
                            subset_compressed_queries,
                            num_bits
                        )

display(dict_df_recalls.keys())

In [55]:
save_path = os.path.join(results_dir, f'dict_df_recalls_{num_bits}b.pt')
#torch.save(dict_df_recalls, save_path)
dict_df_recalls = torch.load(save_path, map_location=torch.device('cpu'), weights_only=False)

### Subsampled dot product experiment

In [None]:
num_bits = 6

dict_df_recalls =  recall_tables_by_k_subsampled_all_queries(
                            base.to(device=device),
                            sampled_queries.to(device=device),
                            compressed_rotated_base,
                            subset_compressed_queries,
                            num_bits,
                            subsample_factor=2
                        )

display(dict_df_recalls.keys())

In [50]:
save_path = os.path.join(results_dir, f'sampled_dict_df_recalls_{num_bits}b.pt')
#torch.save(dict_df_recalls, save_path)
dict_df_recalls = torch.load(save_path, map_location=torch.device('cpu'), weights_only=False)

## **5. Results**

### Functions

In [51]:
def load_and_clean_dict_df_recalls(results_dir, filename):
    """
    Loads a .pt file containing a dictionary of DataFrames,
    removes the 'idx' column if it exists, and cleans quotes from column names.

    Args:
        results_dir (str): Path to the directory where the file is located.
        filename (str): Name of the .pt file containing the dictionary.

    Returns:
        dict[str, pd.DataFrame]: A cleaned dictionary of DataFrames.
    """
    import os
    import torch
    import pandas as pd  # Ensure pandas is available

    save_path = os.path.join(results_dir, filename)
    dict_df_recalls = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))

    for k in dict_df_recalls:
        df = dict_df_recalls[k]

        # Remove 'idx' column if it exists
        if 'idx' in df.columns:
            df = df.drop(columns=['idx'])

        # Clean quotation marks from column names
        clean_cols = {
            col: str(col).replace("'", "").replace('"', "") for col in df.columns
        }

        dict_df_recalls[k] = df.rename(columns=clean_cols)

    return dict_df_recalls


In [62]:
import pandas as pd

def recall_at_k_table(results_dir, filename, num_bits):
    """
    Compares recall@k values (e.g., recall@1) across three files with different subsampling levels.

    Args:
        results_dir (str): Directory where the results are stored.
        filename (str): Name of the file containing the recall data.
        num_bits (int): Number of bits used in the filename (not used in this function directly).

    Returns:
        pd.DataFrame: Comparative recall@k table.
    """
    data = {}
    dict_df = load_and_clean_dict_df_recalls(results_dir, filename)

    for k in dict_df.keys():
        df = dict_df[k]
        data[k] = df.mean() / k  # Compute average recall and normalize by k

    result_df = pd.DataFrame(data).T  # rows = different subsampling settings, columns = @k
    return result_df

### Results

Full dot product

In [63]:
num_bits = 6
filename = f'dict_df_recalls_{num_bits}b.pt'
results_df = recall_at_k_table(results_dir, filename, num_bits)

display(results_df)

Unnamed: 0,@1,@2,@4,@8,@16,@32,@64,@128,@256,@1024
1,0.9201,0.9734,0.99,0.996,0.9988,0.9998,0.9999,0.9999,1.0,1.0
5,0.19928,0.39726,0.7773,0.9867,0.99716,0.9995,0.99986,0.9999,1.0,1.0
10,0.09993,0.19974,0.39886,0.78602,0.99288,0.9988,0.99967,0.99972,0.99991,1.0
15,0.066633,0.133253,0.26642,0.531833,0.968467,0.997707,0.999313,0.999607,0.999873,1.0
20,0.049985,0.099965,0.199895,0.399605,0.791845,0.996185,0.999,0.99947,0.99982,0.999995
25,0.039988,0.079976,0.159952,0.3198,0.637884,0.992008,0.998552,0.999336,0.999764,0.999984
30,0.033327,0.066657,0.133313,0.266587,0.53241,0.974393,0.99812,0.999217,0.9997,0.999967


Sampled dot product

In [64]:
num_bits = 6
filename = f'sampled_dict_df_recalls_{num_bits}b.pt'
results_df = recall_at_k_table(results_dir, filename, num_bits)

display(results_df)

Unnamed: 0,@1,@2,@4,@8,@16,@32,@64,@128,@256,@1024
1,0.6625,0.7926,0.8822,0.9334,0.9689,0.9855,0.9939,0.9975,0.9991,1.0
5,0.17786,0.33064,0.56386,0.76376,0.87398,0.93562,0.9689,0.9857,0.9945,0.9994
10,0.09324,0.17974,0.33271,0.56157,0.75816,0.8705,0.93508,0.96985,0.98765,0.99853
15,0.06356,0.123727,0.234273,0.419033,0.653887,0.809927,0.9019,0.95302,0.97994,0.99746
20,0.0482,0.09455,0.18092,0.331665,0.554565,0.75219,0.86862,0.93537,0.971345,0.996075
25,0.038832,0.076472,0.14754,0.2745,0.475536,0.697508,0.83668,0.917772,0.962944,0.994448
30,0.032527,0.064203,0.124493,0.23418,0.414923,0.644817,0.80512,0.90045,0.954013,0.99292
