In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import psycopg2
import re
import configparser
from tqdm import tqdm
import multiprocessing
from functools import partial

class SpatialHistogramBuilder:
    def __init__(self, grid_resolution_factor=0.1, data_dir="../../large_files"):
        """
        Initialize the histogram-based spatial estimator
        
        Parameters:
        -----------
        grid_resolution_factor : float
            The factor to determine grid resolution (0.1 = ~10% of data objects)
        data_dir : str
            Directory containing results files
        """
        self.grid_resolution_factor = grid_resolution_factor
        self.data_dir = data_dir
        self.universe_boundaries = {}
        self.dataset_sizes = {}
        
        # Create directories for saving histograms and results
        os.makedirs("../../large_files/traditional_methods/histogram", exist_ok=True)
        os.makedirs("../../large_files/traditional_methods/histogram/results", exist_ok=True)
        os.makedirs("../../large_files/traditional_methods/histogram/visualizations", exist_ok=True)
        
        # Load dataset metadata from spatial_statistics.csv
        self.load_spatial_statistics()
        
    def load_spatial_statistics(self):
        """Load dataset information from spatial_statistics.csv"""
        try:
            stats_df = pd.read_csv("../../spatial_statistics.csv")
            for _, row in stats_df.iterrows():
                table_name = row['Table Name']
                total_objects = row['Total Spatial Objects']
                bbox_str = row['Universe Limits (Bounding Box)']
                
                # Parse bounding box
                bbox = self.parse_bbox(bbox_str)
                self.universe_boundaries[table_name] = bbox
                self.dataset_sizes[table_name] = int(total_objects)
                
            print(f"Loaded metadata for {len(self.universe_boundaries)} datasets")
        except Exception as e:
            print(f"Error loading spatial statistics: {e}")
    
    def parse_bbox(self, bbox_str):
        """Parse bounding box string into coordinates"""
        pattern = r"BOX\(([-\d\.]+) ([-\d\.]+),([-\d\.]+) ([-\d\.]+)\)"
        match = re.search(pattern, bbox_str)
        if match:
            xmin = float(match.group(1))
            ymin = float(match.group(2))
            xmax = float(match.group(3))
            ymax = float(match.group(4))
            return (xmin, ymin, xmax, ymax)
        return (-180, -90, 180, 90)  # Default if parsing fails
    
    def connect_to_database(self):
        """Connect to the database containing spatial data"""
        # Load config from config.ini
        config = configparser.ConfigParser()
        config.read("../../dataset_generation/config.ini")
        db_params = config["database"]
        
        try:
            conn = psycopg2.connect(
                dbname=db_params["dbname"],
                user=db_params["user"],
                password=db_params["password"],
                host=db_params["host"],
                port=db_params["port"]
            )
            return conn
        except psycopg2.Error as e:
            print(f"Database connection error: {e}")
            return None
    
    def build_histogram(self, dataset_name, save=True, batch_size=100):
        """
        Build a 2D histogram for the specified dataset
        
        Parameters:
        -----------
        dataset_name : str
            Name of the dataset to build histogram for
        save : bool
            Whether to save the histogram to a file
        batch_size : int
            Number of cells to process in each SQL query batch for optimization
        
        Returns:
        --------
        np.ndarray
            2D histogram grid
        """
        print(f"Building histogram for {dataset_name}...")
        
        # Get universe boundaries and total objects
        if dataset_name not in self.universe_boundaries:
            raise ValueError(f"Unknown dataset: {dataset_name}")
            
        univ_xmin, univ_ymin, univ_xmax, univ_ymax = self.universe_boundaries[dataset_name]
        total_objects = self.dataset_sizes[dataset_name]
        
        # Determine grid resolution - sqrt to get cells in each dimension
        num_grids = int(total_objects * self.grid_resolution_factor)
        # The min between nm_grids and 512 x 512 is to prevent too large grids
        num_grids = min(num_grids, 512 * 512)
        num_grids_sqrt = int(np.sqrt(num_grids))
        
        print(f"Creating {num_grids_sqrt}x{num_grids_sqrt} grid ({num_grids_sqrt*num_grids_sqrt} cells)")
        
        # Initialize grid
        grid = np.zeros((num_grids_sqrt, num_grids_sqrt), dtype=int)
        
        try:
            # Connect to database
            conn = self.connect_to_database()
            cursor = conn.cursor()
            
            # Calculate cell dimensions
            cell_width = (univ_xmax - univ_xmin) / num_grids_sqrt
            cell_height = (univ_ymax - univ_ymin) / num_grids_sqrt
            
            # Process each grid cell with progress tracking
            with tqdm(total=num_grids_sqrt * num_grids_sqrt, desc="Processing grid cells") as pbar:
                # Process grid cells row by row, column by column
                for i in range(num_grids_sqrt):
                    for j in range(num_grids_sqrt):
                        # Calculate cell boundaries
                        cell_min_x = univ_xmin + i * cell_width
                        cell_min_y = univ_ymin + j * cell_height
                        cell_max_x = cell_min_x + cell_width
                        cell_max_y = cell_min_y + cell_height
                        
                        # Query to count objects that intersect with this cell
                        cursor.execute(f"""
                            SELECT COUNT(*) 
                            FROM {dataset_name}_mbr 
                            WHERE geometry && ST_MakeEnvelope(%s, %s, %s, %s, 4326)
                        """, (cell_min_x, cell_min_y, cell_max_x, cell_max_y))
                        
                        # Get the count and store in grid
                        count = cursor.fetchone()[0]
                        grid[i, j] = count
                        
                        # Update progress bar
                        pbar.update(1)
            
            conn.close()
            
            # Save the grid to a file
            if save:
                np.save(f"../../large_files/traditional_methods/histogram/{dataset_name}_histogram.npy", grid)
                # Also save the metadata
                metadata = {
                    'dimensions': (num_grids_sqrt, num_grids_sqrt),
                    'universe': (univ_xmin, univ_ymin, univ_xmax, univ_ymax),
                    'objects': total_objects
                }
                np.save(f"../../large_files/traditional_methods/histogram/{dataset_name}_metadata.npy", metadata)
                
                # Generate a visualization
                self.visualize_histogram(dataset_name, grid, 
                                        (univ_xmin, univ_ymin, univ_xmax, univ_ymax),
                                        save=True)
            
            print(f"Histogram built successfully for {dataset_name}")
            return grid
            
        except Exception as e:
            print(f"Error building histogram for {dataset_name}: {e}")
            return None
    
    def visualize_histogram(self, dataset_name, grid, universe, save=False):
        """Generate a visualization of the histogram"""
        univ_xmin, univ_ymin, univ_xmax, univ_ymax = universe
        
        plt.figure(figsize=(12, 10))
        plt.imshow(grid.T, cmap='viridis', origin='lower',
                  extent=[univ_xmin, univ_xmax, univ_ymin, univ_ymax],
                  aspect='auto', interpolation='nearest')
        plt.colorbar(label='Objects per cell')
        plt.title(f'Histogram for {dataset_name} - {grid.shape[0]}x{grid.shape[1]} grid')
        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        
        if save:
            plt.savefig(f"../../large_files/traditional_methods/histogram/visualizations/{dataset_name}_histogram.png", dpi=150)
            plt.close()
        else:
            plt.show()

    def build_histogram_batch(self, dataset_name, save=True, batch_size=500):
        """
        Build a 2D histogram using batch processing for better efficiency
        
        Parameters:
        -----------
        dataset_name : str
            Name of the dataset to build histogram for
        save : bool
            Whether to save the histogram to a file
        batch_size : int
            Number of cells to process in each batch
            
        Returns:
        --------
        np.ndarray
            2D histogram grid
        """
        print(f"Building histogram for {dataset_name} with batch processing...")
        
        # Get universe boundaries and total objects
        if dataset_name not in self.universe_boundaries:
            raise ValueError(f"Unknown dataset: {dataset_name}")
            
        univ_xmin, univ_ymin, univ_xmax, univ_ymax = self.universe_boundaries[dataset_name]
        total_objects = self.dataset_sizes[dataset_name]
        
        # Determine grid resolution - sqrt to get cells in each dimension
        num_grids = int(total_objects * self.grid_resolution_factor)
        num_grids = min(num_grids, 512 * 512)
        num_grids_sqrt = int(np.sqrt(num_grids))
        
        print(f"Creating {num_grids_sqrt}x{num_grids_sqrt} grid ({num_grids_sqrt*num_grids_sqrt} cells)")
        
        # Initialize grid
        grid = np.zeros((num_grids_sqrt, num_grids_sqrt), dtype=int)
        
        try:
            # Connect to database
            conn = self.connect_to_database()
            cursor = conn.cursor()
            
            # Calculate cell dimensions
            cell_width = (univ_xmax - univ_xmin) / num_grids_sqrt
            cell_height = (univ_ymax - univ_ymin) / num_grids_sqrt
            
            # Create temporary table with non-conflicting column names
            cursor.execute("""
                DROP TABLE IF EXISTS temp_grid_cells;
                CREATE TEMP TABLE temp_grid_cells (
                    row_idx INT, 
                    col_idx INT, 
                    min_x FLOAT, 
                    min_y FLOAT,
                    max_x FLOAT,
                    max_y FLOAT
                );
            """)
            
            # Process in batches
            total_cells = num_grids_sqrt * num_grids_sqrt
            batch_count = (total_cells + batch_size - 1) // batch_size
            
            with tqdm(total=total_cells, desc="Processing grid cells") as pbar:
                for batch in range(batch_count):
                    start_idx = batch * batch_size
                    end_idx = min((batch + 1) * batch_size, total_cells)
                    
                    # Clear temp table
                    cursor.execute("TRUNCATE temp_grid_cells")
                    
                    # Insert batch of cells into temp table
                    batch_cells = []
                    for idx in range(start_idx, end_idx):
                        i = idx // num_grids_sqrt
                        j = idx % num_grids_sqrt
                        
                        cell_min_x = univ_xmin + i * cell_width
                        cell_min_y = univ_ymin + j * cell_height
                        cell_max_x = cell_min_x + cell_width
                        cell_max_y = cell_min_y + cell_height
                        
                        batch_cells.append((i, j, cell_min_x, cell_min_y, cell_max_x, cell_max_y))
                    
                    # Insert batch into temp table
                    args_str = ','.join(cursor.mogrify("(%s,%s,%s,%s,%s,%s)", cell).decode('utf-8') 
                                     for cell in batch_cells)
                    if args_str:
                        cursor.execute(f"INSERT INTO temp_grid_cells VALUES {args_str}")
                        
                        # Query to count objects that intersect with each cell in this batch
                        cursor.execute(f"""
                            SELECT row_idx, col_idx, COUNT(geometry) 
                            FROM temp_grid_cells c
                            LEFT JOIN {dataset_name}_mbr m
                              ON m.geometry && ST_MakeEnvelope(c.min_x, c.min_y, c.max_x, c.max_y, 4326)
                            GROUP BY c.row_idx, c.col_idx
                        """)
                        
                        # Process results
                        for row_idx, col_idx, count in cursor.fetchall():
                            grid[row_idx, col_idx] = count
                            
                        # Update progress
                        pbar.update(end_idx - start_idx)
            
            conn.close()
            
            # Save the grid to a file
            if save:
                np.save(f"../../large_files/traditional_methods/histogram/{dataset_name}_histogram.npy", grid)
                metadata = {
                    'dimensions': (num_grids_sqrt, num_grids_sqrt),
                    'universe': (univ_xmin, univ_ymin, univ_xmax, univ_ymax),
                    'objects': total_objects,
                    'grid_resolution_factor': self.grid_resolution_factor,
                    'model_size_bytes': grid.nbytes,
                    'num_cells': num_grids_sqrt * num_grids_sqrt
                }
                np.save(f"../../large_files/traditional_methods/histogram/{dataset_name}_metadata.npy", metadata)
                
                # Generate a visualization
                self.visualize_histogram(dataset_name, grid, 
                                       (univ_xmin, univ_ymin, univ_xmax, univ_ymax),
                                       save=True)
            
            print(f"Histogram built successfully for {dataset_name}")
            return grid
            
        except Exception as e:
            print(f"Error building histogram for {dataset_name}: {e}")
            import traceback
            traceback.print_exc()
            return None

