In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import math
from tqdm.notebook import tqdm 

# --- Configuration ---
NL_DATA_BASE_DIR = "./downloaded_gcs_data/senegal_NL/" 
ORIGINAL_CSV_PATH = os.path.join('./data', 'senegal_shuffled.csv')
OUTPUT_CSV_PATH = os.path.join('./data', 'senegal_shuffled_coords_imputed_2.csv') 

# Image dimensions (CONFIRMED 384x384)
IMG_HEIGHT = 384
IMG_WIDTH = 384
EXPECTED_PIXELS = IMG_HEIGHT * IMG_WIDTH
PIXEL_SCALE = 30 

# Displacement Radii (meters)
RADIUS_METERS_URBAN = 2000
RADIUS_METERS_RURAL = 5000

# Night Light Threshold for binary likelihood mask (Prior P(x,y|NL))
NL_THRESHOLD = 0.05 

# --- Helper Function: Coordinate Approximation ---
def displace_lat_lon(lat, lon, dx_meters, dy_meters):
    R_EARTH = 6378137
    m_per_deg_lat = 111132.954 - 559.822 * math.cos(2 * math.radians(lat)) + 1.175 * math.cos(4 * math.radians(lat))
    delta_lat = dy_meters / m_per_deg_lat
    new_lat = lat + delta_lat
    m_per_deg_lon = (math.pi / 180) * R_EARTH * math.cos(math.radians(lat))
    if abs(m_per_deg_lon) < 1e-6: delta_lon = 0
    else: delta_lon = dx_meters / m_per_deg_lon
    new_lon = lon + delta_lon
    return new_lat, new_lon

# --- Main Processing Logic ---

# 1. Load original DataFrame
print(f"Loading original data from: {ORIGINAL_CSV_PATH}")
try:
    df_orig = pd.read_csv(ORIGINAL_CSV_PATH)
    if 'rural' not in df_orig.columns:
        print("Error: 'rural' column not found.")
        exit()
    print(f"Loaded {len(df_orig)} records.")
except FileNotFoundError:
    print(f"Error: Original CSV file not found at {ORIGINAL_CSV_PATH}")
    exit()
except KeyError:
     print("Error: CSV seems to be missing expected columns.")
     exit()

# 2. Generate CORRECT UID based on within-year index
print("Generating correct UIDs based on within-year index...")
df_orig['group_index'] = df_orig.groupby(['country', 'year']).cumcount()
df_orig['uid'] = df_orig.apply(
    lambda row: f"{row['country']}_{row['year']}_{int(row['group_index']):05d}",
    axis=1
)
print("UID generation complete.")
print("Sample generated UIDs:")
print(df_orig[['country', 'year', 'group_index', 'uid']].head())


# 3. Prepare results storage
results = []

# 4. Loop through DataFrame and Process TFRecords
print(f"Processing TFRecords and imputing coordinates for {len(df_orig)} locations...")
center_row, center_col = IMG_HEIGHT // 2, IMG_WIDTH // 2
y_coords, x_coords = np.ogrid[:IMG_HEIGHT, :IMG_WIDTH]
dist_from_center = np.sqrt((x_coords - center_col)**2 + (y_coords - center_row)**2)
dist_from_center[center_row, center_col] = 1.0 

for index, row in tqdm(df_orig.iterrows(), total=len(df_orig)):
    uid = row['uid']
    country = row['country']
    year = row['year'] 
    is_rural = row['rural']
    orig_lat, orig_lon = row['lat'], row['lon']

    location_result = {
        'uid': uid, 'imputed_lat': orig_lat, 'imputed_lon': orig_lon,
        'imputation_success': False, 'imputation_error': None
    }
    tfrecord_path = os.path.join(NL_DATA_BASE_DIR, f"{country}_{year}", f"{uid}.tfrecord.gz")

    if not os.path.exists(tfrecord_path):
        location_result['imputation_error'] = 'TFRecord file not found'
        results.append(location_result)
        continue

    try:
        # --- Load NL Patch ---
        raw_dataset = tf.data.TFRecordDataset(tfrecord_path, compression_type='GZIP')
        serialized_example = next(iter(raw_dataset))
        example = tf.train.Example()
        example.ParseFromString(serialized_example.numpy())
        feature_map = example.features.feature

        nl_image_patch = None
        if len(feature_map) == 1:
            feature_key = list(feature_map.keys())[0]
            feature = feature_map[feature_key]
            if feature.float_list.value and len(feature.float_list.value) == EXPECTED_PIXELS:
                nl_image_patch = np.array(feature.float_list.value, dtype=np.float32).reshape((IMG_HEIGHT, IMG_WIDTH))
            elif not feature.float_list.value:
                 nl_image_patch = np.full((IMG_HEIGHT, IMG_WIDTH), np.nan, dtype=np.float32)
                 location_result['imputation_error'] = 'TFRecord feature empty'
            else:
                location_result['imputation_error'] = 'Pixel length mismatch'
        else:
            location_result['imputation_error'] = f'Found {len(feature_map)} features, expected 1'

        if nl_image_patch is None:
            results.append(location_result)
            continue

        # --- Create Prior based on NL (P(x,y | NL)) ---
        prior_map_nl = (nl_image_patch > NL_THRESHOLD).astype(np.float32)
        prior_map_nl[np.isnan(nl_image_patch)] = 0.0

        # --- Create Likelihood based on Displacement (P(x',y' | x,y) ~ 1/d) ---
        radius_meters = RADIUS_METERS_RURAL if is_rural == 1 else RADIUS_METERS_URBAN
        radius_pixels = radius_meters / PIXEL_SCALE
        likelihood_map_displacement = np.zeros_like(dist_from_center, dtype=np.float32)
        mask_within_radius = dist_from_center <= radius_pixels
        likelihood_map_displacement[mask_within_radius] = 1.0 / dist_from_center[mask_within_radius]

        # --- Calculate Posterior Probability Map ---
        posterior_map = likelihood_map_displacement * prior_map_nl

        # --- Sample from Posterior ---
        posterior_sum = np.sum(posterior_map)
        if posterior_sum > 1e-9:
            probabilities = posterior_map / posterior_sum
            flat_probs = probabilities.flatten()
            flat_probs = np.maximum(flat_probs, 0) 
            flat_probs /= np.sum(flat_probs) 

            flat_indices = np.arange(len(flat_probs))
            chosen_flat_index = np.random.choice(flat_indices, p=flat_probs)
            chosen_row, chosen_col = np.unravel_index(chosen_flat_index, (IMG_HEIGHT, IMG_WIDTH))

            dx_pixels = chosen_col - center_col
            dy_pixels = center_row - chosen_row
            dx_meters = dx_pixels * PIXEL_SCALE
            dy_meters = dy_pixels * PIXEL_SCALE

            imputed_lat, imputed_lon = displace_lat_lon(orig_lat, orig_lon, dx_meters, dy_meters)

            location_result['imputed_lat'] = imputed_lat
            location_result['imputed_lon'] = imputed_lon
            location_result['imputation_success'] = True
        else:
            location_result['imputation_error'] = 'No valid settlement pixels found within radius (weighted)'

    except StopIteration:
        location_result['imputation_error'] = 'TFRecord file is empty'
    except Exception as e:
        location_result['imputation_error'] = f'Processing error: {str(e)}'
    results.append(location_result)

