In [None]:
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles
from rasterio.io import MemoryFile
import xarray as xr
import requests
import numpy as np
import rioxarray
from datetime import datetime, timedelta
import os
import pandas as pd
import fsspec
from google.cloud import storage
import json
import argparse
from dotenv import load_dotenv, find_dotenv

In [None]:
import os
from dotenv import load_dotenv, find_dotenv
import logging

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


# Default configuration values
DEFAULT_CONFIG = {
    'SERVICE_ACCOUNT_KEY': 'coiled-data-e4drr.json',
    'BUCKET_NAME': 'gefs-wgl',
    'EXTENT': [21.85, 51.50, 23.14, -11.72],
    'START_TIME': '0h',
    'END_TIME': '21h',
    'STORE': 'https://data.dynamical.org/noaa/gefs/forecast-35-day/latest.zarr'
}

def load_environment(env_file=None):
    """
    Load environment variables from specified .env file
    If no file is specified, it will try to find .env in the current directory
    """
    # If env_file is specified, load it
    if env_file and os.path.exists(env_file):
        print(f"Loading environment from: {env_file}")
        load_dotenv(env_file)
    # Otherwise, try to find .env in the current directory
    else:
        env_path = find_dotenv()
        if env_path:
            print(f"Loading environment from: {env_path}")
            load_dotenv(env_path)
        else:
            print("Warning: No .env file found, using default values")

    # Load configuration using DEFAULT_CONFIG as fallback values
    config = {
        'SERVICE_ACCOUNT_KEY': os.getenv('SERVICE_ACCOUNT_KEY', DEFAULT_CONFIG['SERVICE_ACCOUNT_KEY']),
        'BUCKET_NAME': os.getenv('BUCKET_NAME', DEFAULT_CONFIG['BUCKET_NAME']),
        'EXTENT': [
            float(os.getenv('EXTENT_X1', str(DEFAULT_CONFIG['EXTENT'][0]))),
            float(os.getenv('EXTENT_X2', str(DEFAULT_CONFIG['EXTENT'][1]))),
            float(os.getenv('EXTENT_Y1', str(DEFAULT_CONFIG['EXTENT'][2]))),
            float(os.getenv('EXTENT_Y2', str(DEFAULT_CONFIG['EXTENT'][3])))
        ],
        'START_TIME': os.getenv('START_TIME', DEFAULT_CONFIG['START_TIME']),
        'END_TIME': os.getenv('END_TIME', DEFAULT_CONFIG['END_TIME']),
        'STORE': os.getenv('STORE', DEFAULT_CONFIG['STORE'])
    }
    
    return config

In [None]:
config=load_environment(env_file=None)
config

In [None]:
"""
Create directory if it doesn't exist
"""
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

"""
Delete temporary file
"""
def delete_file(file_path):
    if os.path.exists(file_path):
        os.remove(file_path)
    return

"""
Upload temporary file to Google Storage Bucket
"""
def upload_tif_to_gcs(bucket_name, source_file_path, destination_blob_name, service_account_path):
    with open(service_account_path, 'r') as f:
        service_account_info = json.load(f)
    client = storage.Client.from_service_account_info(service_account_info)
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)
    blob.upload_from_filename(source_file_path)
    print(f"Uploaded to gs://{bucket_name}/{destination_blob_name}")


"""
Merge u & v tif files to single raster file (Band 1 & 2)
"""
def merge_uv(file1, file2, output_path):
    with rasterio.open(file1) as src1, rasterio.open(file2) as src2:
        band1 = src1.read(1)
        band2 = src2.read(1)

        profile = src1.profile
        profile.update(count=2)

        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(band1, 1)
            dst.write(band2, 2)
    return output_path

def reproject_tif(src_path, dst_path):
    dst_crs = 'EPSG:4326'
    with rasterio.open(src_path) as src:
        transform, width, height = calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds
        )
        profile = src.profile.copy()
        profile.update({
            "crs": dst_crs,
            "transform": transform,
            "width": width,
            "height": height,
            "driver": "GTiff",
            "tiled": True,
            "blockxsize": 512,
            "blockysize": 512,
            "compress": "deflate",
            "interleave": "band"
        })

        with MemoryFile() as memfile:
            with memfile.open(**profile) as mem:
                for i in range(1, src.count + 1):
                    reproject(
                        source=rasterio.band(src, i),
                        destination=rasterio.band(mem, i),
                        src_transform=src.transform,
                        src_crs=src.crs,
                        dst_transform=transform,
                        dst_crs=dst_crs,
                        resampling=Resampling.nearest
                    )

            # COG profile with tiling
            cog_profile = cog_profiles.get("deflate")

            cog_translate(
                memfile, dst_path,
                cog_profile,
                in_memory=True
            )