# Main execution function
def build_all_histograms(resolution_factor=0.1):
    """
    Build histograms for all datasets in the spatial_statistics.csv file
    
    Parameters:
    -----------
    resolution_factor : float
        The factor to determine grid resolution (0.1 = ~10% of data objects)
    """
    # Create histogram builder
    builder = SpatialHistogramBuilder(grid_resolution_factor=resolution_factor)
    
    # Get all dataset names
    dataset_names = list(builder.universe_boundaries.keys())
    
    print(f"Found {len(dataset_names)} datasets to process")
    
    # Build histograms for each dataset
    for dataset_name in dataset_names:
        try:
            print(f"\nProcessing dataset: {dataset_name}")
            # Use the batch version which avoids column name conflicts
            builder.build_histogram_batch(dataset_name, save=True)
        except Exception as e:
            print(f"Error processing {dataset_name}: {e}")

if __name__ == "__main__":
    # Build histograms for all datasets with default resolution factor
    build_all_histograms(resolution_factor=0.1)
    print("All histograms built successfully!")

Loaded metadata for 14 datasets
Found 14 datasets to process

Processing dataset: yago2
Building histogram for yago2 with batch processing...
Creating 512x512 grid (262144 cells)


Processing grid cells: 100%|██████████| 262144/262144 [1:22:50<00:00, 52.74it/s] 


