In [12]:
import pandas as pd
import pygrib
import numpy as np
from datetime import datetime, timedelta
from scipy.spatial import cKDTree
from scipy.interpolate import griddata
import warnings
warnings.filterwarnings('ignore')

class GRIBClimateExtractor:
    def __init__(self, grib_file):
        self.grib_file = grib_file
        self.grib_data = {}
        self.spatial_tree = None
        self.grid_coords = None
        self.time_index = {}
        
    def load_grib_data(self, variables=None):
        """Load GRIB data into memory with spatial indexing"""
        print("🔄 Loading GRIB data...")
        
        grbs = pygrib.open(self.grib_file)
        
        # First pass: catalog all available data
        available_vars = set()
        available_times = set()
        
        for grb in grbs:
            available_vars.add(grb.parameterName)
            available_times.add(grb.validDate)
        
        grbs.rewind()
        
        # Use specified variables or all available
        if variables is None:
            variables = list(available_vars)
        
        print(f"📊 Loading variables: {variables}")
        print(f"📅 Time range: {min(available_times)} to {max(available_times)}")
        
        # Second pass: load data
        for grb in grbs:
            if grb.parameterName in variables:
                var_name = grb.parameterName
                timestamp = grb.validDate
                level = grb.level
                
                # Create nested structure: variable -> time -> level
                if var_name not in self.grib_data:
                    self.grib_data[var_name] = {}
                if timestamp not in self.grib_data[var_name]:
                    self.grib_data[var_name][timestamp] = {}
                
                # Get coordinates and data
                data = grb.values
                lats, lons = grb.latlons()
                
                # Handle longitude format (convert to -180 to 180 if needed)
                if lons.max() > 180:
                    lons = np.where(lons > 180, lons - 360, lons)
                
                self.grib_data[var_name][timestamp][level] = {
                    'data': data,
                    'lats': lats,
                    'lons': lons,
                    'units': grb.parameterUnits
                }
                
                # Build spatial index from first variable/time (assuming same grid)
                if self.spatial_tree is None:
                    flat_lats = lats.flatten()
                    flat_lons = lons.flatten()
                    self.grid_coords = np.column_stack([flat_lats, flat_lons])
                    self.spatial_tree = cKDTree(self.grid_coords)
                    print(f"🗺️  Built spatial index with {len(self.grid_coords)} grid points")
        
        grbs.close()
        
        # Build time index for each variable
        for var_name in self.grib_data:
            self.time_index[var_name] = sorted(self.grib_data[var_name].keys())
        
        print(f"✅ Loaded {len(self.grib_data)} variables")
    
    def find_nearest_time(self, target_time, var_name, max_hours_diff=6):
        """Find nearest available time for a variable"""
        available_times = self.time_index[var_name]
        
        # Convert target time to datetime if string
        if isinstance(target_time, str):
            # Handle different timestamp formats
            if target_time.endswith('Z'):
                target_time = datetime.fromisoformat(target_time.replace('Z', '+00:00'))
            elif '+' in target_time or target_time.endswith('UTC'):
                target_time = datetime.fromisoformat(target_time.replace('UTC', '+00:00'))
            else:
                target_time = datetime.fromisoformat(target_time)
        
        # Make target_time timezone naive (remove timezone info to match GRIB times)
        if target_time.tzinfo is not None:
            target_time = target_time.replace(tzinfo=None)
        
        # Ensure GRIB times are also timezone naive
        naive_available_times = []
        for t in available_times:
            if hasattr(t, 'tzinfo') and t.tzinfo is not None:
                naive_available_times.append(t.replace(tzinfo=None))
            else:
                naive_available_times.append(t)
        
        # Find closest time
        time_diffs = [abs((t - target_time).total_seconds()) for t in naive_available_times]
        min_idx = np.argmin(time_diffs)
        closest_time = available_times[min_idx]  # Return original time object
        
        # Check if within acceptable range
        hours_diff = abs((naive_available_times[min_idx] - target_time).total_seconds()) / 3600
        if hours_diff > max_hours_diff:
            return None, hours_diff
        
        return closest_time, hours_diff
    
    def extract_point_data(self, lat, lon, timestamp, variables, interpolation='nearest'):
        """Extract climate data for a single point"""
        results = {}
        
        for var_name in variables:
            if var_name not in self.grib_data:
                results[var_name] = {'value': np.nan, 'units': '', 'time_diff_hours': np.nan}
                continue
            
            # Find nearest time
            nearest_time, time_diff = self.find_nearest_time(timestamp, var_name)
            if nearest_time is None:
                results[var_name] = {'value': np.nan, 'units': '', 'time_diff_hours': time_diff}
                continue
            
            # Get data for the nearest time (assume surface level for now)
            time_data = self.grib_data[var_name][nearest_time]
            level = list(time_data.keys())[0]  # Take first available level
            
            data_info = time_data[level]
            data = data_info['data']
            lats = data_info['lats']
            lons = data_info['lons']
            units = data_info['units']
            
            # Spatial interpolation
            if interpolation == 'nearest':
                # Find nearest grid point
                distances, indices = self.spatial_tree.query([lat, lon], k=1)
                grid_idx = np.unravel_index(indices, data.shape)
                value = data[grid_idx]
            
            elif interpolation == 'bilinear':
                # Bilinear interpolation
                flat_data = data.flatten()
                valid_mask = ~np.isnan(flat_data)
                
                if valid_mask.sum() < 4:  # Not enough points for interpolation
                    value = np.nan
                else:
                    try:
                        value = griddata(
                            self.grid_coords[valid_mask], 
                            flat_data[valid_mask], 
                            (lat, lon), 
                            method='linear'
                        )
                        if np.isnan(value):  # If linear fails, use nearest
                            value = griddata(
                                self.grid_coords[valid_mask], 
                                flat_data[valid_mask], 
                                (lat, lon), 
                                method='nearest'
                            )
                    except:
                        value = np.nan
            
            results[var_name] = {
                'value': float(value) if not np.isnan(value) else np.nan,
                'units': units,
                'time_diff_hours': time_diff
            }
        
        return results
    
    def process_csv_batch(self, csv_file, output_file, variables, 
                         batch_size=10000, interpolation='nearest'):
        """Process CSV file in batches"""
        
        print(f"🔄 Processing CSV file: {csv_file}")
        print(f"📦 Batch size: {batch_size}")
        print(f"🎯 Variables to extract: {variables}")
        
        # Read CSV info first
        total_rows = sum(1 for _ in open(csv_file)) - 1  # Subtract header
        print(f"📊 Total rows to process: {total_rows:,}")
        
        # Initialize output file
        first_batch = True
        processed_rows = 0
        
        # Process in chunks
        for chunk_num, df_chunk in enumerate(pd.read_csv(csv_file, chunksize=batch_size)):
            print(f"🔄 Processing batch {chunk_num + 1} ({len(df_chunk):,} rows)...")
            
            # Prepare results DataFrame
            results_list = []
            
            for idx, row in df_chunk.iterrows():
                # lat = row['latitude'] if 'Latitude' in row else row['lat']
                # lon = row['longitude'] if 'Longitude' in row else row['lon']
                # timestamp = row['timestamp'] if 'Datetime' in row else row['time']
                lat = row['Latitude']
                lon = row['Longitude']
                timestamp = row['Datetime']
                
                
                # Extract climate data
                climate_data = self.extract_point_data(lat, lon, timestamp, variables, interpolation)
                
                # Prepare result row
                result_row = {
                    'original_index': idx,
                    'latitude': lat,
                    'longitude': lon,
                    'timestamp': timestamp
                }
                
                # Add climate variables
                for var_name, var_data in climate_data.items():
                    result_row[f"{var_name}_value"] = var_data['value']
                    result_row[f"{var_name}_units"] = var_data['units']
                    result_row[f"{var_name}_time_diff_hours"] = var_data['time_diff_hours']
                
                results_list.append(result_row)
            
            # Convert to DataFrame and save
            results_df = pd.DataFrame(results_list)
            
            # Write to file (append after first batch)
            mode = 'w' if first_batch else 'a'
            header = first_batch
            results_df.to_csv(output_file, mode=mode, header=header, index=False)
            first_batch = False
            
            processed_rows += len(df_chunk)
            progress = (processed_rows / total_rows) * 100
            print(f"✅ Batch {chunk_num + 1} complete. Progress: {progress:.1f}%")
            
            # Memory cleanup
            del results_df, results_list
        
        print(f"🎉 Processing complete! Results saved to: {output_file}")
        return output_file

