## Use Case

This notebook reads downloaded GRIB files and join weather attributes by location to each station, assigning grid data (df) to stations from station_df based on the nearest grid point.

Main function: assign_nearest_grid_values(df, station_df, grid_points, nearest_grid_idx)

Method: scipy.spatial.KDTree for nearest neighbor searches

In [47]:
# ! pip install xarray
# ! pip install cfgrib
import pandas as pd
import pygrib
import os
import numpy as np
import cfgrib 
import xarray as xr


In [58]:
# set path: input, output

input_dir = '../../data/GRIB'
output_dir = '../../data/GRIB/csv'


In [33]:
# Read GRIB file
def read_grib_file(file_path):
	grbs = pygrib.open(file_path)
	for grb in grbs:
		print(grb)
	grbs.close()
	
	return grbs

test_input_grib = '../../data/GRIB/2013/era5_land_20131229_2300.grib'


grbs = read_grib_file(test_input_grib)



with xr.open_dataset(test_input_grib) as ds:
	df = ds.to_dataframe()

1:2 metre temperature:K (instant):regular_ll:surface:level 0:fcst time 23 hrs:from 201312290000
2:10 metre U wind component:m s**-1 (instant):regular_ll:surface:level 0:fcst time 23 hrs:from 201312290000
3:10 metre V wind component:m s**-1 (instant):regular_ll:surface:level 0:fcst time 23 hrs:from 201312290000
4:Total precipitation:m (accum):regular_ll:surface:level 0:fcst time 22-23 hrs (accum):from 201312290000


In [39]:


df

Unnamed: 0_level_0,Unnamed: 1_level_0,number,time,step,surface,valid_time,t2m,u10,v10,tp
latitude,longitude,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
31.032,72.475,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,277.273468,-1.662138,1.093603,0.0
31.032,72.575,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,277.206085,-1.590117,1.204443,0.0
31.032,72.675,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,277.191437,-1.505644,1.310400,0.0
31.032,72.775,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,277.169952,-1.411161,1.422705,0.0
31.032,72.875,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,277.170929,-1.317411,1.519384,0.0
...,...,...,...,...,...,...,...,...,...,...
12.431,94.375,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,,,,
12.431,94.475,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,,,,
12.431,94.575,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,,,,
12.431,94.675,0,2013-12-29,0 days 23:00:00,0.0,2013-12-29 23:00:00,,,,


In [None]:

# df.head(30)

# df.to_csv('test.csv')

# in df, delete column: number, time, step, surfcace, valid_time

df = df.drop(columns=['number', 'time', 'step', 'surface', 'valid_time'])

# df.head(30)
date = test_input_grib.split('_')[2]
time = test_input_grib.split('_')[3].split('.')[0]

# add date and time to df

df['date'] = date 
df['time'] = time

df.head(30)



In [25]:
# check how many grids are there in df


# check index

df.index

# count number of unique latitudes
# lat_count = df['latitude'].nunique()
# lon_count = df['longitude'].nunique()

# print('lat_count:',lat_count,'lon_count:', lon_count)

# print('pixel number:',lat_count*lon_count)


MultiIndex([(31.032,            72.475),
            (31.032, 72.57499999999999),
            (31.032, 72.67499999999998),
            (31.032, 72.77499999999998),
            (31.032, 72.87499999999997),
            (31.032, 72.97499999999997),
            (31.032, 73.07499999999996),
            (31.032, 73.17499999999995),
            (31.032, 73.27499999999995),
            (31.032, 73.37499999999994),
            ...
            (12.431, 93.87499999999878),
            (12.431, 93.97499999999877),
            (12.431, 94.07499999999877),
            (12.431, 94.17499999999876),
            (12.431, 94.27499999999876),
            (12.431, 94.37499999999875),
            (12.431, 94.47499999999874),
            (12.431, 94.57499999999874),
            (12.431, 94.67499999999873),
            (12.431,            94.775)],
           names=['latitude', 'longitude'], length=41888)

In [61]:
# read station csv

station_file='../../data/STATION.csv'
station_df = pd.read_csv(station_file)

station_df.head(30)

# print lat and lon data type
# print(station_df['Lat'].dtype)