Histogram built successfully for yago2

Processing dataset: craftwaysorted
Building histogram for craftwaysorted with batch processing...
Creating 104x104 grid (10816 cells)


Processing grid cells: 100%|██████████| 10816/10816 [00:00<00:00, 21252.66it/s]


Histogram built successfully for craftwaysorted

Processing dataset: zcta5
Building histogram for zcta5 with batch processing...
Creating 57x57 grid (3249 cells)


Processing grid cells: 100%|██████████| 3249/3249 [00:00<00:00, 19863.12it/s]


Histogram built successfully for zcta5

Processing dataset: areawater
Building histogram for areawater with batch processing...
Creating 478x478 grid (228484 cells)


Processing grid cells: 100%|██████████| 228484/228484 [00:13<00:00, 17369.57it/s]


Histogram built successfully for areawater

Processing dataset: aerowaythingnodesorted
Building histogram for aerowaythingnodesorted with batch processing...
Creating 89x89 grid (7921 cells)


Processing grid cells: 100%|██████████| 7921/7921 [00:00<00:00, 21664.82it/s]


Histogram built successfully for aerowaythingnodesorted

Processing dataset: emergencythingwaysorted
Building histogram for emergencythingwaysorted with batch processing...
Creating 284x284 grid (80656 cells)


