In [2]:
import numpy as np
import time
import os
from tqdm import tqdm
import pyarrow.parquet as pq
from sklearn.neighbors import KDTree
import json

# --- Configuration ---
DATA_FILE = "ad_features_all_gt1_14_08.parquet"
OUTPUT_FILE = "tde_nearest_neighbors_results_all_without_mean_variance.json"

# Target Object IDs for different categories can be uncommented as needed.
# TDE
TARGET_OBJECT_IDS = ['ZTF22abegjtx', 'ZTF22abkfhua', 'ZTF23aapyidj', 'ZTF23abohtqf', 'ZTF24aaecooj', 'ZTF24aajvvhj', 'ZTF24aakaiha', 'ZTF24aamfius', 'ZTF24aatxshz', 'ZTF24abfaake']
# SLNS
#TARGET_OBJECT_IDS = ['ZTF25aanxtou', 'ZTF25aaixrfr', 'ZTF24abvftmi', 'ZTF24abdhylt', 'ZTF24aaysowl', 'ZTF23aboebgh', 'ZTF23abcvbqq', 'ZTF23aanptpp', 'ZTF22abcvfgs', 'ZTF21ackxdos']
# CV
#TARGET_OBJECT_IDS = ['ZTF18abjcuhv', 'ZTF18aaudzmj', 'ZTF19aagxcga', 'ZTF18aaashju', 'ZTF21aagtpxy', 'ZTF18aboacjl', 'ZTF17aaantxj', 'ZTF18aajwgru', 'ZTF18abigrzf', 'ZTF18adibadz']

NUM_NEIGHBORS = 10

FEATURE_NAMES = [
    'mean_g', 'weighted_mean_g',
    'standard_deviation_g', 'median_g', 'amplitude_g', 'beyond_1_std_g',
    'cusum_g', 'inter_percentile_range_10_g', 'kurtosis_g', 'linear_trend_g',
    'linear_trend_sigma_g', 'linear_trend_noise_g', 'linear_fit_slope_g',
    'linear_fit_slope_sigma_g', 'linear_fit_reduced_chi2_g',
    'magnitude_percentage_ratio_40_5_g', 'magnitude_percentage_ratio_20_10_g',
    'maximum_slope_g', 'median_absolute_deviation_g',
    'median_buffer_range_percentage_10_g', 'percent_amplitude_g',
    'anderson_darling_normal_g', 'chi2_g', 'skew_g', 'stetson_K_g', 'mean_r',
    'weighted_mean_r', 'standard_deviation_r', 'median_r', 'amplitude_r',
    'beyond_1_std_r', 'cusum_r', 'inter_percentile_range_10_r', 'kurtosis_r',
    'linear_trend_r', 'linear_trend_sigma_r', 'linear_trend_noise_r',
    'linear_fit_slope_r', 'linear_fit_slope_sigma_r', 'linear_fit_reduced_chi2_r',
    'magnitude_percentage_ratio_40_5_r', 'magnitude_percentage_ratio_20_10_r',
    'maximum_slope_r', 'median_absolute_deviation_r',
    'median_buffer_range_percentage_10_r', 'percent_amplitude_r',
    'anderson_darling_normal_r', 'chi2_r', 'skew_r', 'stetson_K_r', 'distnr'
]

def load_and_prepare_data_parquet(filepath):

    print(f"Step 1: Reading data from '{filepath}'...")
    try:
        pq_file = pq.ParquetFile(filepath)
    except Exception as e:
        print(f" ERROR: Could not open Parquet file '{filepath}'. {e}"); return None, None

    num_objects_initial = pq_file.metadata.num_rows
    print(f"   The file contains {num_objects_initial:,} objects.")

    object_ids_list, feature_batches = [], []
    columns_to_read = ['objectId'] + FEATURE_NAMES

    for i in tqdm(range(pq_file.num_row_groups), desc="Reading Parquet batches"):
        batch = pq_file.read_row_group(i, columns=columns_to_read)
        object_ids_list.extend(batch.column('objectId').to_pylist())
        feature_batches.append(np.column_stack([batch.column(name).to_numpy() for name in FEATURE_NAMES]))

    print("Combining batches...")
    data_matrix = np.vstack(feature_batches).astype('float32')
    object_ids = np.array(object_ids_list)
    print(f" Data loaded. Matrix size: {data_matrix.shape}")

    print("\nStep 2: Removing rows containing NaN...")
    nan_mask = np.isnan(data_matrix).any(axis=1)
    num_nan_rows = np.sum(nan_mask)

    if num_nan_rows > 0:
        print(f"   Found {num_nan_rows:,} rows with NaN values. Removing them...")
        data_matrix = data_matrix[~nan_mask]
        object_ids = object_ids[~nan_mask]
        print(f"Cleanup complete. Matrix size after cleanup: {data_matrix.shape}")
    else:
        print("No missing values (NaN) found.")

    print("\n Step 3: Normalizing data (Standardization)...")
    mean = np.mean(data_matrix, axis=0, dtype='float32')
    std = np.std(data_matrix, axis=0, dtype='float32')
    std[std == 0] = 1.0
    data_matrix = (data_matrix - mean) / std
    print("Data successfully normalized.")

    return object_ids.tolist(), data_matrix