Unnamed: 0,station_id,Location name,ESMI_ID,From date,To date,District,State,Category,Connection Type,Lat,Lon
0,1,5th phase JP Nagar,,1/30/2017 0:00,6/12/2017 12:59,Bengaluru Urban,Karnataka,State Capital,Domestic,12.901092,77.58915
1,2,80 feet road,,9/6/2015 1:00,10/2/2018 14:59,Dhule,Maharashtra,District Headquarters,Domestic,20.895199,74.775982
2,3,Adarsh Nagar,,3/2/2016 0:00,1/22/2017 20:59,Saharsa,Bihar,District Headquarters,Domestic,25.883507,86.614919
3,4,Adgaon,,5/4/2018 0:00,12/31/2018 23:59,Nashik,Maharashtra,Gram Panchayat,Domestic,20.037184,73.850136
4,5,Agarchitti,,1/1/2017 0:00,6/12/2017 12:59,Chamoli,Uttarakhand,Gram Panchayat,Domestic,30.016116,79.308735
5,6,Agarwal Mohalla Jarwal Kasba,,1/26/2018 0:00,12/31/2018 23:59,Bahraich,Uttar Pradesh,Other Municipal Area,Domestic,27.574566,81.601251
6,7,Agri road Chidambaram,,4/28/2018 0:00,12/31/2018 23:59,Cuddalore,Tamil Nadu,Other Municipal Area,Domestic,11.390684,79.721755
7,8,Aishbagh,,11/3/2017 0:00,10/14/2018 9:59,Lucknow,Uttar Pradesh,State Capital,Domestic,26.841267,80.903863
8,9,Ajekar,,1/9/2015 17:00,5/15/2015 2:59,Udupi,Karnataka,Gram Panchayat,Domestic,13.320509,74.998807
9,10,Akoli Jehangir,,10/31/2014 23:00,3/8/2015 0:59,Akola,Maharashtra,Gram Panchayat,Domestic,21.154837,77.118466


In [63]:
import pandas as pd
import numpy as np
from scipy.spatial import KDTree

# Assuming df (grib data) has lat and lon as index
# And the columns are ['t2m', 'wind_u', 'wind_v']
# Station_df has ['station_lat', 'station_lon']

# Step 1: Extract the grid coordinates from the grib data
grid_points = np.array(list(zip(df.index.get_level_values('latitude'), df.index.get_level_values('longitude'))))

# Step 2: Build the KDTree
tree = KDTree(grid_points)

# Step 3: For each station coordinate, find the nearest grid point
station_coords = station_df[['Lat', 'Lon']].to_numpy()

# Query the nearest point in the KDTree
_, nearest_grid_idx = tree.query(station_coords)

# Step 4: Use the index to get the corresponding grid data
nearest_grid_points = grid_points[nearest_grid_idx]
print(nearest_grid_points)
print(type(nearest_grid_points))


[[12.932 77.575]
 [20.932 74.775]
 [25.932 86.575]
 ...
 [26.132 90.575]
 [26.732 91.675]
 [20.032 78.975]]
<class 'numpy.ndarray'>


In [64]:
df.index = pd.MultiIndex.from_tuples(
    [(round(lat, 3), round(lon, 3)) for lat, lon in df.index], 
)



In [67]:
# nearest_grid_points = grid_points[nearest_grid_idx]
df.index

# nearest_grid_values = df.loc[nearest_grid_points[:, 0], nearest_grid_points[:, 1]].reset_index()


# Step 5: Combine the nearest grid data with station data
# station_df = pd.concat([station_df.reset_index(drop=True), nearest_grid_values[['t2m', 'wind_u', 'wind_v']]], axis=1)
# preserve all df columns in nearest_grid_values
# The station_df now contains the nearest grid point attributes (t2m, wind_u, wind_v) for each station.


MultiIndex([(31.032, 72.475),
            (31.032, 72.575),
            (31.032, 72.675),
            (31.032, 72.775),
            (31.032, 72.875),
            (31.032, 72.975),
            (31.032, 73.075),
            (31.032, 73.175),
            (31.032, 73.275),
            (31.032, 73.375),
            ...
            (12.431, 93.875),
            (12.431, 93.975),
            (12.431, 94.075),
            (12.431, 94.175),
            (12.431, 94.275),
            (12.431, 94.375),
            (12.431, 94.475),
            (12.431, 94.575),
            (12.431, 94.675),
            (12.431, 94.775)],
           length=41888)

In [69]:
# Step 5: Convert nearest_grid_points to tuples (lat, lon)
nearest_grid_points = [(round(lat, 3), round(lon, 3)) for lat, lon in nearest_grid_points]

