In [4]:
!pip install graphviz

Collecting graphviz
  Downloading graphviz-0.20.3-py3-none-any.whl (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.1/47.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: graphviz
Successfully installed graphviz-0.20.3
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
from graphviz import Digraph
import math
import heapq
import numpy as np
import os

class Node:
    def __init__(self, points=None, axis=-1, split_value=None):
        self.points = points if points is not None else None  # Points array for leaf nodes
        self.axis = axis  # Splitting axis
        self.split_value = split_value  # Value used for splitting
        self.next = [None, None]  # [left, right] children

class KDTree:
    def __init__(self, points=None, dim=None, leaf_size=16, discriminator="cyclic"):
        """
        Initialize KD-tree with optional points

        Args:
            points: Initial points to build tree (numpy array or list of Points)
            dim: Dimension of the space
            leaf_size: Maximum points in leaf nodes
            discriminator: Strategy for choosing split axis ("cyclic" or "variance")
        """
        self.dim = dim
        self.leaf_size = leaf_size
        self.discriminator = discriminator
        self.root = None

        if points is not None:
            # Convert points to numpy array if needed
            if not isinstance(points[0], np.ndarray):
                points = np.array([p.coordinates for p in points])
            if not self.dim:
              self.dim = points.shape[1]
            self.build_tree(points)

    def build_tree(self, points, depth=0):
        """Build KD-tree from points array"""
        def build(points, depth=0):
            if len(points) == 0:
                return None

            if len(points) <= self.leaf_size:
                return Node(points=points, axis=depth % self.dim)

            # Choose splitting axis
            if self.discriminator == "cyclic":
                axis = depth % self.dim
            else:  # variance-based
                axis = np.argmax(np.var(points, axis=0))

            # Sort points and find median
            idx = np.argsort(points[:, axis])
            points = points[idx]
            median_idx = len(points) // 2

            # Create node
            node = Node(
                points=None,
                axis=axis,
                split_value=points[median_idx, axis]
            )

            # Recursively build subtrees
            node.next[0] = build(points[:median_idx], depth + 1)
            node.next[1] = build(points[median_idx:], depth + 1)

            return node

        self.root = build(points)

    def insert(self, point):
        """Insert a new point into the tree"""
        point = np.asarray(point).reshape(1, -1)

        def insert_recursive(node, point, depth=0):
            if node is None:
                return Node(points=point, axis=depth % self.dim)

            if node.points is not None:
                # Leaf node
                if len(node.points) < self.leaf_size:
                    node.points = np.vstack([node.points, point])
                    return node

                # Split leaf node
                all_points = np.vstack([node.points, point])
                return self.build_tree(all_points, depth)

            # Internal node
            if point[0, node.axis] < node.split_value:
                node.next[0] = insert_recursive(node.next[0], point, depth + 1)
            else:
                node.next[1] = insert_recursive(node.next[1], point, depth + 1)
            return node

        self.root = insert_recursive(self.root, point)

    def nn_search(self, query):
        """Find nearest neighbor and its distance to query point"""
        query = np.asarray(query)

        def search_recursive(node, best_dist=float('inf'), best_point=None):
            if node is None:
                return best_point, best_dist

            if node.points is not None:
                # Leaf node: check all points
                dists = np.sum((node.points - query) ** 2, axis=1)
                min_idx = np.argmin(dists)
                if dists[min_idx] < best_dist:
                    return node.points[min_idx], dists[min_idx]
                return best_point, best_dist

            # Internal node
            diff = query[node.axis] - node.split_value
            near, far = node.next if diff >= 0 else node.next[::-1]

            # Search near subtree
            best_point, best_dist = search_recursive(near, best_dist, best_point)

            # Check if we need to search far subtree
            if diff ** 2 < best_dist:
                best_point, best_dist = search_recursive(far, best_dist, best_point)

            return best_point, best_dist

        point, dist = search_recursive(self.root)
        return point, math.sqrt(dist)

    def k_nearest_neighbors(self, query_point, k):
        """Find k nearest neighbors to query point"""
        query_point = np.asarray(query_point)

        def search(node, heap):
            if node is None:
                return

            if node.points is not None:
                # Process all points in leaf
                dists = np.sum((node.points - query_point) ** 2, axis=1)
                for dist, point in zip(dists, node.points):
                    if len(heap) < k:
                        heapq.heappush(heap, (-dist, tuple(point)))
                    elif -dist > heap[0][0]:
                        heapq.heappushpop(heap, (-dist, tuple(point)))
                return

            # Internal node
            if query_point[node.axis] < node.split_value:
                near, far = node.next
            else:
                near, far = node.next[::-1]

            search(near, heap)

            if len(heap) < k or (query_point[node.axis] - node.split_value) ** 2 < -heap[0][0]:
                search(far, heap)

        heap = []
        search(self.root, heap)
        return [np.array(point) for _, point in sorted(heap, key=lambda x: -x[0])]

    def radius_search(self, query, radius):
        """Find all points within radius of query point"""
        query = np.asarray(query)
        results = []

        def in_radius(points):
            return np.sum((points - query) ** 2, axis=1) <= radius ** 2

        def search_recursive(node):
            if node is None:
                return

            if node.points is not None:
                # Leaf node: check all points
                mask = in_radius(node.points)
                results.extend(node.points[mask])
                return

            # Internal node
            axis_dist = query[node.axis] - node.split_value

            if axis_dist < radius:
                search_recursive(node.next[0])
            if axis_dist > -radius:
                search_recursive(node.next[1])

        search_recursive(self.root)
        return results

    def visualize_tree(self, output_filename="kdtree"):
        """Visualize the KD-tree structure"""
        dot = Digraph(format='png')
        dot.attr('node', shape='circle')

        def add_nodes(node, parent=None):
            if node is None:
                return

            node_id = str(id(node))

            if node.points is not None:
                points_str = "\n".join(map(str, node.points[:3]))
                if len(node.points) > 3:
                    points_str += f"\n... ({len(node.points)} total)"
                label = f"Leaf\n{points_str}\nAxis: {node.axis}"
            else:
                label = f"Split: {node.split_value:.2f}\nAxis: {node.axis}"

            dot.node(node_id, label)

            if parent:
                dot.edge(parent, node_id)

            if not node.points:  # Internal node
                add_nodes(node.next[0], node_id)
                add_nodes(node.next[1], node_id)

        add_nodes(self.root)

        # Save visualization
        os.makedirs("images", exist_ok=True)
        output_path = os.path.join("images", output_filename)
        dot.render(output_path, cleanup=True)
        return dot

In [6]:
import h5py
import pandas as pd
import os

# Open the HDF5 file
with h5py.File('/kaggle/input/treedata/gist-960-euclidean.hdf5', 'r') as f:
    # List all keys (datasets) in the file
    print("Keys in the HDF5 file:", list(f.keys()))

    # Read the training set, queries, ground truth, and distances
    train_data = f['train'][:]  # Assuming 'train' is the key for training data
    query_data = f['test'][:]  # Assuming 'test' is the key for query data
    ground_truth = f['neighbors'][:]  # Assuming 'neighbors' is the key for ground truth neighbors
    distance = f['distances'][:]  # Assuming 'distances' is the key for ground truth distances

# Inspect the data
print("Training data shape:", train_data.shape)
print("Query data shape:", query_data.shape)
print("Ground truth shape:", ground_truth.shape)
print("Distance truth shape:",distance.shape)

Keys in the HDF5 file: ['distances', 'neighbors', 'test', 'train']
Training data shape: (1000000, 960)
Query data shape: (1000, 960)
Ground truth shape: (1000, 100)
Distance truth shape: (1000, 100)


In [7]:
tree = KDTree(train_data, dim=960)

In [10]:
query_points = query_data

k = 100

neighbors=[]
# Iterate over the selected query points
for i, query_point in enumerate(query_points):
    print(f"Processing query point {i+1}")
    # Get the k nearest neighbors using the KDTree
    neighbors.append(tree.k_nearest_neighbors(query_point,k))


Processing query point 1
Processing query point 2
Processing query point 3
Processing query point 4
Processing query point 5
Processing query point 6
Processing query point 7
Processing query point 8
Processing query point 9
Processing query point 10
Processing query point 11
Processing query point 12
Processing query point 13
Processing query point 14
Processing query point 15
Processing query point 16
Processing query point 17
Processing query point 18
Processing query point 19
Processing query point 20
Processing query point 21
Processing query point 22
Processing query point 23
Processing query point 24
Processing query point 25
Processing query point 26
Processing query point 27
Processing query point 28
Processing query point 29
Processing query point 30
Processing query point 31
Processing query point 32
Processing query point 33
Processing query point 34
Processing query point 35
Processing query point 36
Processing query point 37
Processing query point 38
Processing query poin

In [12]:
import numpy as np
from collections import defaultdict
from typing import Dict, List, Tuple, Any, Union
import numpy.typing as npt

def vectorTOindex(train_data: Union[List, npt.NDArray],
                 decimals: int = 10) -> Dict[Tuple[float, ...], int]:
    """
    Create a hashmap of vectors to their indices using vectorized operations.

    Args:
        train_data: Training data array or list
        decimals: Number of decimal places to round to

    Returns:
        Dictionary mapping vector tuples to their indices
    """
    # Convert to numpy array if input is list
    if isinstance(train_data, list):
        train_data = np.array(train_data)

    # Round all vectors at once
    rounded_data = np.round(train_data, decimals=decimals)

    # Create dictionary with vectorized operation
    return dict(zip(map(tuple, rounded_data), range(len(rounded_data))))

def FindIndices(query_vectors: Union[List, npt.NDArray],
               vector_map: Dict[Tuple[float, ...], int],
               decimals: int = 10,
               batch_size: int = 1000) -> npt.NDArray[np.int64]:
    """
    Find indices for vectors using batched processing.

    Args:
        query_vectors: Array or list of query vectors
        vector_map: Dictionary mapping vectors to indices
        decimals: Number of decimal places for rounding
        batch_size: Size of batches for processing

    Returns:
        Array of indices
    """
    # Convert to numpy array if input is list
    if isinstance(query_vectors, list):
        query_vectors = np.array(query_vectors)

    num_vectors = len(query_vectors)
    indices = np.full(num_vectors, -1, dtype=np.int64)

    # Process in batches to optimize memory usage
    for i in range(0, num_vectors, batch_size):
        batch = query_vectors[i:i + batch_size]
        rounded_batch = np.round(batch, decimals=decimals)

        # Vectorized tuple conversion and lookup
        batch_indices = [vector_map.get(tuple(v), -1) for v in rounded_batch]
        indices[i:i + len(batch)] = batch_indices

    return indices

def KNNAccuracy(knn_results: Union[List, npt.NDArray],
                ground_truth_indices: Union[List, npt.NDArray],
                train_data: Union[List, npt.NDArray],
                decimals: int = 10) -> float:
    """
    Compute KNN accuracy with optimized vectorized operations.

    Args:
        knn_results: Array/list of KNN results (num_queries x k x dimensions)
        ground_truth_indices: Ground truth indices (num_queries x k)
        train_data: Training data array/list
        decimals: Number of decimal places for rounding

    Returns:
        Accuracy score between 0 and 1
    """
    # Convert inputs to numpy arrays if they're lists
    if isinstance(knn_results, list):
        knn_results = np.array(knn_results)
    if isinstance(ground_truth_indices, list):
        ground_truth_indices = np.array(ground_truth_indices)
    if isinstance(train_data, list):
        train_data = np.array(train_data)

    # Get dimensions
    num_queries, num_neighbors = knn_results.shape[:2]

    # Create vector map once and share it
    vector_map = vectorTOindex(train_data, decimals)

    # Reshape knn_results for batch processing
    flattened_results = knn_results.reshape(-1, knn_results.shape[-1])

    # Find all indices in one batch operation
    all_indices = FindIndices(flattened_results, vector_map, decimals)
    predicted_indices = all_indices.reshape(num_queries, num_neighbors)

    # Compute accuracy using vectorized operations
    correct_matches = np.zeros(num_queries, dtype=np.int64)

    # Vectorized intersection computation
    for i in range(num_queries):
        correct_matches[i] = np.intersect1d(
            predicted_indices[i],
            ground_truth_indices[i],
            assume_unique=True
        ).size

    # Calculate final accuracy
    total_predictions = num_queries * num_neighbors
    accuracy = np.sum(correct_matches) / total_predictions

    return accuracy


accuracy = KNNAccuracy(neighbors, ground_truth, train_data)


In [16]:
print(f"Accuracy: {accuracy * 100:.4f}%")

Accuracy: 99.9950%


In [8]:
import numpy as np
import time
import matplotlib.pyplot as plt
from typing import List, Tuple
from tqdm import tqdm
import psutil
import gc

def get_memory_usage():
    """Get current memory usage in MB"""
    process = psutil.Process()
    return process.memory_info().rss / (1024 * 1024)

def benchmark_construction(points: np.ndarray, variants: List[Tuple[str, dict]]) -> tuple:
    """Benchmark tree construction for different variants."""
    construction_times = {}
    memory_usage = {}

    for name, params in variants:
        # Force garbage collection before each test
        gc.collect()
        initial_memory = get_memory_usage()

        # Measure construction time
        start_time = time.time()
        _ = KDTree(points, **params)
        construction_time = time.time() - start_time

        # Measure memory usage
        final_memory = get_memory_usage()
        memory_delta = final_memory - initial_memory

        construction_times[name] = construction_time
        memory_usage[name] = memory_delta

    return construction_times, memory_usage

def benchmark_queries(trees: List[Tuple[str, KDTree]],
                     query_points: np.ndarray,
                     k: int,
                     radius: float) -> Tuple[dict, dict]:
    """Benchmark kNN and range queries."""
    knn_results = {}
    range_results = {}

    for name, tree in trees:
        knn_times = []
        range_times = []

        # Use smaller batch for progress tracking
        for query in tqdm(query_points, desc=f"Benchmarking {name}", miniters=50):
            # Benchmark kNN search
            start_time = time.time()
            _ = tree.k_nearest_neighbors(query, k)
            knn_times.append(time.time() - start_time)

            # Benchmark range search
            start_time = time.time()
            _ = tree.radius_search(query, radius)
            range_times.append(time.time() - start_time)

        knn_results[name] = knn_times
        range_results[name] = range_times

    return knn_results, range_results

def plot_results(construction_times: dict,
                memory_usage: dict,
                knn_times: dict,
                range_times: dict,
                output_filename: str = "kdtree_benchmark"):
    """Create plots comparing the performance of different variants."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    names = list(construction_times.keys())

    # Plot construction times
    times = [construction_times[name] for name in names]
    ax1.bar(names, times)
    ax1.set_xlabel('Tree Variant')
    ax1.set_ylabel('Construction Time (s)')
    ax1.set_title('Tree Construction Time')
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)
    ax1.grid(True)

    # Plot memory usage
    memory = [memory_usage[name] for name in names]
    ax2.bar(names, memory)
    ax2.set_xlabel('Tree Variant')
    ax2.set_ylabel('Memory Usage (MB)')
    ax2.set_title('Additional Memory Usage')
    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)
    ax2.grid(True)

    # Plot kNN query times
    bp1 = ax3.boxplot([knn_times[name] for name in names],
                      labels=names,
                      showfliers=False)
    ax3.set_xlabel('Tree Variant')
    ax3.set_ylabel('Query Time (s)')
    ax3.set_title('k-NN Query Time')
    plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45)
    ax3.grid(True)

    # Plot range query times
    bp2 = ax4.boxplot([range_times[name] for name in names],
                      labels=names,
                      showfliers=False)
    ax4.set_xlabel('Tree Variant')
    ax4.set_ylabel('Query Time (s)')
    ax4.set_title('Range Query Time')
    plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45)
    ax4.grid(True)

    plt.tight_layout()
    plt.savefig(f"{output_filename}.png", dpi=300, bbox_inches='tight')
    plt.close()

def run_benchmark(train_data: np.ndarray, query_data: np.ndarray):
    """Run complete benchmark suite."""
    # Parameters adjusted for large-scale dataset
    n_queries = 100  # Reduced number of queries for large dataset
    k = 10  # Increased k for better comparison
    radius = 0.1  # Radius for range search

    # Define tree variants with larger leaf sizes for better performance
    variants = [
        ("Cyclic-32", {"discriminator": "cyclic", "leaf_size": 32}),
        ("Cyclic-64", {"discriminator": "cyclic", "leaf_size": 64}),
        ("Variance-32", {"discriminator": "variance", "leaf_size": 32}),
        ("Variance-64", {"discriminator": "variance", "leaf_size": 64}),
    ]

    print(f"\nBenchmarking {train_data.shape[1]}-dimensional data")
    print(f"Training set size: {len(train_data):,} points")
    print(f"Query set size: {len(query_data):,} points")

    # Benchmark construction
    print("\nBenchmarking tree construction...")
    construction_times, memory_usage = benchmark_construction(train_data, variants)

    # Create trees for query benchmarking
    print("\nCreating trees for query benchmarking...")
    trees = []
    for name, params in variants:
        tree = KDTree(train_data, **params)
        trees.append((name, tree))

    # Benchmark queries
    print("\nBenchmarking queries...")
    query_points = query_data[:n_queries]  # Use subset of query points
    knn_times, range_times = benchmark_queries(trees, query_points, k, radius)

    # Plot results
    print("\nGenerating plots...")
    plot_results(construction_times, memory_usage, knn_times, range_times,
                f"kdtree_benchmark_{train_data.shape[1]}d_{len(train_data)}points")

    # Print summary statistics
    print(f"\nPerformance Summary:")
    print("\nConstruction metrics:")
    for name in construction_times:
        print(f"{name}:")
        print(f"  Time: {construction_times[name]:.2f} seconds")
        print(f"  Memory: {memory_usage[name]:.2f} MB")

    print("\nQuery metrics:")
    for name in knn_times:
        print(f"{name}:")
        print(f"  k-NN (avg): {np.mean(knn_times[name])*1000:.2f} ms")
        print(f"  k-NN (std): {np.std(knn_times[name])*1000:.2f} ms")
        print(f"  Range (avg): {np.mean(range_times[name])*1000:.2f} ms")
        print(f"  Range (std): {np.std(range_times[name])*1000:.2f} ms")

query_data=query_data[:100]
run_benchmark(train_data, query_data)


Benchmarking 960-dimensional data
Training set size: 1,000,000 points
Query set size: 100 points

Benchmarking tree construction...

Creating trees for query benchmarking...

Benchmarking queries...


Benchmarking Cyclic-32: 100%|██████████| 100/100 [04:59<00:00,  2.99s/it]
Benchmarking Cyclic-64: 100%|██████████| 100/100 [05:02<00:00,  3.03s/it]
Benchmarking Variance-32: 100%|██████████| 100/100 [04:37<00:00,  2.78s/it]
Benchmarking Variance-64: 100%|██████████| 100/100 [04:55<00:00,  2.96s/it]



Generating plots...


  bp1 = ax3.boxplot([knn_times[name] for name in names],
  bp2 = ax4.boxplot([range_times[name] for name in names],



Performance Summary:

Construction metrics:
Cyclic-32:
  Time: 15.73 seconds
  Memory: 3857.63 MB
Cyclic-64:
  Time: 15.12 seconds
  Memory: 3815.08 MB
Variance-32:
  Time: 38.45 seconds
  Memory: 1704.14 MB
Variance-64:
  Time: 35.62 seconds
  Memory: 10.79 MB

Query metrics:
Cyclic-32:
  k-NN (avg): 1852.19 ms
  k-NN (std): 40.24 ms
  Range (avg): 1138.89 ms
  Range (std): 441.47 ms
Cyclic-64:
  k-NN (avg): 1889.25 ms
  k-NN (std): 36.03 ms
  Range (avg): 1138.35 ms
  Range (std): 425.72 ms
Variance-32:
  k-NN (avg): 1967.23 ms
  k-NN (std): 109.48 ms
  Range (avg): 808.19 ms
  Range (std): 365.09 ms
Variance-64:
  k-NN (avg): 2090.66 ms
  k-NN (std): 37.05 ms
  Range (avg): 864.46 ms
  Range (std): 373.72 ms


## **RTEES**