"""
Save dataset as tif
"""
def process_param_to_tif(e, t, param_data, file_path):
    td = np.timedelta64(t, 'ns')
    hours = td / np.timedelta64(1, 'h')
    param_data_sel = param_data.sel(lead_time=f"{int(hours)}hr").sel(ensemble_member=e)
    param_array = xr.DataArray(
        param_data_sel.values[:, :].astype(np.float32),
        dims=['latitude', 'longitude'],
        coords={
            'latitude': param_data_sel.latitude.values,
            'longitude': param_data_sel.longitude.values
        }
    )
    param_array.rio.write_crs("EPSG:4326", inplace=True)
    tmp_file_path = file_path + "_tmp.tif"
    param_array.rio.to_raster(
        tmp_file_path,
        driver='COG',
        compress='LZW',
        dtype='float32',
        nodata=np.nan
    )
    reproject_tif(tmp_file_path, file_path)
    delete_file(tmp_file_path)
    return file_path

"""
Filter dataset for given parameter
"""
def filter_param(ds, param,config):
       
    param_subset = ds[param].sel(
        # No need for init_time selection since the dataset is already filtered
        lead_time=slice(config['START_TIME'], config['END_TIME']),
        latitude=slice(config['EXTENT'][2], config['EXTENT'][3]),
        longitude=slice(config['EXTENT'][0], config['EXTENT'][1]),
    )
    param_data = param_subset.load()
    return param_data
"""
Process single parameter (Scalar)
"""
def process_scalar(ds,latest_init_time, param_obj, base_folder,config):
    param_data = filter_param(ds,param_obj['params'][0], config)
    print(param_data)
    for e in param_data.ensemble_member.values:
        for t in param_data.lead_time.values:
            latest_init_datetime = pd.to_datetime(latest_init_time)
            # Now add the timedelta
            forecast_datetime = latest_init_datetime + t
            # Format file name as param_YYYYMMDDHH.tif
            file_name = f"{param_obj['params'][0]}_{forecast_datetime.strftime('%Y%m%d%H')}.tif"
            if e > 0:  # Add ensemble member number if not the first ensemble
                file_name = f"{param_obj['params'][0]}_{forecast_datetime.strftime('%Y%m%d%H')}_{e}.tif"
                
            file_path = os.path.join(base_folder, file_name)
            
            process_param_to_tif(e, t, param_data, file_path)
            print(file_path)
            
            destination_blob_path = os.path.join(init_date_folder, file_name)
            upload_tif_to_gcs(
                bucket_name=config['BUCKET_NAME'],
                source_file_path=file_path,
                destination_blob_name=destination_blob_path,
                service_account_path=config['SERVICE_ACCOUNT_KEY']
            )
            #delete_file(file_path)

"""
Process vector parameter
"""
def process_vector(ds,latest_init_time, param_obj, base_folder, config):
    param_data_u = filter_param(ds, param_obj['params'][0], config)
    param_data_v = filter_param(ds, param_obj['params'][1], config)

    for e in param_data_u.ensemble_member.values:
        for t in param_data_u.lead_time.values:
            latest_init_datetime = pd.to_datetime(latest_init_time)
            # Now add the timedelta
            forecast_datetime = latest_init_datetime + t
            #forecast_datetime = pd.to_datetime(str(latest_init_time + t))
            # Format temporary file names
            temp_u_name = f"{param_obj['params'][0]}_{forecast_datetime.strftime('%Y%m%d%H')}_tmp.tif"
            temp_v_name = f"{param_obj['params'][1]}_{forecast_datetime.strftime('%Y%m%d%H')}_tmp.tif"
            
            file_path_u = os.path.join(base_folder, temp_u_name)
            file_path_v = os.path.join(base_folder, temp_v_name)

            process_param_to_tif(e, t, param_data_u, file_path_u)
            process_param_to_tif(e, t, param_data_v, file_path_v)

            # Format output file name as param_YYYYMMDDHH.tif
            file_name = f"{param_obj['id']}_{forecast_datetime.strftime('%Y%m%d%H')}.tif"
            if e > 0:  # Add ensemble member number if not the first ensemble
                file_name = f"{param_obj['id']}_{forecast_datetime.strftime('%Y%m%d%H')}_{e}.tif"
                
            file_path = os.path.join(base_folder, file_name)
            
            merge_uv(file_path_u, file_path_v, file_path)
            print(file_path)
            
            destination_blob_path = os.path.join(init_date_folder, file_name)
            upload_tif_to_gcs(
                bucket_name=config['BUCKET_NAME'],
                source_file_path=file_path,
                destination_blob_name=destination_blob_path,
                service_account_path=config['SERVICE_ACCOUNT_KEY']
            )
            
            # Clean up temporary files
            #delete_file(file_path)
            #delete_file(file_path_u)
            #delete_file(file_path_v)