# Usage functions
def quick_grib_info(grib_file):
    """Quick overview of GRIB file contents"""
    grbs = pygrib.open(grib_file)
    
    variables = set()
    times = set()
    
    for grb in grbs:
        variables.add(grb.parameterName)
        times.add(grb.validDate)
    
    grbs.close()
    
    print("📊 GRIB File Overview:")
    print(f"Variables: {sorted(list(variables))}")
    print(f"Time range: {min(times)} to {max(times)}")
    print(f"Total times: {len(times)}")
    
    return sorted(list(variables)), sorted(list(times))

def process_climate_data(grib_file, csv_file, output_file, 
                        variables=None, batch_size=10000, interpolation='nearest'):
    """Main function to process climate data extraction"""
    
    # Initialize extractor
    extractor = GRIBClimateExtractor(grib_file)
    
    # Quick info if variables not specified
    if variables is None:
        available_vars, _ = quick_grib_info(grib_file)
        print(f"\n🔍 Available variables: {available_vars}")
        variables = available_vars[:3]  # Use first 3 variables by default
        print(f"🎯 Using variables: {variables}")
    
    # Load GRIB data
    extractor.load_grib_data(variables)
    
    # Process CSV
    result_file = extractor.process_csv_batch(
        csv_file, output_file, variables, batch_size, interpolation
    )
    
    return result_file