# 5. Create DataFrame from results and Merge
print("\nMerging imputed coordinates back into the original DataFrame...")
df_results = pd.DataFrame(results)
merge_cols = ['uid', 'imputed_lat', 'imputed_lon', 'imputation_success', 'imputation_error']
df_output = pd.merge(df_orig, df_results[merge_cols], on='uid', how='left')
df_output = df_output.drop(columns=['group_index'])

# 5a. Preserve original coordinates
print("Creating origin_lat/origin_lon columns...")
df_output['origin_lat'] = df_output['lat']
df_output['origin_lon'] = df_output['lon']

# 5b. Update lat/lon with imputed values where successful
print("Updating lat/lon columns with imputed values where successful...")
success_mask = (df_output['imputation_success'] == True) & \
               (df_output['imputed_lat'].notna()) & \
               (df_output['imputed_lon'].notna())

df_output.loc[success_mask, 'lat'] = df_output.loc[success_mask, 'imputed_lat']
df_output.loc[success_mask, 'lon'] = df_output.loc[success_mask, 'imputed_lon']
print(f"Updated lat/lon for {success_mask.sum()} locations.")

# 7. Report Summary
print("\n--- Imputation Summary ---")
num_success = df_output['imputation_success'].sum() 
num_fail_file = df_output['imputation_error'].str.contains('not found', na=False).sum() 
num_fail_nolight = df_output['imputation_error'].str.contains('No valid settlement', na=False).sum() 
num_fail_other = df_output['imputation_error'].notna().sum() - num_fail_file - num_fail_nolight 

print(f"Successfully imputed coordinates for: {num_success} locations")
print(f"Failures due to missing TFRecord:   {num_fail_file}")
print(f"Failures due to no NL signal in radius: {num_fail_nolight}")
print(f"Failures due to other errors:       {num_fail_other}")
if num_fail_other > 0:
    print("Sample other errors:")
    other_error_mask = (df_output['imputation_success'] == False) & \
                       df_output['imputation_error'].notna() & \
                       ~df_output['imputation_error'].str.contains('not found|No valid settlement', na=True)
    print(df_output.loc[other_error_mask, ['uid', 'imputation_error']].head())

# 5c. Drop temporary imputation-related columns
print("Dropping temporary imputation columns...")
cols_to_drop = ['imputed_lat', 'imputed_lon', 'imputation_success', 'imputation_error']
cols_to_drop_existing = [col for col in cols_to_drop if col in df_output.columns]
df_output.drop(columns=cols_to_drop_existing, inplace=True)

# 6. Save the output CSV
print(f"Saving imputed DataFrame to: {OUTPUT_CSV_PATH}")
os.makedirs(os.path.dirname(OUTPUT_CSV_PATH), exist_ok=True)
df_output.to_csv(OUTPUT_CSV_PATH, index=False)

print("\nImputation script finished.")

Loading original data from: ./data\senegal_shuffled.csv
Loaded 2165 records.
Generating correct UIDs based on within-year index...
UID generation complete.
Sample generated UIDs:
   country  year  group_index                 uid
0  senegal  1992            0  senegal_1992_00000
1  senegal  1992            1  senegal_1992_00001
2  senegal  1992            2  senegal_1992_00002
3  senegal  1992            3  senegal_1992_00003
4  senegal  1992            4  senegal_1992_00004
Processing TFRecords and imputing coordinates for 2165 locations...


  0%|          | 0/2165 [00:00<?, ?it/s]


Merging imputed coordinates back into the original DataFrame...
Creating origin_lat/origin_lon columns...
Updating lat/lon columns with imputed values where successful...
Updated lat/lon for 462 locations.

--- Imputation Summary ---
Successfully imputed coordinates for: 462 locations
Failures due to missing TFRecord:   1565
Failures due to no NL signal in radius: 138
Failures due to other errors:       0
Dropping temporary imputation columns...
Saving imputed DataFrame to: ./data\senegal_shuffled_coords_imputed_2.csv

Imputation script finished.