Processing grid cells: 100%|██████████| 80656/80656 [00:03<00:00, 23602.39it/s]


Histogram built successfully for emergencythingwaysorted

Processing dataset: historicthingwaysorted
Building histogram for historicthingwaysorted with batch processing...
Creating 423x423 grid (178929 cells)


Processing grid cells: 100%|██████████| 178929/178929 [00:08<00:00, 21017.06it/s]


Histogram built successfully for historicthingwaysorted

Processing dataset: aerowaythingwaysorted
Building histogram for aerowaythingwaysorted with batch processing...
Creating 429x429 grid (184041 cells)


Processing grid cells: 100%|██████████| 184041/184041 [00:10<00:00, 17323.34it/s]


Histogram built successfully for aerowaythingwaysorted

Processing dataset: cyclewaythingwaysorted
Building histogram for cyclewaythingwaysorted with batch processing...
Creating 512x512 grid (262144 cells)


Processing grid cells: 100%|██████████| 262144/262144 [00:19<00:00, 13660.56it/s]


Histogram built successfully for cyclewaythingwaysorted

Processing dataset: powerthingwaysorted
Building histogram for powerthingwaysorted with batch processing...
Creating 512x512 grid (262144 cells)


Processing grid cells: 100%|██████████| 262144/262144 [01:11<00:00, 3649.49it/s]


Histogram built successfully for powerthingwaysorted

Processing dataset: leisurewaysorted
Building histogram for leisurewaysorted with batch processing...
Creating 512x512 grid (262144 cells)


Processing grid cells: 100%|██████████| 262144/262144 [01:36<00:00, 2721.48it/s]


Histogram built successfully for leisurewaysorted

Processing dataset: barrierthingwaysorted
Building histogram for barrierthingwaysorted with batch processing...
Creating 512x512 grid (262144 cells)


Processing grid cells: 100%|██████████| 262144/262144 [01:29<00:00, 2920.63it/s]


Histogram built successfully for barrierthingwaysorted

Processing dataset: powerthingnodesorted
Building histogram for powerthingnodesorted with batch processing...
Creating 512x512 grid (262144 cells)


Processing grid cells: 100%|██████████| 262144/262144 [01:07<00:00, 3910.79it/s]


Histogram built successfully for powerthingnodesorted

Processing dataset: arealm
Building histogram for arealm with batch processing...
Creating 113x113 grid (12769 cells)


Processing grid cells: 100%|██████████| 12769/12769 [00:00<00:00, 21789.96it/s]


Histogram built successfully for arealm
All histograms built successfully!
