In [3]:
import numpy as np
import time
import os
from tqdm import tqdm
import pyarrow.parquet as pq
from sklearn.neighbors import KDTree
import json

# --- Configuration ---
DATA_FILE = "ad_features_all_gt1_14_08.parquet"

timestamp = time.strftime("%Y%m%d_%H%M%S")
OUTPUT_FILE = f"tde_nearest_neighbors_results_all_without_mean_variance_{timestamp}.json"

TARGET_OBJECT_IDS = ['ZTF22abegjtx', 'ZTF22abkfhua', 'ZTF23aapyidj', 'ZTF23abohtqf', 'ZTF24aaecooj', 'ZTF24aajvvhj', 'ZTF24aakaiha', 'ZTF24aamfius', 'ZTF24aatxshz', 'ZTF24abfaake']

NUM_NEIGHBORS = 10

FEATURE_NAMES = [
    'mean_g', 'weighted_mean_g',
    'standard_deviation_g', 'median_g', 'amplitude_g', 'beyond_1_std_g',
    'cusum_g', 'inter_percentile_range_10_g', 'kurtosis_g', 'linear_trend_g',
    'linear_trend_sigma_g', 'linear_trend_noise_g', 'linear_fit_slope_g',
    'linear_fit_slope_sigma_g', 'linear_fit_reduced_chi2_g',
    'magnitude_percentage_ratio_40_5_g', 'magnitude_percentage_ratio_20_10_g',
    'maximum_slope_g', 'median_absolute_deviation_g',
    'median_buffer_range_percentage_10_g', 'percent_amplitude_g',
    'anderson_darling_normal_g', 'chi2_g', 'skew_g', 'stetson_K_g', 'mean_r',
    'weighted_mean_r', 'standard_deviation_r', 'median_r', 'amplitude_r',
    'beyond_1_std_r', 'cusum_r', 'inter_percentile_range_10_r', 'kurtosis_r',
    'linear_trend_r', 'linear_trend_sigma_r', 'linear_trend_noise_r',
    'linear_fit_slope_r', 'linear_fit_slope_sigma_r', 'linear_fit_reduced_chi2_r',
    'magnitude_percentage_ratio_40_5_r', 'magnitude_percentage_ratio_20_10_r',
    'maximum_slope_r', 'median_absolute_deviation_g',
    'median_buffer_range_percentage_10_r', 'percent_amplitude_r',
    'anderson_darling_normal_r', 'chi2_r', 'skew_r', 'stetson_K_r', 'distnr'
]

def load_data_parquet(filepath):
    print(f"Step 1: Reading data from '{filepath}'...")
    try:
        pq_file = pq.ParquetFile(filepath)
    except Exception as e:
        print(f" ERROR: Could not open Parquet file '{filepath}'. {e}"); return None, None

    object_ids_list, feature_batches = [], []
    columns_to_read = ['objectId'] + FEATURE_NAMES

    for i in tqdm(range(pq_file.num_row_groups), desc="Reading Parquet batches"):
        batch = pq_file.read_row_group(i, columns=columns_to_read)
        object_ids_list.extend(batch.column('objectId').to_pylist())
        feature_batches.append(np.column_stack([batch.column(name).to_numpy(zero_copy_only=False) for name in FEATURE_NAMES]))

    print("Combining batches...")
    data_matrix = np.vstack(feature_batches).astype('float32')
    object_ids = np.array(object_ids_list)
    print(f" Data loaded. Matrix size: {data_matrix.shape}")
    return object_ids, data_matrix

def find_neighbors_kdtree_batch(data, query_vectors, k_candidates):
    print("\nStep 4: Building KD-Tree...")
    start_time = time.time()
    kdt = KDTree(data, leaf_size=40, metric='euclidean')
    end_time = time.time()
    print(f"Tree built in {end_time - start_time:.2f} sec.")

    print(f"\nPerforming batch search for {len(query_vectors)} objects...")
    start_time = time.time()
    distances, indices = kdt.query(query_vectors, k=k_candidates)
    end_time = time.time()
    print(f"Search completed in {end_time - start_time:.4f} sec.")

    return distances, indices