def find_neighbors_kdtree_batch(data, query_vectors, k_candidates):
    """
    Builds a KD-Tree on the CPU and performs a batch search.
    """
    print("\nStep 4: Building KD-Tree...")
    print("   (This may take a few minutes depending on data size and CPU power)")
    start_time = time.time()
    kdt = KDTree(data, leaf_size=40, metric='euclidean')
    end_time = time.time()
    print(f"Tree built in {end_time - start_time:.2f} sec.")

    print(f"\nPerforming batch search for {len(query_vectors)} objects...")
    start_time = time.time()
    distances, indices = kdt.query(query_vectors, k=k_candidates)
    end_time = time.time()
    print(f"Search completed in {end_time - start_time:.4f} sec.")

    return distances, indices

# --- Main execution block ---
if __name__ == "__main__":
    if not os.path.exists(DATA_FILE):
        print(f"ERROR: Data file '{DATA_FILE}' not found.")
        exit()

    all_ids, norm_data = load_and_prepare_data_parquet(DATA_FILE)
    if all_ids is None:
        exit()

    print("\nStep 5: Finding target objects and preparing query vectors...")
    id_to_idx = {oid: i for i, oid in enumerate(all_ids)}

    query_vectors_list, found_target_ids = [], []
    for target_id in TARGET_OBJECT_IDS:
        if target_id in id_to_idx:
            query_vectors_list.append(norm_data[id_to_idx[target_id]])
            found_target_ids.append(target_id)
        else:
            print(f"   Warning: Target object '{target_id}' not found in the data (possibly removed due to NaN) and will be skipped.")

    if not found_target_ids:
        print("ERROR: None of the target objects were found in the cleaned data. Exiting.")
        exit()

    query_vectors = np.array(query_vectors_list, dtype='float32')
    print(f"Found {len(found_target_ids)} target objects. Preparing to search.")

    distances_raw, indices_raw = find_neighbors_kdtree_batch(norm_data, query_vectors, NUM_NEIGHBORS + 1)

    print("\n" + "="*80)
    print("NEAREST NEIGHBOR SEARCH RESULTS (KD-Tree algorithm, data cleaned of NaN)")
    print("="*80)

    results_to_save = {}

    for i, target_id in enumerate(found_target_ids):
        final_neighbors_for_print = []
        final_neighbor_ids_for_json = []
        seen_ids = {target_id}

        distances_for_target, indices_for_target = distances_raw[i], indices_raw[i]

        for j in range(len(indices_for_target)):
            if len(final_neighbor_ids_for_json) >= NUM_NEIGHBORS:
                break

            neighbor_id = all_ids[indices_for_target[j]]
            if neighbor_id in seen_ids:
                continue
            
            final_neighbor_ids_for_json.append(neighbor_id)
            
            neighbor_dist = distances_for_target[j]
            final_neighbors_for_print.append({'id': neighbor_id, 'distance': neighbor_dist})
            
            seen_ids.add(neighbor_id)
        
        results_to_save[target_id] = final_neighbor_ids_for_json

        print(f"\n--- {NUM_NEIGHBORS} nearest unique neighbors for: {target_id} ---\n")
        print(f"{'Rank':<5} {'ObjectID':<20} {'L2 Distance':<20}")
        print("-"*50)
        for rank, neighbor in enumerate(final_neighbors_for_print, 1):
            print(f"{rank:<5} {neighbor['id']:<20} {neighbor['distance']:<20.4f}")

    print("\n" + "="*80)
    print(f"Saving results to file '{OUTPUT_FILE}'...")
    try:
        with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
            json.dump(results_to_save, f, indent=4)
        print(f"Results successfully saved.")
    except Exception as e:
        print(f"ERROR: Could not save the file. {e}")

    print("\n" + "="*80)

Step 1: Reading data from 'ad_features_all_gt1_14_08.parquet'...
   The file contains 3,922,520 objects.


Reading Parquet batches: 100%|████████████████████| 1/1 [00:08<00:00,  8.02s/it]


Combining batches...
 Data loaded. Matrix size: (3922520, 51)

Step 2: Removing rows containing NaN...
   Found 3,642,757 rows with NaN values. Removing them...
Cleanup complete. Matrix size after cleanup: (279763, 51)

 Step 3: Normalizing data (Standardization)...
Data successfully normalized.

Step 5: Finding target objects and preparing query vectors...
Found 4 target objects. Preparing to search.

Step 4: Building KD-Tree...
   (This may take a few minutes depending on data size and CPU power)
Tree built in 1.69 sec.

Performing batch search for 4 objects...
Search completed in 0.0992 sec.

NEAREST NEIGHBOR SEARCH RESULTS (KD-Tree algorithm, data cleaned of NaN)

--- 100 nearest unique neighbors for: ZTF18aboacjl ---

Rank  ObjectID             L2 Distance         
--------------------------------------------------
1     ZTF18abwbjjy         15.6274             
2     ZTF17aaajeks         16.6986             
3     ZTF18aazsdnv         16.8160             
4     ZTF18absnnsr      