# Step 6: Make sure the index of the df is also rounded and in the same format
df.index = pd.MultiIndex.from_tuples(
    [(round(lat, 3), round(lon, 3)) for lat, lon in df.index], 
    names=['lat', 'lon']
)

# Step 7: Use the nearest grid points to get corresponding values from df
# Ensure nearest_grid_points is a list of tuples
nearest_grid_values = df.loc[nearest_grid_points].reset_index()

# Now you can join this with the station_df
station_df = pd.concat([station_df.reset_index(drop=True), nearest_grid_values[['t2m', 'u10', 'v10','tp','date','time']]], axis=1)

In [72]:
station_df.head(30)

station_df.to_csv('test_station.csv')

## Main Function

In [8]:
from scipy.spatial import KDTree
import numpy as np


In [31]:
def read_grib_file(file_path):
	grbs = pygrib.open(file_path)
	for grb in grbs:
		print(grb)
	grbs.close()
	
	return grbs

In [61]:

def build_kdtree(df):
    """
    Build a KDTree from the gridded dataset.
    
    Args:
        df (pd.DataFrame): DataFrame containing the grib data with lat and lon as the index.

    Returns:
        KDTree: A KDTree built on lat/lon coordinates.
        np.array: Grid points as an array of lat/lon pairs.
    """
    # Ensure lat/lon values are floats and create an array of lat/lon pairs
    grid_points = np.array(list(zip(df.index.get_level_values('latitude').astype(float), df.index.get_level_values('longitude').astype(float))))
    
    # Build and return the KDTree
    return KDTree(grid_points), grid_points


In [65]:
def find_nearest_grid_points(tree, station_df):
    """
    Find the nearest grid points for each station using the KDTree.
    
    Args:
        tree (KDTree): The KDTree built from grid points.
        station_df (pd.DataFrame): DataFrame containing station coordinates with 'station_lat' and 'station_lon' columns.

    Returns:
        np.array: Nearest grid points for each station in station_df.
    """
    # Extract station coordinates as numpy array
    station_coords = station_df[['Lat', 'Lon']].to_numpy()
    
    # Query the KDTree for nearest grid points
    _, nearest_grid_idx = tree.query(station_coords)
    
    return nearest_grid_idx


In [67]:
def assign_nearest_grid_values(df, station_df, grid_points, nearest_grid_idx):
    """
    Assign the nearest grid data (t2m, wind_u, wind_v) to each station in station_df.
    
    Args:
        df (pd.DataFrame): The grib data with lat/lon index and variables like t2m, wind_u, wind_v.
        station_df (pd.DataFrame): DataFrame containing station coordinates.
        grid_points (np.array): Grid points from the KDTree.
        nearest_grid_idx (np.array): Indices of the nearest grid points for each station.

    Returns:
        pd.DataFrame: station_df with grid data (t2m, wind_u, wind_v) appended.
    """
    # Get the nearest grid points as lat/lon tuples
    nearest_grid_points = [(round(lat, 3), round(lon, 3)) for lat, lon in grid_points[nearest_grid_idx]]
    
    # Ensure the index of df is also rounded and in tuple form
    df.index = pd.MultiIndex.from_tuples(
        [(round(lat, 3), round(lon, 3)) for lat, lon in df.index], 
        names=['lat', 'lon']
    )
    
    # Extract the grid data for the nearest points
    nearest_grid_values = df.loc[nearest_grid_points].reset_index()
    
    # Join the nearest grid values with the station data
    station_df = pd.concat([station_df.reset_index(drop=True), nearest_grid_values[['t2m', 'u10', 'v10','tp']]], axis=1)
    
    return station_df


In [37]:


def read_grib_file(file_path):
	grbs = pygrib.open(file_path)
	for grb in grbs:
		print(grb)
	grbs.close()
	
	return grbs

test_input_grib = '../../data/GRIB/2013/era5_land_20131229_2300.grib'


grbs = read_grib_file(test_input_grib)

with xr.open_dataset(file) as ds:
	df = ds.to_dataframe()


1:2 metre temperature:K (instant):regular_ll:surface:level 0:fcst time 23 hrs:from 201312290000
2:10 metre U wind component:m s**-1 (instant):regular_ll:surface:level 0:fcst time 23 hrs:from 201312290000
3:10 metre V wind component:m s**-1 (instant):regular_ll:surface:level 0:fcst time 23 hrs:from 201312290000
4:Total precipitation:m (accum):regular_ll:surface:level 0:fcst time 22-23 hrs (accum):from 201312290000