# --- Main execution block ---
if __name__ == "__main__":
    if not os.path.exists(DATA_FILE):
        print(f"ERROR: Data file '{DATA_FILE}' not found.")
        exit()

    all_ids_raw, all_data_raw = load_data_parquet(DATA_FILE)
    if all_ids_raw is None:
        exit()

    id_to_idx = {oid: i for i, oid in enumerate(all_ids_raw)}
    results_to_save = {}

    print("\n" + "="*80)
    print("NEAREST NEIGHBOR SEARCH (KD-Tree algorithm)")
    print("="*80)

    for target_id in TARGET_OBJECT_IDS:
        print(f"\n--- Processing target object: {target_id} ---")

        if target_id not in id_to_idx:
            print(f"   Warning: Target object '{target_id}' not found in the initial data. Skipping.")
            continue

        target_idx = id_to_idx[target_id]
        target_vector_raw = all_data_raw[target_idx]

        nan_features_mask_target = np.isnan(target_vector_raw)
        if np.all(nan_features_mask_target):
            print(f"   Warning: All features for '{target_id}' are NaN. Search is not possible. Skipping.")
            continue
        
        target_vector_clean = target_vector_raw[~nan_features_mask_target]
        
        print("Step 2: Filtering data based on the target object's NaN values...")
        
        data_filtered_cols = all_data_raw[:, ~nan_features_mask_target]
        
        nan_rows_mask = np.isnan(data_filtered_cols).any(axis=1)
        
        data_clean = data_filtered_cols[~nan_rows_mask]
        ids_clean = all_ids_raw[~nan_rows_mask]
        
        print(f"   Matrix size after filtering: {data_clean.shape}")

        if not any(ids_clean == target_id):
            print(f"   Warning: Target object '{target_id}' was removed during cleanup. Skipping.")
            continue
            
        print("\nStep 3: Normalizing filtered data (Standardization)...")
        mean = np.mean(data_clean, axis=0, dtype='float32')
        std = np.std(data_clean, axis=0, dtype='float32')
        std[std == 0] = 1.0
        norm_data = (data_clean - mean) / std
        print("Data successfully normalized.")
        
        norm_target_vector = ((target_vector_clean - mean) / std).reshape(1, -1)

        distances_raw, indices_raw = find_neighbors_kdtree_batch(norm_data, norm_target_vector, NUM_NEIGHBORS + 1)
        
        final_neighbors_for_print = []
        final_neighbor_ids_for_json = []
        seen_ids = {target_id}

        distances_for_target, indices_for_target = distances_raw[0], indices_raw[0]

        for j in range(len(indices_for_target)):
            if len(final_neighbor_ids_for_json) >= NUM_NEIGHBORS:
                break

            neighbor_id = ids_clean[indices_for_target[j]]
            if neighbor_id in seen_ids:
                continue
            
            final_neighbor_ids_for_json.append(neighbor_id)
            
            neighbor_dist = distances_for_target[j]
            final_neighbors_for_print.append({'id': neighbor_id, 'distance': neighbor_dist})
            
            seen_ids.add(neighbor_id)
        
        results_to_save[target_id] = final_neighbor_ids_for_json

        print(f"\n--- {NUM_NEIGHBORS} nearest unique neighbors for: {target_id} ---\n")
        print(f"{'Rank':<5} {'ObjectID':<20} {'L2 Distance':<20}")
        print("-"*50)
        for rank, neighbor in enumerate(final_neighbors_for_print, 1):
            print(f"{rank:<5} {neighbor['id']:<20} {neighbor['distance']:<20.4f}")

    print("\n" + "="*80)
    print(f"Saving results to file '{OUTPUT_FILE}'...")
    try:
        with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
            json.dump(results_to_save, f, indent=4)
        print(f"Results successfully saved.")
    except Exception as e:
        print(f"ERROR: Could not save the file. {e}")

    print("\n" + "="*80)

Step 1: Reading data from 'ad_features_all_gt1_14_08.parquet'...


Reading Parquet batches: 100%|████████████████████| 1/1 [00:06<00:00,  6.54s/it]


Combining batches...
 Data loaded. Matrix size: (3922520, 51)

NEAREST NEIGHBOR SEARCH (KD-Tree algorithm)

--- Processing target object: ZTF22abegjtx ---
Step 2: Filtering data based on the target object's NaN values...
   Matrix size after filtering: (709506, 27)

Step 3: Normalizing filtered data (Standardization)...
Data successfully normalized.

Step 4: Building KD-Tree...
Tree built in 2.54 sec.

Performing batch search for 1 objects...
Search completed in 0.0018 sec.

--- 10 nearest unique neighbors for: ZTF22abegjtx ---

Rank  ObjectID             L2 Distance         
--------------------------------------------------
1     ZTF18abncwah         0.2161              
2     ZTF20acuxaln         0.2392              
3     ZTF23aablqzq         0.2545              
4     ZTF19aarghka         0.2699              
5     ZTF20abezxdt         0.2839              
6     ZTF19abscstb         0.2849              
7     ZTF19abhlztd         0.2912              
8     ZTF22aanaxtu         0.2