def get_forecast_data(ds_global, date=None):
    """
    Get forecast data for a specific date or the latest available
    
    Parameters:
    - ds_global: xarray dataset containing forecast data
    - date: string in YYYYMMDD format, or None for latest available
    
    Returns:
    - dataset and initialization time
    """
    # If no specific date is requested, use the latest
    if date is None or date == "":
        latest_init_time = ds_global.init_time.values[-1]
        # Subset to the latest init time
        ds_subset = ds_global.sel(init_time=latest_init_time)
        return ds_subset, latest_init_time
    
    # Convert input date to datetime for comparison
    try:
        target_date = pd.to_datetime(date, format="%Y%m%d")
        
        # Find the closest initialization time to the requested date
        init_times = pd.to_datetime(ds_global.init_time.values)
        time_diffs = abs(init_times - target_date)
        closest_idx = time_diffs.argmin()
        selected_init_time = ds_global.init_time.values[closest_idx]
        
        # Subset to the selected init time
        ds_subset = ds_global.sel(init_time=selected_init_time)
        
        # Print info about selected date
        print(f"Requested date: {target_date.strftime('%Y-%m-%d')}")
        print(f"Using initialization time: {pd.to_datetime(selected_init_time).strftime('%Y-%m-%d %H:%M')}")
        
        return ds_subset, selected_init_time
        
    except ValueError:
        print(f"Warning: Invalid date format '{date}'. Using latest available date.")
        latest_init_time = ds_global.init_time.values[-1]
        # Subset to the latest init time
        ds_subset = ds_global.sel(init_time=latest_init_time)
        print(f"Using latest initialization time: {pd.to_datetime(latest_init_time).strftime('%Y-%m-%d %H:%M')}")
        return ds_subset, latest_init_time

        

In [None]:
ds_global = xr.open_zarr(config['STORE'], decode_timedelta=True, chunks=None, consolidated=True)

In [None]:
ds = ds_global
latest_init_time = ds.init_time.values[-1]
latest_init_time='20250407'
# Format the initialization date as YYYYMMDDz00
init_datetime = pd.to_datetime(latest_init_time)
init_date_str = init_datetime.strftime('%Y%m%d') + 'z00'

In [None]:
ds, init_time=get_forecast_data(ds_global, date=latest_init_time)

In [None]:
output_dir='./'
init_date_folder = init_date_str

# Create a local temporary directory for processing
temp_dir = os.path.join(output_dir, init_date_str)
ensure_dir(temp_dir)

params = [
    {
        'id': 'precipitation_surface',
        'type': 'scalar',
        'dType': 'float',
        'params': ['precipitation_surface']
    },
    {
        'id': 'wind_10m',
        'type': 'vector',
        'dType': 'float',
        'params': ['wind_u_10m', 'wind_v_10m']
    },
    {
        'id': 'temperature_2m',
        'type': 'scalar',
        'dType': 'float',
        'params': ['temperature_2m']
    }
]

print(f"Processing data for initialization date: {init_date_str}")
print(f"Output folder: {init_date_folder}")

try:
    for param_obj in params:
        print(f"Processing {param_obj['id']}...")
        if param_obj['type'] == 'vector':
            process_vector(ds, latest_init_time, param_obj, temp_dir,config)
        else:
            process_scalar(ds, latest_init_time, param_obj, temp_dir,config)
            
    print("Processing complete!")
finally:
    pass