In [63]:
def read_grib_file(file_path):
	grbs = pygrib.open(file_path)
	for grb in grbs:
		print(grb)

	return grbs



In [69]:

input_dir = '/Users/rainylty/STUDY/fall24/1-GRA/Resilient_energy_network/data/GRIB/2022'
station_file='../../data/STATION.csv'
station_df = pd.read_csv(station_file)
output_dir = '/Users/rainylty/STUDY/fall24/1-GRA/Resilient_energy_network/data/csv/2022'

for file in os.listdir(input_dir):

	if  file.endswith('.grib'):
		print(file)

		grbs = read_grib_file(os.path.join(input_dir, file))
		
		with xr.open_dataset(test_input_grib) as ds:
			df = ds.to_dataframe()
		
			df = df.drop(columns=['number', 'time', 'step', 'surface', 'valid_time'])
			# get date and time
			date = file.split('_')[2]
			time = file.split('_')[3].split('.')[0]
			# add date and time to df
			
			df['date'] = date
			df['time'] = time

			print(df.head(10))

			
			# build KDTree
			tree, grid_points = build_kdtree(df)
			# find nearest grid points
			nearest_grid_idx = find_nearest_grid_points(tree, station_df)
			# assign nearest grid values

			station_df = assign_nearest_grid_values(df, station_df, grid_points, nearest_grid_idx)

			# save to csv
			station_df.to_csv(f'{output_dir}/{date}_{time}_station.csv')

			print(f'{output_dir}/{date}_{time}_station.csv')

			print('done')







era5_land_20220101_0100.grib
1:2 metre temperature:K (instant):regular_ll:surface:level 0:fcst time 1 hrs:from 202201010000
2:10 metre U wind component:m s**-1 (instant):regular_ll:surface:level 0:fcst time 1 hrs:from 202201010000
3:10 metre V wind component:m s**-1 (instant):regular_ll:surface:level 0:fcst time 1 hrs:from 202201010000
4:Total precipitation:m (accum):regular_ll:surface:level 0:fcst time 0-1 hrs (accum):from 202201010000
                           t2m       u10       v10   tp      date  time
latitude longitude                                                     
31.032   72.475     277.273468 -1.662138  1.093603  0.0  20220101  0100
         72.575     277.206085 -1.590117  1.204443  0.0  20220101  0100
         72.675     277.191437 -1.505644  1.310400  0.0  20220101  0100
         72.775     277.169952 -1.411161  1.422705  0.0  20220101  0100
         72.875     277.170929 -1.317411  1.519384  0.0  20220101  0100
         72.975     277.181671 -1.253202  1.580664  0.0

In [None]:

			tree, grid_points = build_kdtree(df)
			# find nearest grid points
			nearest_grid_idx = find_nearest_grid_points(tree, station_df)
			# assign nearest grid values

			station_df = assign_nearest_grid_values(df, station_df, grid_points, nearest_grid_idx)

			# save to csv
			station_df.to_csv(f'{output_dir}/{date}_{time}_station.csv')

			print(f'{output_dir}/{date}_{time}_station.csv')

			print('done')


In [None]:


		
	df = df.drop(columns=['number', 'time', 'step', 'surface', 'valid_time'])
	# get date and time
	date = file.split('_')[2]
	time = file.split('_')[3].split('.')[0]
	# add date and time to df
	
	df['date'] = date
	df['time'] = time
	# build KDTree

	tree, grid_points = build_kdtree(df)
	# find nearest grid points
	nearest_grid_idx = find_nearest_grid_points(tree, station_df)
	# assign nearest grid values

	station_df = assign_nearest_grid_values(df, station_df, grid_points, nearest_grid_idx)

	# save to csv
	station_df.to_csv(f'{output_dir}/{date}_{time}_station.csv')

	print(f'{output_dir}/{date}_{time}_station.csv')

	print('done')






In [None]:
from concurrent.futures import ProcessPoolExecutor
import pandas as pd
import numpy as np
from scipy.spatial import KDTree