In [13]:
if __name__ == "__main__":
    # Configuration
    grib_file = "copernicus_era5land_data/azuay_era5land_abril2024.grib"
    csv_file = "raw_data/hash queryabril2024.csv"
    output_file = "transformed_data/hash queryabril2024_with-temp-precip.csv"
    
    # Variables you want to extract (None = all available)
    # variables_to_extract = None  # or ['Temperature', 'Precipitation', 'Wind speed']
    variables_to_extract = ['2 metre temperature', 'Total precipitation']
    
    try:
        # Process the data
        result_file = process_climate_data(
            grib_file=grib_file,
            csv_file=csv_file,
            output_file=output_file,
            variables=variables_to_extract,
            batch_size=5000,  # Adjust based on your memory
            interpolation='nearest'  # 'nearest' or 'bilinear'
        )
        
        print(f"\n🎉 Success! Climate data extracted to: {result_file}")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        print("Make sure your file paths are correct and pygrib is installed.")

🔄 Loading GRIB data...
📊 Loading variables: ['2 metre temperature', 'Total precipitation']
📅 Time range: 2024-03-31 00:00:00 to 2024-04-30 00:00:00
🗺️  Built spatial index with 169 grid points
✅ Loaded 2 variables
🔄 Processing CSV file: raw_data/hash queryabril2024.csv
📦 Batch size: 5000
🎯 Variables to extract: ['2 metre temperature', 'Total precipitation']
📊 Total rows to process: 5,616,509
🔄 Processing batch 1 (5,000 rows)...
✅ Batch 1 complete. Progress: 0.1%
🔄 Processing batch 2 (5,000 rows)...
✅ Batch 2 complete. Progress: 0.2%
🔄 Processing batch 3 (5,000 rows)...
✅ Batch 3 complete. Progress: 0.3%
🔄 Processing batch 4 (5,000 rows)...
✅ Batch 4 complete. Progress: 0.4%
🔄 Processing batch 5 (5,000 rows)...
✅ Batch 5 complete. Progress: 0.4%
🔄 Processing batch 6 (5,000 rows)...
✅ Batch 6 complete. Progress: 0.5%
🔄 Processing batch 7 (5,000 rows)...
✅ Batch 7 complete. Progress: 0.6%
🔄 Processing batch 8 (5,000 rows)...
✅ Batch 8 complete. Progress: 0.7%
🔄 Processing batch 9 (5,000 r