# Function for the whole batch process
def process_batch(df, station_batch, tree, grid_points):
    """
    Processes a batch of stations, finding the nearest grid points and assigning grid data.
    
    Args:
        df (pd.DataFrame): The grib data with lat/lon index and variables.
        station_batch (pd.DataFrame): A subset of station data.
        tree (KDTree): A pre-built KDTree for grid points.
        grid_points (np.array): Array of grid points (lat, lon).
        
    Returns:
        pd.DataFrame: A batch of station data with nearest grid values assigned.
    """
    # Find nearest grid points for the batch of stations
    nearest_grid_idx = find_nearest_grid_points(tree, station_batch)
    
    # Assign the grid values for the batch
    return assign_nearest_grid_values(df, station_batch, grid_points, nearest_grid_idx)

# Main function to parallelize the process
def parallel_process(df, station_df, num_workers=4, batch_size=100):
    """
    Parallelizes the batch processing of station data using ProcessPoolExecutor.
    
    Args:
        df (pd.DataFrame): Gridded dataset.
        station_df (pd.DataFrame): Full station dataset.
        num_workers (int): Number of parallel workers.
        batch_size (int): Size of each batch to process.
        
    Returns:
        pd.DataFrame: The full station dataframe with nearest grid values assigned.
    """
    # Build the KDTree from the grid data
    tree, grid_points = build_kdtree(df)
    
    # Split station_df into batches
    station_batches = [station_df[i:i+batch_size] for i in range(0, len(station_df), batch_size)]
    
    # Create a process pool
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit tasks in parallel
        futures = [executor.submit(process_batch, df, batch, tree, grid_points) for batch in station_batches]
        
        # Collect the results
        results = [f.result() for f in futures]
    
    # Combine all the results into a single dataframe
    return pd.concat(results, ignore_index=True)


In [None]:
# Parallel process the station data
num_workers = 4  # Set the number of workers (depends on CPU)
batch_size = 100  # Adjust batch size as needed

station_df_with_grid_data = parallel_process(df, station_df, num_workers=num_workers, batch_size=batch_size)

# Now station_df_with_grid_data contains the station data along with nearest grid values.


In [12]:
import os
import pandas as pd

# Define the input and output directories
input_base_dir = '../../data/GRIB'
output_base_dir = '../../data/csv'

# Function to process each file
def process_grib_file(grib_file, station_df, df):
    """
    Process a single GRIB file: find nearest grid points and save station data to CSV.
    
    Args:
        grib_file (str): The path to the GRIB file.
        station_df (pd.DataFrame): The station data DataFrame.
        df (pd.DataFrame): The grid data DataFrame extracted from the GRIB file.
    
    Returns:
        pd.DataFrame: Station data with grid attributes for the given file.
    """
    # Here you should load the GRIB data into `df`
    # Placeholder for GRIB loading logic
    # df = load_grib_file(grib_file)  # You need to implement this function
    
    # Build the KDTree and process the stations
    tree, grid_points = build_kdtree(df)
    nearest_grid_idx = find_nearest_grid_points(tree, station_df)
    station_df_with_grid_data = assign_nearest_grid_values(df, station_df, grid_points, nearest_grid_idx)
    
    return station_df_with_grid_data

# Loop through each year folder
for year in range(2022, 2022):
    input_year_dir = os.path.join(input_base_dir, str(year))
    output_year_dir = os.path.join(output_base_dir, str(year))
    
    # Create the output directory if it doesn't exist
    os.makedirs(output_year_dir, exist_ok=True)
    
    # Loop through all GRIB files in the current year folder
    for grib_file in os.listdir(input_year_dir):
        if grib_file.endswith('.grib'):
            # Extract the timestamp (YYYYMMDD_HHMM) from the filename
            timestamp = grib_file.replace('era5_land_', '').replace('.grib', '')
            
            # Construct the output CSV file path
            output_csv_file = os.path.join(output_year_dir, f"era5_land_{timestamp}.csv")
            
            # Full path to the current GRIB file
            grib_file_path = os.path.join(input_year_dir, grib_file)
            
            # Placeholder for station_df (you need to load station data before the loop)
            station_df = load_station_data()  # Assuming you have station data loaded here
            
            # Placeholder for grib_df (grid data loaded from GRIB file)
            grib_df = load_grib_data(grib_file_path)  # You need to implement this function
            
            # Process the GRIB file and get the output
            station_df_with_grid_data = process_grib_file(grib_file_path, station_df, grib_df)
            
            # Save the resulting DataFrame to CSV
            station_df_with_grid_data.to_csv(output_csv_file, index=False)

            print(f"Processed and saved: {output_csv_file}")
