In [1]:
import os
import h3
import json
import boto3
import shutil
import shapely
import numpy as np
import pandas as pd
import gemgis as gg
from time import time
from glob import glob
import rasterio as rio
import geopandas as gpd
import dask.dataframe as ddf
from rasterio.plot import show
from rasterio.mask import mask
import matplotlib.pyplot as plt
from rasterio.fill import fillnodata
from sklearn.metrics import r2_score
from rasterio.features import rasterize
from rasterio.enums import MergeAlg, Resampling
from shapely.geometry import box, mapping, Polygon

# Read me

### __This scripts consist of all functions required for processing weather, soil, and elevation data that have been queried from CSW. Functions include shapefile to raster conversion, reading/writing raster data, resample and masking raster data etc..__

## Refernce rasters and nodata

In [2]:
cities_california_28km_refraster = '../reference_rasters/cities_California_28km_ref_raster.tif'
cities_california_8km_refraster = '../reference_rasters/cities_California_8km_ref_raster.tif' 
cities_california_4km_refraster = '../reference_rasters/cities_California_4km_ref_raster.tif'  
cities_California_100m_refraster = '../reference_rasters/cities_California_100m_ref_raster.tif'

cities_california_buffer_28km_refraster = '../reference_rasters/cities_California_buffer_28km_ref_raster.tif' 
cities_california_buffer_8km_refraster = '../reference_rasters/cities_California_buffer_8km_ref_raster.tif'   
cities_california_buffer_4km_refraster = '../reference_rasters/cities_California_buffer_4km_ref_raster.tif'   

no_data_value = -9999

## Raster / Vector operations functions

In [3]:
def read_raster_arr_object(raster_file, rasterio_obj=False, band=1, get_file=True, change_dtype=True):
    """
    Get raster array and raster file.

    :param raster_file: Input raster filepath.
    :param rasterio_obj: Set True if raster_file is a rasterio object.
    :param band: Selected band to read. Default set to 1.
    :param get_file: Set to False if raster file is not required.
    :param change_dtype: Set to True if want to change raster data type to float. Default set to True.

    :return: Raster numpy array and rasterio object file (get_file=True, rasterio_obj=False).
    """
    if not rasterio_obj:
        raster_file = rio.open(raster_file)
    else:
        get_file = False
    raster_arr = raster_file.read(band)
    if change_dtype:
        raster_arr = raster_arr.astype(np.float32)
        if raster_file.nodata:
            raster_arr[np.isclose(raster_arr, raster_file.nodata)] = np.nan
    if get_file:
        return raster_arr, raster_file
    else:
        return raster_arr


def write_array_to_raster(raster_arr, raster_file, transform, output_path, ref_file=None, nodata=no_data_value):
    """
    Write raster array to Geotiff format.

    :param raster_arr: Raster array data to be written.
    :param raster_file: Original rasterio raster file containing geo-coordinates.
    :param transform: Affine transformation matrix.
    :param output_path: Output filepath.
    :param ref_file: Write output raster considering parameters from reference raster file.
    :param nodata: no_data_value set as -9999.

    :return: Output filepath.
    """
    if ref_file:
        raster_file = rio.open(ref_file)
        transform = raster_file.transform

    with rio.open(
            output_path,
            'w',
            driver='GTiff',
            height=raster_arr.shape[0],
            width=raster_arr.shape[1],
            dtype=raster_arr.dtype,
            count=1,  # raster_file.count
            crs=raster_file.crs,
            transform=transform,
            nodata=nodata
    ) as dst:
        dst.write(raster_arr, 1) #raster_file.count

    return output_path
   
    
def rasterize_shapefile(input_file, output_raster, attribute, ref_raster, date=None, grid_shapefile=None, 
                        merge_alg = MergeAlg.replace, dtype='float32', no_data_value=-9999, paste_on_ref_raster=False):
    """
    rasterize shapefile.
    
    params:
    input_file : Filepath of parquet (or already read geodataframe) file with the attribute. If parquet file should have a 
                 grid_id column to be matched with the grid_shapefile.
    output_raster : Filepath of output raster file.
    grid_shapefile : If parquet file in given as input_file, filepath of grid/geometry shapefile. 
                     Should have a grid_id column to be matched with the parquet file.
                     Default set to None so that 
    attribute : Attriute column (in str) of parquet/gdf to rasterize. 
    ref_raster : Reference raster to be used in assigning rasterization shape, transform.
    date : Default set to None. Set to str of date if want to filter the parquet/gdf for a specific date (specially for weather data).
    merge_alg : Rasterio merge algorithm. Can be either MergeAlg.replace (to replace value) or 
                MergeAlg.add (to add value to existing value). Default set to MergeAlg.replace.
    dtype : Data type of raster. Default set to Float32.
    no_data_value : No data value assigned to raster. Default set to -9999.
    paste_on_ref_raster : Set to True if want to paste rasterized values on a reference rasters. In this case, the raster will
                          have similar no_data pixels as reference rasters.              
    
    returns: The output raster filepath.
    """
    if 'parquet' in input_file: # if parquet file used as input_file
        gdf = read_parquet_as_geodataframe(parquet_file=input_file, grid_geometry_file=grid_shapefile, save=False)
    else: # if geodataframe used as input_file
        gdf = input_file
        
    if date is not None:
        gdf = gdf[gdf['date'] == date]

    ref_arr, ref_file = read_raster_arr_object(ref_raster)
    input_shape = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attribute]))

    raster_arr = rasterize(shapes=input_shape, out_shape=ref_arr.shape, fill=no_data_value, out=None, 
                           transform=ref_file.transform, all_touched=True, 
                           default_value=no_data_value, dtype=dtype, merge_alg=merge_alg)
    
    if paste_on_ref_raster:
        raster_arr[np.isnan(ref_arr)] = no_data_value
        
    write_array_to_raster(raster_arr, raster_file=ref_file, transform=ref_file.transform, 
                          output_path=output_raster, ref_file=None, nodata=no_data_value)

    return output_raster



def resample_raster_based_on_ref_raster(input_raster, ref_raster, output_dir, raster_name, resampling_alg=Resampling.bilinear,
                                        paste_value_on_ref_raster=False):
    """
    Resample raster based on a refernce raster.
    
    params:
    input_raster : Filepath of input raster to resample.
    ref_raster : Filepath of input raster to be used in determining resample height/width/affine transformation/crs/dtype/nodata.
    output_raster : Filepath of resampled output raster.
    resampling_alg : resampling algorithm. Can be Resampling.nearest/ Resampling.bilinear/Resampling.cubic or 
                     any resampling algorith rasterio supports Default set to Resampling.bilinear.
    paste_value_on_ref_raster : Set to True if want to have nodata pixels on the resampled raster similar to reference raster. 
    
    returns: The resampled output raster filepath.
    """
    makedirs([output_dir])
    
    ref_arr, ref_file = read_raster_arr_object(ref_raster)
    
    # target shape. use a reference raster (created using GIS for a specific region) to decide.
    resampled_height, resampled_width = ref_arr.shape

    with rio.open(input_raster) as dataset:
        # resample data to target shape
        resampled_arr = dataset.read(1,
                            out_shape=(1,
                                       resampled_height,
                                       resampled_width),
                            resampling=resampling_alg)

        resampled_arr = resampled_arr.squeeze() # removing the 1 (for count) from the dimension
        
        if paste_value_on_ref_raster:
            resampled_arr = np.where(np.isnan(ref_arr), -9999, resampled_arr)
        
        # Saving the resampled data
        output_raster = os.path.join(output_dir, raster_name)
        write_array_to_raster(raster_arr=resampled_arr, raster_file=ref_file, 
                              transform=ref_file.transform, output_path=output_raster, 
                              ref_file=None, nodata=-9999)
        
        return output_raster
    

def resample_raster_with_height_width(input_raster, height, width, ref_raster, output_dir, raster_name, 
                                      resampling_alg=Resampling.bilinear):
    """
    Resample raster based on a refernce raster.
    
    params:
    input_raster : Filepath of input raster to resample.
    height, width : integer of resampling height and width.
    ref_raster : Filepath of input raster to be used in determining resample affine transformation/crs/dtype/nodata.
    output_raster : Filepath of resampled output raster.
    resampling_alg : resampling algorithm. Can be Resampling.nearest/ Resampling.bilinear/Resampling.cubic or 
                     any resampling algorith rasterio supports Default set to Resampling.bilinear.
    
    returns: The resampled output raster filepath.
    """
    makedirs([output_dir])
    
    ref_arr, ref_file = read_raster_arr_object(ref_raster)

    with rio.open(input_raster) as dataset:
        # resample data to target shape
        resampled_arr = dataset.read(1,
                            out_shape=(1,
                                       height, width),
                            resampling=resampling_alg)

        resampled_arr = resampled_arr.squeeze() # removing the 1 (for count) from the dimension
        
        # Saving the resampled data
        output_raster = os.path.join(output_dir, raster_name)
        write_array_to_raster(raster_arr=resampled_arr, raster_file=ref_file, 
                              transform=ref_file.transform, output_path=output_raster, 
                              ref_file=None, nodata=-9999)
        
        return output_raster
    
    
def mask_raster_array_by_shapefile(input_raster, mask_shape, output_dir=None, raster_name=None, invert=False,
                                   crop=True, save_masked_arr=False):
    """
    Mask a raster using a input shapefile.

    Parameters:
    input_raster: Input raster filepath.
    mask_shape : Reference shape file to crop input_raster.
    output_dir : Defaults to None. Set a output raster directory path if save_masked_arr is True.
    raster_name : Defaults to None. Set a output raster name if save_masked_arr is True.
    invert : If False (default) pixels outside shapes will be masked.
             If True, pixels inside shape will be masked.
    crop : Whether to crop the raster to the extent of the shapes. Set to False if invert=True is used.
    save_masked_arr : Set to true if want to save cropped/masked raster array. If True, must provide output_raster_name and
                       output_dir.

    returns : Masked raster array and masked raster filepath.
    """
    input_arr, input_file = read_raster_arr_object(input_raster)
    
    shapefile = gpd.read_file(mask_shape)
    geoms = shapefile['geometry'].values  # list of shapely geometries
    geoms = [mapping(geoms[0])]
    
    # masking
    masked_arr, masked_transform = mask(dataset=input_file, shapes=geoms, filled=True, crop=crop, invert=invert, 
                                        all_touched=False)
    masked_arr = masked_arr.squeeze()  # Remove axes of length 1 from the array
    

    if save_masked_arr:
        # naming output file
        makedirs([output_dir])
        output_raster = os.path.join(output_dir, raster_name)

        # saving output raster
        masked_raster = write_array_to_raster(raster_arr=masked_arr, raster_file=input_file, transform=masked_transform,
                              output_path=output_raster)
        return masked_arr, masked_raster
   
    else: # in case raster is not saved return only masked raster array
        return masked_arr


def convert_point_geom_to_poly_geom_from_centroid(point_shapefile, output_shapefile, crs='EPSG:4326'):
    """
    Convert point shapefile to polygon shapefile. 
    ** The cell size will be figured out from distance betwwen centroids.
    
    params:
    point_shapefile : Filepath of point shaepfile.
    output_shapefile : Filepath of output shapefile.
    crs : Default crs set to 'EPSG:4326'. 
    
    return: The output polygon geometry filepath.
    """
    point_gdf = gpd.read_file(point_shapefile)

    # We are not directly inputting cell size as an argument because our calculated/presumed cell size
    # might not exact match with grid points' center-to-center distance. Instead, we are calculating cell size
    # directly from two (02) adjacent points' center-to-center distance
    cell_size = None
    
    for i, point in enumerate(point_gdf['geometry']):
        if i==0:  #1st point's coords
            x1 = point.x
            y1 = point.y
            
        else: #2nd point's coords
            x2 = point.x
            y2 = point.y

            # adding a condition so that cell size can be calculated after 2nd itration (after 2nd point's coords has been collected)
            if (i>0) & ((x2-x1) > 0.04):  ## taking 0.04 as the filter resolution as both TWC/TWC precip/ERA5 resolution is equal or higher than this values
                cell_size = abs(abs(x2)- abs(x1))
                break

    # create the cells/polygons in a loop
    grid_cells = []
    for point in point_gdf['geometry']:
        x = point.x
        y = point.y

        # Calculating new polygon bounds
        x0 = x - (cell_size/2)
        y0 = y - (cell_size/2)
        x1 = x + (cell_size/2)
        y1 = y + (cell_size/2)

        grid_cells.append(shapely.geometry.box(x0, y0, x1, y1))

    poly_gdf = point_gdf.drop(columns=['geometry'])
    poly_gdf['geometry'] = grid_cells
    poly_gdf = gpd.GeoDataFrame(poly_gdf, geometry='geometry', crs=crs)
    poly_gdf.to_file(output_shapefile)

    return output_shapefile


def mask_datasets_with_shapefile(input_raster_dir, main_output_dir, mask_shape, exclude_datasets=None,
                                 resample=False, resample_target_raster=None):
    """
    Mask datasets in a folder. Will first try going into the subdirectories of the input_raster_dir. 
    If there are no sub-directories, it will collect all the .tif files in the input_raster_dir and mask them.
    
    input_raster_dir : Main directory filepath that has sub-directories with the .tif files. 
                      If there are no sub-directories, will collect the .tif files and process them.
    main_output_dir : The main output directory filepath where masked rasters will be saved.
    mask_shape : Filepath of shapefile that will be used to mask the rasters.
    exclude_datasets : List of datasets to exclude from processing. Default set to None. 
    
    resample : Set to True if want to resample data to height and weight of any target reference raster.
               Default set to None.
    resample_target_raster : Filepath of target raster to use if resampling. Default set to None.
    """
    # if there are sub-directories inside the input_raster_folder
    try:
        variables = os.listdir(input_raster_dir)
        if exclude_datasets is not None:
            variables = [i for i in variables if i not in exclude_datasets]

        for var in variables:
            datasets = glob(os.path.join(input_raster_dir, var, '*.tif'))
            print(f'Masking data for {var}...')

            for data in datasets:
                # creating output folders
                if resample:
                    output_dir = os.path.join(main_output_dir, var, 'masked')
                else:
                    output_dir = os.path.join(main_output_dir, var)
                
                # masking
                raster_name = os.path.basename(data).split('.')[0] + '.tif'
                arr, masked_raster_fp = mask_raster_array_by_shapefile(input_raster=data, mask_shape=mask_shape, 
                                                                       output_dir=output_dir, 
                                                                       raster_name=raster_name, invert=False,
                                                                       crop=True, save_masked_arr=True)
                if resample:
                    # resampling to make sure the processed raster has the same height*width as as the reference raster
                    resampled_output_dir = os.path.join(main_output_dir, var)
                    resample_raster_based_on_ref_raster(input_raster=masked_raster_fp, ref_raster=resample_target_raster, 
                                                        output_dir=resampled_output_dir, raster_name=raster_name)
    

    except:     # If input_raster_dir is the main directory and has all the .tif files
        datasets = glob(os.path.join(input_raster_dir, '*.tif'))   
        for data in datasets:
            if resample:
                output_dir = os.path.join(main_output_dir, 'masked')

            raster_name = os.path.basename(data).split('.')[0] + '.tif'
            arr, masked_raster_fp = mask_raster_array_by_shapefile(input_raster=data, mask_shape=mask_shape, 
                                                               output_dir=main_output_dir, 
                                                               raster_name=raster_name, invert=False,
                                                               crop=True, save_masked_arr=True)
            if resample:
                resampled_output_dir = main_output_dir
                # resampling to make sure the processed raster has the same height*width as as the reference raster
                resample_raster_based_on_ref_raster(input_raster=masked_raster_fp, ref_raster=resample_target_raster, 
                                                    output_dir=resampled_output_dir, raster_name=raster_name)
                
                
def fill_nodata_by_interpolation(input_raster, reference_raster, output_raster,
                                 use_pixel_to_interp=100):
    """
    Interpolate values to fill nodata gaps in input raster.
    
    params:
    input_raster : Filepath of input raster. 
    reference_raster : Filepath of reference raster. The reference raster will be used to decide where nodata values 
                       will be filled/interpolated.
    output_raster : Filepath of output raster.
    
    returns :None.
    """
    data_arr, data_file = read_raster_arr_object(input_raster)
    ref_arr, ref_file = read_raster_arr_object(reference_raster)
    
    # changing no data from -9999 to np.nan. Otherwise, the masking operation and furthur interpolation doesn't work.
    data_arr[data_arr == data_file.nodata] = np.nan
    ref_arr[ref_arr == ref_file.nodata] = np.nan
    
    # creating mask data where zero value locations will be interpolated
    mask_arr = data_arr.copy()
    mask_arr[~np.isnan(data_arr)] = 1
    mask_arr[np.isnan(data_arr)] = 0
    
    
    # Interpolation of array to fill nodata locations
    interp_arr = fillnodata(data_arr, mask=mask_arr, max_search_distance=use_pixel_to_interp, smoothing_iterations=0)
    interp_arr[np.isnan(ref_arr)] = -9999  # interpolation causes some increase in extent. removing that with ref array
    
    write_array_to_raster(raster_arr=interp_arr, raster_file=data_file, transform=data_file.transform, 
                          output_path=output_raster)

## Database (parquet / geodataframe) operations

In [4]:
def add_geometry_to_h3(row):
    """
    create polygon geometry for h3 index.
    
    params:
    row : dataframe row. 
    
    returns: polygon geometry for h3 index.
    """
    points = h3.h3_to_geo_boundary(row['h3'], True)
    
    return Polygon(points)


def read_h3_parquet_save_as_geodataframe(parquet_file, h3_geometry_file, save=False, output_folder=None, savename=None):
    """
    Read parquet file with h3 information and save it as a geodataframe.
    
    params:
    parquet_file : Filepath of parquet file. Must have h3 information.
    h3_geometry-file: Filepath of h3 geometry file. Must have matching h3 information with the parquet file and geometry info.
    save : Set to true if want to save data as geodataframe/shapefile.
    output_folder : str of output folder to save the data. Default set to None.
    savename : str of name of the shapefile. Default set to None.
    
    returns: geopandas dataframe with data information.
    """
    df_parq = pd.read_parquet(parquet_file) # must have h3 information
    df_h3 = gpd.read_file(h3_geometry_file) # must have matching h3 information with the parquet file and geometry info
    
    df_compiled = df_parq.merge(df_h3, on='h3', how='inner')
    gdf_compiled = gpd.GeoDataFrame(df_compiled, geometry='geometry')
    
    if save:
        savefile = os.path.join(output_folder, savename)
        gdf_compiled.to_file(savefile)
    
    return gdf_compiled

def read_parquet_as_geodataframe(parquet_file, grid_geometry_file, save=False, output_folder=None, savename=None):
    """
    Read parquet file with twc/era5 grid information and save it as a geodataframe.
    
    params:
    parquet_file : Filepath of parquet file. Must have twc grid information.
    grid_geometry-file: Filepath of twc/era5 grid geometry file. Must have matching twc grid information with the 
                            parquet file and geometry info.
    save : Set to true if want to save data as geodataframe/shapefile. Large files will have error (dieing kernel). 
           Better to not save when facing such issues.
    output_folder : str of output folder to save the data. Default set to None.
    savename : str of name of the shapefile. Default set to None.
    
    returns: geopandas dataframe with data information.
    
    """
    df_parq = pd.read_parquet(parquet_file) # must have twc/era5 grid information
    df_grid = gpd.read_file(grid_geometry_file) # must have matching twc grid information with the parquet file and geometry info
    
    df_compiled = df_parq.merge(df_grid, on='grid_id', how='inner')
    gdf_compiled = gpd.GeoDataFrame(df_compiled, geometry='geometry')
    
    if save:
        savefile = os.path.join(output_folder, savename)
        gdf_compiled.to_file(savefile)
    
    return gdf_compiled


def clip_grids_by_admin(grids_file, admin_file, output_folder, savename):
    """
    Clip a shapefile/geodataframe with another shapefile//geodataframe.
    
    params:
    grids_file : shapefile path/geodataframe with twc_grid/era5_rid information.
    admin_file : shapefile path/geodataframe used to clip the grids_file.
    output_folder : str of output folder to save the data. 
    savename : str of name of the shapefile.
    
    returns: geopandas dataframe of cliiped shapefile/geodataframe.
    """
    
    if '.shp' not in grids_file or admin_file:
        grids_df = grids_file
        admin_df = admin_file
    else:
        grids_df = gpd.read_file(grids_file)
        admin_df = gpd.read_file(admin_file)

        
    clipped_gdf = gpd.clip(grids_df['geometry'], admin_df['geometry'])
    clipped_gdf = gpd.GeoDataFrame(clipped_gdf, geometry='geometry')
    clipped_gdf = clipped_gdf.join(grids_df, on=None, how='left', lsuffix='', rsuffix='R')  # merging lost grids_df info to the clipped grids 
    clipped_gdf = clipped_gdf.drop(columns=['geometry', 'geometryR'])
    clipped_gdf = gpd.GeoDataFrame(clipped_gdf, geometry=gpd.points_from_xy(clipped_gdf.lon, clipped_gdf.lat), 
                                   crs="EPSG:4326")
    clipped_gdf = clipped_gdf.dropna()
    clipped_gdf = clipped_gdf.reset_index()
    
    makedirs([output_folder])
    savefile = os.path.join(output_folder, savename)
    clipped_gdf.to_file(savefile)
    
    return clipped_gdf

## Weather Data processing codes

In [6]:
def make_lat_lon_array_from_raster(input_raster, nodata=-9999):
    """
    Make lat, lon array for each pixel using the input raster.
    
    params:
    input_raster : Input raster filepath that will be used as reference raster.
    nodata : No data value. Default set to -9999.
    
    returns: Lat, lon array with nan value (-9999) applied.
    """
    
    raster_file = rio.open(input_raster)
    raster_arr = raster_file.read(1)

    # calculating lat, lon of each cells centroid
    height, width = raster_arr.shape
    cols, rows = np.meshgrid(np.arange(width), np.arange(height))
    xs, ys = rio.transform.xy(rows=rows, cols=cols, transform=raster_file.transform)
    
    # flattening and reshaping to the input_raster's array size
    xs = np.array(xs).flatten()
    ys = np.array(ys).flatten()
    
    lon_arr = xs.reshape(raster_arr.shape)
    lat_arr = ys.reshape(raster_arr.shape)
    
    # assigning no_data_value
    lon_arr[raster_arr==nodata] = nodata
    lat_arr[raster_arr==nodata] = nodata
    
    return lon_arr, lat_arr


def process_twc_daily_data(twc_parquet_file, twc_geom_shp, ref_raster,
                           remove_cols = ['grid_id', 'date', 'index', 'elevation', 'time_zone', 'geometry', 'lat', 'lon'],
                           twc_output_dir='../../datasets/weather_raster_data/twc_data'):
    """
    Process TWC daily data for each column attribute.
    
    params:
    twc_parquet_file : Filepath of parquet file with the attributes. The parquet file should have a 
                       grid_id column to be matched with the grid_shapefile.
    twc_geom_shp : Filepath of grid/geometry shapefile. 
                   Should have a grid_id column to be matched with the parquet file.
    ref_raster : Filepath of reference raster to be used in shapefile to raster conversion. 
                 For example, california_4km_refraster as TWC's original grid size is 4km.
    remove_cols : Columns in the TWC file that will not be rasterized.
    twc_output_dir : Filepath of folder where rasterized data will be saved.

    returns: None.
    """
    twc_gdf = read_parquet_as_geodataframe(parquet_file=twc_parquet_file, grid_geometry_file=twc_geom_shp, 
                                           save=False, output_folder=None, savename=None)
    # creating output folder
    makedirs([twc_output_dir])
    
    # making list of columns to be rasterized
    twc_attr = list(twc_gdf.columns)
    keep_attr = [i for i in twc_attr if i not in remove_cols]
    
    # making list of unique dates in twc data
    dates = twc_gdf['date'].unique()
    
    # Looping through each attribute and each date to rasterize the data
    for attr_col in keep_attr:
        print(f'Processing TWC {attr_col} dataset...')
        
        # making new output directory for specific attribute
        process_to_dir = os.path.join(twc_output_dir, attr_col)  
        makedirs([process_to_dir])    
        
        for date in dates:
            # making raster name
            date_str = ''.join(date.split('-'))
            output_raster_fp = os.path.join(process_to_dir, f'{attr_col}_{date_str}.tif')

            # rasterization
            rasterize_shapefile(input_file=twc_gdf, grid_shapefile=None, attribute=attr_col, 
                                date=date, ref_raster=ref_raster, output_raster=output_raster_fp, 
                                merge_alg = MergeAlg.replace, dtype='float32', no_data_value=-9999)
    
    # making latitude and longitude rasters
    print(f'Processing lat, lon dataset...')
    ref_arr, ref_file = read_raster_arr_object(ref_raster)
    lon_arr, lat_arr = make_lat_lon_array_from_raster(ref_raster)
    
    lon_dir = os.path.join(twc_output_dir, 'lon')
    lat_dir = os.path.join(twc_output_dir, 'lat')
    makedirs([lon_dir, lat_dir])

    write_array_to_raster(raster_arr=lon_arr, raster_file=ref_file, transform=ref_file.transform, 
                          output_path=os.path.join(lon_dir, 'lon.tif'))
    write_array_to_raster(raster_arr=lat_arr, raster_file=ref_file, transform=ref_file.transform, 
                          output_path=os.path.join(lat_dir, 'lat.tif'))


def compile_twc_daily_data_to_dataframe(savename, twc_data_folder, output_folder):
    """
    Compile twc daily data in a dataframe. All datasets have to be of same shape.
    
    params:
    savename : (str) Name of the output parquet file.
    twc_data_folder : TWC daily data main folder. The code will automatically get data in the sub-directories.
    output_folder : Main output folder. The code will automatically save data in the individual sub-directories.
    
    returns: compiled TWC dataframe.
    """
    start_time = time()
    
    makedirs([output_folder])
    
    # making list of variables in the twc raster data folder
    variable_names = os.listdir(twc_data_folder)
    variable_paths = [os.path.join(twc_data_folder, folder) for folder in variable_names]
    
    # will be used to multiply lat/lon data 
    # the condition is added because we have to process different TWC data and not all has max_temp or total_precip
    num_days = len(glob(os.path.join(twc_data_folder, 'max_temp', '*.tif')))
    if num_days > 1:
        pass
    else:
        num_days = len(glob(os.path.join(twc_data_folder, 'total_precip', '*.tif')))
    
    
    variable_dict = {}  # a dictionary where daily dataset values will be stored under variable_name 
    
    for path in variable_paths:
        all_data = glob(os.path.join(path, '*.tif')) # making list of all dataset in a particular folder
        all_data = sorted(all_data)  # to sort data by date so that all variables are compiled in same serial
        
        variable_name = os.path.basename(path).split('.')[0]  # extracted variable name 
        if variable_name not in ['lat', 'lon']:
            print(f'compiling data for {variable_name}...')

            # loop for reading datasets and storing pixel info in a dictionary
            for count, data in enumerate(all_data):
                # retrieving and storing data
                data_arr = read_raster_arr_object(data, get_file=False).flatten()  # read data as array and flattened it

                # extarcting, formatting, and storing date info
                date = os.path.basename(data).split('.')[0].split('_')[-1]
                year, month, day = date[:4], date[4:6], date[6:]

                len_data = len(data_arr)  # number of pixels in each daily dataset (array)
                year_list = [int(year)] * len_data
                month_list = [int(month)] * len_data
                day_list = [int(day)] * len_data
                date_list = [int(date)] * len_data

                # Assigning all values to the variable_dict
                if count == 0:
                    variable_dict[variable_name] = list(data_arr)  # storing flattened data in a dictionary under the variable name
                    variable_dict['date'] = date_list
                    variable_dict['year'] = year_list
                    variable_dict['month'] = month_list
                    variable_dict['day'] = day_list

                else:
                    variable_dict[variable_name].extend(list(data_arr))  # storing flattened data in a dictionary under the variable name)
                    variable_dict['date'].extend(date_list)
                    variable_dict['year'].extend(year_list)
                    variable_dict['month'].extend(month_list)
                    variable_dict['day'].extend(day_list)
        else: # for lat/lon
            data = glob(os.path.join(path, '*.tif'))[0]##############
            data_arr = read_raster_arr_object(data, get_file=False).flatten()  # read data as array and flattened it
            data_duplicated_for_days = list(data_arr) * num_days
            variable_dict[variable_name] = data_duplicated_for_days

    twc_variable_df = pd.DataFrame(variable_dict)
    twc_variable_ddf = ddf.from_pandas(twc_variable_df, npartitions=20)
    twc_variable_ddf = twc_variable_ddf.dropna()
    twc_variable_ddf = twc_variable_ddf.reset_index()
    
    if '.parquet' in savename:
        output_parquet_file = os.path.join(output_folder, savename)
    else:
        output_parquet_file = os.path.join(output_folder, savename+'.parquet')
    
    twc_variable_ddf.to_parquet(output_parquet_file)
    
    end_time = time()
    print('time taken', round((end_time-start_time)/60, 3), 'mins')

    return twc_variable_ddf
      
    
def process_era5_daily_data(era5_parquet_file, era5_geom_shp, ref_raster_rasterize,
                            ref_raster_resample, resampling_alg=Resampling.bilinear, 
                            remove_cols = ['grid_id', 'date', 'index', 'lat', 'lon', 'time_zone', 'geometry'],
                            era5_output_dir='../../datasets/raster_data/era5_data',
                            resampled_output_dir= '../../datasets/raster_data/era5_data/resampled_4km_rasters_nearest'):
    """
    Process ERA5 daily data for each column attribute.
    
    params:
    era5_parquet_file : Filepath of parquet file with the attributes. The parquet file should have a 
                        grid_id column to be matched with the grid_shapefile.
    era5_geom_shp : Filepath of grid/geometry shapefile. 
                    Should have a grid_id column to be matched with the parquet file.
    ref_raster_rasterize : Filepath of reference raster (a 28km/0.25 deg raster) to be used in shapefile to raster conversion.
    ref_raster_resample : Filepath of reference raster to be used in resampling.
    resampling_alg : resampling algorithm. Can be Resampling.nearest/ Resampling.bilinear/Resampling.cubic or 
                     any resampling algorith rasterio supports Default set to Resampling.bilinear.
              
    remove_cols : Columns in the era5 file that will not be rasterized.
    era5_output_dir : Filepath of folder where rasterized data will be saved.
    resampled_output_dir : Filepath of folder where resampled data will be saved. Made this a variable
                           because I am resampling data in both nearest_neighbour and bilinear approach.
                           Default set to '../../datasets/raster_data/era5_data/resampled_4km_rasters_nearest'.
    
    
    returns: None.
    """
    era5_gdf = read_parquet_as_geodataframe(parquet_file=era5_parquet_file, grid_geometry_file=era5_geom_shp, 
                                            save=False, output_folder=None, savename=None)
    era5_gdf = era5_gdf.dropna()
    
    # creating output folders for original rasters, resampled raster
    rasterization_dir = os.path.join(era5_output_dir, 'original_28km_rasters')
    resampled_raster_dir = os.path.join(resampled_output_dir)
    makedirs([era5_output_dir, rasterization_dir, resampled_raster_dir])
    
    # making list of columns to be rasterized
    era5_attr = list(era5_gdf.columns)
    keep_attr = [i for i in era5_attr if i not in remove_cols]
    
    # making list of uniqye dates in era5 data
    dates = era5_gdf['date'].unique()

    # Looping through each attribute and each date to rasterize the data
    for attr_col in keep_attr:
        print(f'Processing ERA5 {attr_col} dataset...')
        
        # making new output directory for specific attribute
        rasterize_to_dir = os.path.join(rasterization_dir, attr_col)  # original 28km rasters will be saved  
        resample_to_dir = os.path.join(resampled_raster_dir, attr_col)  # resampled 4km rasters will be saved
        makedirs([rasterize_to_dir, resample_to_dir])    
        
        for date in dates:
            # making raster name
            date_str = ''.join(date.split('-'))

            # rasterization
            output_raster_fp = os.path.join(rasterize_to_dir, f'{attr_col}_{date_str}.tif')
            era5_raster = rasterize_shapefile(input_file=era5_gdf, grid_shapefile=None, attribute=attr_col, 
                                              date=date, ref_raster=ref_raster_rasterize, 
                                              output_raster=output_raster_fp, 
                                              merge_alg = MergeAlg.replace, 
                                              dtype='float32', no_data_value=-9999)
            
            # resampling raster
            output_raster_name = f'{attr_col}_{date_str}.tif'
            resample_raster_based_on_ref_raster(input_raster=era5_raster, resampling_alg=resampling_alg,
                                                ref_raster=ref_raster_resample, 
                                                 output_dir=resample_to_dir, raster_name=output_raster_name)
            
            
def compile_era5_daily_data_to_multiple_dataframe(dataset_in_each_chunk, save_keyword, era5_data_folder, output_folder):
    
    """
    Compile ERA5 daily data in multiple dataframes. 
    *** All input daily datasets have to be of same shape.
    *** If datasets are processed in chunks, multiple dataframes will be created and user have to read them separately for 
    further processing (use compile_era5_multiDF_to_singleDF() to compile into a single dataframe)
    
    params:
    dataset_in_each_chunk: (int). Number of datasets to process in each chunk. If dataset is small, give total number of 
                           variables to process everything in a signle chunk.
    save_keyword : A keyword (str) to distinguish between 4km/8km parquet files. Can set to '4km'/'8km'.
    era5_data_folder : ERA5 daily data main folder. The code will automatically get data in the sub-directories.
    output_folder : Main output folder. The code will automatically save data in the individual sub-directories.
    
    returns: None. 
    """
    makedirs([output_folder])
    
    # making list of variables in the era5 raster data folder
    variable_names = os.listdir(era5_data_folder)
    variable_chunks = [variable_names[x:x+dataset_in_each_chunk] for x in range(0, len(variable_names), dataset_in_each_chunk)]

    # will be used to multiply lat/lon data
    num_days = len(glob(os.path.join(era5_data_folder, 'total_precip', '*.tif')))
    
    # first loop for each set/chuck of datasets
    for num, chunk in enumerate(variable_chunks):
        start_time = time()
        
        # removing lat/lon if selected in chunk. The, adding lat+lon data in each chunk for merging purpose with TWC. 
        # If all data are processed in one chunk this isn't required, but homogenizing for all datasets.
        if 'lat' in chunk:
            chunk.remove('lat')
        if 'lon' in chunk:
            chunk.remove('lon')
        chunk.extend(['lat', 'lon'])
        
        print(f'processing for {chunk}..')
        
        variable_paths = [os.path.join(era5_data_folder, folder) for folder in chunk]  
    
        variable_dict = {}  # a dictionary where daily dataset values will be stored under variable_name 
    
        for path in variable_paths:
            all_data = glob(os.path.join(path, '*.tif')) # making list of all dataset in a particular folder
            all_data = sorted(all_data)  # to sort data by date so that all variables are compiled in same serial
            
            # extracting variable name
            variable_name = os.path.basename(path).split('.')[0]   
            print(f'compiling data for {variable_name}...')

            if variable_name not in ['lat', 'lon', 'elevation', 'slope', 'aspect']:
                # loop for reading datasets and storing pixel info in a dictionary
                for count, data in enumerate(all_data):
                    # retrieving and storing data
                    data_arr = read_raster_arr_object(data, get_file=False).flatten()  # read data as array and flattened it
            
                    # extarcting and storing date info
                    date = os.path.basename(data).split('.')[0].split('_')[-1]

                    len_data = len(data_arr)  # number of pixels in each daily dataset (array)
                    date_list = [int(date)] * len_data

                    # Assigning all values to the variable_dict
                    if count == 0:
                        variable_dict[variable_name] = list(data_arr)  # storing flattened data in a dictionary under the variable name
                        variable_dict['date'] = date_list

                    else:
                        variable_dict[variable_name].extend(list(data_arr))  # storing flattened data in a dictionary under the variable name)
                        variable_dict['date'].extend(date_list)

            else: # for lat/lon/elevation/slope/aspect data
                data = glob(os.path.join(path, '*.tif'))[0]
                data_arr = read_raster_arr_object(data, get_file=False).flatten()  # read data as array and flattened it
                data_duplicated_for_days = list(data_arr) * num_days
                variable_dict[variable_name] = data_duplicated_for_days

        era5_variable_df = pd.DataFrame(variable_dict)
        era5_variable_ddf = ddf.from_pandas(era5_variable_df, npartitions=20)
        era5_variable_ddf = era5_variable_ddf.dropna()
        era5_variable_ddf = era5_variable_ddf.reset_index()

        output_parquet_file = os.path.join(output_folder, f'{save_keyword}_era5_daily_data_{num}.parquet')
        era5_variable_ddf.to_parquet(output_parquet_file)

        end_time = time()
        print('time taken', round((end_time-start_time)/60, 3), 'mins')

        
def compile_era5_multiDF_to_singleDF(parquet_folder, output_folder, save_keyword):
    """
    Compile multiple dataframe of era5 data (generated by compile_era5_daily_data_to_multiple_dataframe()) 
    into a single dataframe.
    
    params:
    parquet_folder : Filepath of folder where multiple parquet files (dataframes) are saved.
    output_folder : Filepath of output folder where single parquet file (dataframe) with all era5 variales will be saved.
    save_keyword : A keyword (str) to distinguish between 4km/8km parquet files. Can set to '4km'/'8km'.
    
    returns: Compiled single dataframe.
    """
    parquet_files = glob(os.path.join(parquet_folder, '*.parquet'))
    for parq in parquet_files:
        df = pd.read_parquet(parq)
        df = df.drop(columns=['index'])
        if parq == parquet_files[0]:
            compiled_df = df
        else:
            compiled_df = compiled_df.merge(df, on=['date', 'lat', 'lon'])
    
    output_parquet = os.path.join(output_folder, f'{save_keyword}_era5_daily_data.parquet')
    compiled_df.to_parquet(output_parquet)
    
    return compiled_df
        
    
def combine_twc_era5_datasets(twc_dataset, era5_dataset, output_file, merge_on=['date', 'lat', 'lon'], how='inner'):
    """
    Combine twc era5 datasets.
    
    params:
    twc_dataset : TWC dataframe filepath or dataframe.
    era5_datset : ERA5 dataframe filepath or dataframe.
    output_file : Combined TWC and ERA5 dataframe output filepath.
    merge_on : List of columns to use in dataframe merging. Default set to ['date', 'lat', 'lon'].
    how : Type of merging. Default set to 'inner'.
    
    returns: Combined TWC and ERA5 dataframe.
    """
    if isinstance(twc_dataset, pd.DataFrame):
        twc_df = twc_dataset
        era5_df = era5_dataset
    else:
        twc_df = pd.read_parquet(twc_dataset)
        era5_df = pd.read_parquet(era5_dataset)
    
    twc_era5_combined = twc_df.merge(era5_df, on=merge_on, suffixes=('_twc', '_era5'), how=how)
    twc_era5_combined.to_parquet(output_file)
    
    return twc_era5_combined
            
            
def resample_weather_datasets_to_100m(variables_to_resample, input_data_main_dir, resampled_output_main_dir, 
                                     target_raster):
    """
    Resample weather datasets (TWC/ERA5) to 100m resolution using a reference rastr of 100m resolution.
    
    params:
    variables_to_resample : List of weather variables to resample. 
    input_data_main_dir : Data main directory path. Subdirectories for each variable will be selected by the code.
    resampled_output_main_dir : Resampled output data main directory. Subdirectories for each variable will be selected by the code.
    target_raster : 100m ref raster. Can set to cities_california_100m_refraster. 
    
    returns: None.
    """
    # resampling data only the required variables
    for var in variables_to_resample:
        print(f'Resampling data for {var}...')
        variable_dir = os.path.join(input_data_main_dir, var)
        all_rasters = glob(os.path.join(variable_dir, '*.tif'))

        resampled_output_folder =  os.path.join(resampled_output_main_dir, var)
        makedirs([resampled_output_main_dir])
        
        for raster in all_rasters:
            raster_name = os.path.basename(raster).split('.')[0] + '.tif'
            resample_raster_based_on_ref_raster(input_raster=raster, ref_raster=target_raster, 
                                                output_dir=resampled_output_folder, raster_name=raster_name,
                                                resampling_alg=Resampling.bilinear)

# Soil Data Processing

In [7]:
def create_point_from_h3(row):
    """
    create point coords for h3 index. This can be further processed to polygon.
    
    params:
    row : dataframe row. 
    
    returns: Tuple of point coords of h3 index.
    """
    points = h3.h3_to_geo_boundary(row['h3'], True)

    return points

def compile_soil250_data_in_single_dataframe(input_data_dir, output_parquet, search_by):
    """
    Compiles multiple soil250 parquet files for individual aoi/regions to a signle dataframe.
    
    params:
    input_data_dir : Input directory filepath where the parquet files are located.
    output_parquet : Filepath of output parquet file.
    search_by : For example use "*V1*.parquet" for soil250 V1 parquet files. Use "*V2*.tif" for soil250 V2 parquet files.
    
    returns: Compiled single dataframe which is saved as a parquet file.
    """
    soil250_datasets = glob(os.path.join(input_data_dir, search_by))

    final_df = pd.DataFrame()
    for each in soil250_datasets:
        df = pd.read_parquet(each)
        final_df = pd.concat([final_df, df])
    
    # adding polygon geometry based on h3
    final_df['points'] = final_df.apply(create_point_from_h3, axis=1)
    final_df['geometry'] = final_df['points'].apply(Polygon)
    final_df = final_df.drop(columns=['points'])
    final_gdf=gpd.GeoDataFrame(final_df, geometry='geometry')
    final_gdf = final_gdf.set_crs('EPSG:4326')
    
    final_gdf.to_parquet(output_parquet)
    return final_gdf


def process_soil250_data(soil250_file, ref_raster, output_dir,
                         remove_cols = ['h3', 'hid', 'hids', 'geometry'],
                         interpolation_distance=30):   # while processing also remove the attributes at individual depths if not needed. Average values will be calculated and rasterized
    """
    Process TWC daily data for each column attribute.
    
    params:
    soil250_file : Filepath of parquet file with the attributes. The parquet file should have a 
                   "h3" column to be matched with the grid_shapefile.
    ref_raster : Filepath of reference raster to be used in shapefile to raster conversion. 
                 For example, yolo_county_100m_refraster as we want to resample soil250 data to 100m.
    output_dir : Filepath of folder where rasterized data will be saved.
    remove_cols : Columns in the TWC file that will not be rasterized.
    interpolation_distance : Integer value of how many sorrouding pixels to use in interpolation.  

    returns: None.
    """
    soil250_gdf = gpd.read_parquet(soil250_file)
    
    # creating output folder
    makedirs([output_dir])
    
    # Calculating average values of attributes
    soil250_attr = list(soil250_gdf.columns)
    
    if 'awct_0cm' in list(soil250_attr):
        soil250_gdf['average_awct'] = soil250_gdf[['awct_0cm', 'awct_5cm', 'awct_15cm', 'awct_30cm']].mean(axis=1)
    if 'wwp_0cm' in list(soil250_attr):
        soil250_gdf['average_wwp'] = soil250_gdf[['wwp_0cm', 'wwp_5cm', 'wwp_15cm', 'wwp_30cm']].mean(axis=1)
    if 'nit0_5' in list(soil250_attr):
        soil250_gdf['average_nitrogen'] = soil250_gdf[['nit0_5', 'nit5_15', 'nit15_30']].mean(axis=1)
    if 'soc0_5' in list(soil250_attr):
        soil250_gdf['average_soc'] = soil250_gdf[['soc0_5', 'soc5_15', 'soc15_30']].mean(axis=1)
    
    # making list of columns to be rasterized
    soil250_attr = list(soil250_gdf.columns)
    keep_attr = [i for i in soil250_attr if i not in remove_cols]
    
    # Looping through each attribute and each date to rasterize the data
    for attr_col in keep_attr:
        print(f'Processing Soil250 {attr_col} dataset...')
        
        # making new output directory for specific attribute
        process_to_dir = os.path.join(output_dir, attr_col, 'rasterized' )  
        makedirs([process_to_dir])    
        
        output_raster_fp = os.path.join(process_to_dir, f'{attr_col}.tif')

        # rasterization
        rasterzied_data = rasterize_shapefile(input_file=soil250_gdf, grid_shapefile=None, attribute=attr_col, 
                                              date=None, ref_raster=ref_raster, output_raster=output_raster_fp, 
                                              merge_alg = MergeAlg.replace, dtype='float32', no_data_value=-9999)
        
        # interpolate nodata pixels/gaps
        output_gapfilled_raster = os.path.join(output_dir, attr_col, f'{attr_col}.tif')
        fill_nodata_by_interpolation(input_raster=rasterzied_data, reference_raster=ref_raster, 
                                     output_raster=output_gapfilled_raster,
                                     use_pixel_to_interp=interpolation_distance)

## Elevation Data Processing

In [None]:
def compile_elevation_data_in_single_dataframe(input_data_dir, output_parquet, search_by='*.parquet'):
    """
    Compiles multiple elevation parquet files for individual aoi/regions to a signle dataframe.
    
    params:
    input_data_dir : Input directory filepath where the parquet files are located.
    output_parquet : Filepath of output parquet file.
    search_by : Default set to '*.parquet'.
    
    returns: Compiled single dataframe which is saved as a parquet file.
    """
    elevation_datasets = glob(os.path.join(input_data_dir, search_by))
    
    final_df = pd.DataFrame()
    for each in elevation_datasets:
        df = pd.read_parquet(each)
        final_df = pd.concat([final_df, df])
    
    # adding polygon geometry based on h3
    final_df['points'] = final_df.apply(create_point_from_h3, axis=1)
    final_df['geometry'] = final_df['points'].apply(Polygon)
    final_df = final_df.drop(columns=['points'])
    final_gdf=gpd.GeoDataFrame(final_df, geometry='geometry')
    final_gdf = final_gdf.set_crs('EPSG:4326')
    
    final_gdf.to_parquet(output_parquet)
    return final_gdf


def calculate_slope_aspect(dem_filepath, slope_raster_fp, aspect_raster_fp):
    """
    Calculates slope (degress) and aspect (degrees) raster using dem data.
    
    params:
    dem_filepath : DEM raster filepath. 
    slope_raster_fp : created slope raster filepath.
    aspect_raster_fp : created aspect raster filepath.
    
    returns: None.
    """
    dem_arr, dem_file = read_raster_arr_object(dem_filepath)
    
    # calculating slope and aspect
    slope_arr = gg.raster.calculate_slope(dem_file)
    aspect_arr = gg.raster.calculate_aspect(dem_file)
    
    # saving slope and aspect data
    slope_outdir = os.path.dirname(slope_raster_fp)
    makedirs([slope_outdir])
    write_array_to_raster(raster_arr=slope_arr, raster_file=dem_file, transform=dem_file.transform, 
                          output_path=slope_raster_fp)
    
    aspect_outdir = os.path.dirname(aspect_raster_fp)
    makedirs([aspect_outdir])
    write_array_to_raster(raster_arr=aspect_arr, raster_file=dem_file, transform=dem_file.transform, 
                          output_path=aspect_raster_fp)

## Satellite Data Processing

In [8]:
def save_raster_as_single_band(input_raster, output_dir,  
                               scale=False, scale_factor=0.001, 
                               add_value=False, subtract_value=False, value=None, 
                               change_dtype=False,  reformat_date=True,
                               nodata=no_data_value):
    """
    Read a raster with multiple band and save only band 1. Can perform datatype change and scaling if enabled. 
    
    params:
    input_raster : Input raster filepath.
    output_dir : Output directory filepath where processed rasters will be saved.
    scaling : Set to true if data need to be scaled. Default set to False.
    scale_factor : Scaling factor to multiply with if scaling=True.
    add_value : Set to True if want add a value. Default set to False.
    subtract_value : Set to True if want subtract a value. Default set to False.
    value : Value to add or subtract. Always put positive value. based on whether value to add or subtract chose add_value or subtract_value.
            Default set to None.
    change_dtype: Set to True if want to change raster data type to float. Default set to False.
    reformat_date: Set to to format date. Default set to True.
    no_data_value : No data value assigned to raster. Default set to -9999.
    
    returns: The output raster.
    """
    makedirs([output_dir])
    raster_arr, raster_file = read_raster_arr_object(input_raster)
    
    if change_dtype:  # changing dtype to float32
        raster_arr = raster_arr.astype(np.float32)
    
    if scale: #scaling data
        raster_arr[~np.isnan(raster_arr)] *= scale_factor
    
    if add_value: # adding a value
        raster_arr[~np.isnan(raster_arr)] += value
    
    if subtract_value: # subtracing a value 
        raster_arr[~np.isnan(raster_arr)] -= value
    
    
    raster_name = os.path.basename(input_raster).split('.')[0] + '.tif'
    
    # Formatting final raster name (including similarizing date pattern)
    # this is set considering data format of SMAP L-band soil moisture and temperature file format
    if reformat_date:
        main_raster_name = raster_name[:raster_name.rfind('_')]
        date = ''.join(raster_name.split('_')[-1].split('.')[0].split('-'))
        if main_raster_name[-1] == '_':
            raster_name = main_raster_name + date + '.tif'
        else:
            raster_name = main_raster_name + '_' + date + '.tif'
    
    
    output_path = os.path.join(output_dir, raster_name)
    
    with rio.open(
            output_path,
            'w',
            driver='GTiff',
            height=raster_arr.shape[0],
            width=raster_arr.shape[1],
            dtype=raster_arr.dtype,
            count=1,
            crs=raster_file.crs,
            transform=raster_file.transform,
            nodata=nodata
    ) as dst:
        dst.write(raster_arr, 1)
        
    return output_path

def harmonize_resample_satellite_data(input_data_dir, search_by, main_output_dir, height, width, ref_raster, 
                                      scale_factor, scale=True, add_value=False, subtract_value=False, value=None,
                                      change_dtype=True):
    
    """
    Extract single band info and harmonize (same height and width) satellite data.
    
    ***Look into the raw data and determine a data as reference raster. All data will be harmonized based on that.
    
    params: 
    input_data_dir : Filepath of input data directory.
    search_by : search keywork for selecting datasets. For example: '*.tif'.
    main_output_dir : Filepath of output data main directory. A sub-directory with intermediate single band data will be created by code.
    height, width :Height and width integer.
    ***ref_raster*** : Filepath of input raster to be used in determining resample affine transformation/crs/dtype/nodata.
    scaling : Set to true if data need to be scaled. Default set to True.
    scale_factor : Scaling factor to multiply with if scaling=True.
    add_value : Set to True if want add a value. Default set to False.
    subtract_value : Set to True if want subtract a value. Default set to False.
    value : Value to add or subtract. Always put positive value. based on whether value to add or subtract chose add_value or subtract_value.
            Default set to None.
    change_dtype: Set to True if want to change raster data type to float. Default set to True.
        
    returns: None.
    """
    
    raw_rasters = glob(os.path.join(input_data_dir, search_by))
    
    single_band_output_dir = os.path.join(main_output_dir, 'single_band')
    makedirs([single_band_output_dir])
    
    for raw_data in raw_rasters: 
        single_band_raster = save_raster_as_single_band(input_raster=raw_data, output_dir=single_band_output_dir, 
                                                        change_dtype=change_dtype, scale=scale, scale_factor=scale_factor,
                                                        add_value=add_value, subtract_value=subtract_value, value=value, 
                                                        nodata=-9999)
        
        raster_name = os.path.basename(raw_data).split('.')[0] + '.tif'
               
        resample_raster_with_height_width(input_raster=single_band_raster, height=height, width=width, 
                                          ref_raster=ref_raster, output_dir=main_output_dir, 
                                          raster_name=raster_name, resampling_alg=Resampling.bilinear)
        
        
def mask_resample_weather_datasets_to_bbox(variables_to_resample, input_data_main_dir, output_main_dir, bounding_box,
                                           target_raster):
    """
    Mask and resample weather rasters (resampled 100m resolution) with a bounding box and target satellite data (100m).
    
    ##########################
    Masking/Resampling Weather data for woodland site. Saving maked data to an interim directory and resampling it again.
    Ideally, this should work with only masking (without resamplig). But rasterio.mask is causing this issue
    gdal.warp could have done it correctly (tested by QGIS) but I couldn't figure out gdal in AWS
    So, I am trying a different approach by first masking the era5 data to using bounding box and then resampling it 
    using satellite data.
    ##########################
    
    params:
    variables_to_resample : List of weather variables to resample. 
    input_data_main_dir : Data main directory path. Subdirectories for each variable will be selected by the code.
    output_main_dir : Output data main directory where masked and resampled data will be saved. 
                      Subdirectories for each variable will be selected by the code.
    target_raster : 100m ref raster. Should be a satellite raster of 100m resolution that is inside the bounding box. 
    
    returns: None.
    """

    # resampling data only the required variables
    for var in variables_to_resample:
        print(f'Masking and resampling data for {var}...')
        variable_dir = os.path.join(input_data_main_dir, var)
        all_rasters = glob(os.path.join(variable_dir, '*.tif'))
        
        masked_output_dir =  os.path.join(output_main_dir, var, 'interim') #saving initial masked rasters in an 'interim folder' inside each variable's folder
        resampled_output_dir =  os.path.join(output_main_dir, var)
        makedirs([masked_output_dir])
        
        for raster in all_rasters:
            raster_name = os.path.basename(raster).split('.')[0] + '.tif'
            masked_raster_fp = os.path.join(masked_output_dir, raster_name)
            # masking to bounding box
            mask_raster_array_by_shapefile(input_raster=raster, mask_shape=bounding_box, 
                                             output_dir=masked_output_dir, raster_name=raster_name, 
                                             invert=False, crop=True, save_masked_arr=True)
            # resampling to make sure the processed raster is inside the bounding box
            resample_raster_based_on_ref_raster(input_raster=masked_raster_fp, ref_raster=target_raster, 
                                                output_dir=resampled_output_dir, raster_name=raster_name)

## Miscellaneous

In [9]:
def copy_file(input_dir_file, copy_dir, search_by='*.tif', rename=None):
    """
    Copy a file to the specified directory.

    :param input_dir_file: File path of input directory/ Path of the file to copy.
    :param copy_dir: File path of copy directory.
    :param search_by: Default set to '*.tif'.
    :param rename: New name of file if required. Default set to None. DOesn't work if a directory is being copied.

    :returns: File path of copied file.
    """
    makedirs([copy_dir])
    if '.tif' not in input_dir_file:
        input_files = glob(os.path.join(input_dir_file, search_by))

        for each in input_files:
            file_name = os.path.basename(each)
            copy_file = os.path.join(copy_dir, file_name)

            shutil.copyfile(each, copy_file)

    else:
        if rename is not None:
            copy_file = os.path.join(copy_dir, f'{rename}.tif')
        else:
            file_name = os.path.basename(input_dir_file)
            copy_file = os.path.join(copy_dir, file_name)

        shutil.copyfile(input_dir_file, copy_file)

    return copy_file


def makedirs(directory_list):
    """
    Make directory (if not exists) from a list of directory.

    :param directory_list: A list of directories to create.

    :returns: None.
    """
    for directory in directory_list:
        if not os.path.exists(directory):
            os.makedirs(directory)
            
def make_folder_in_s3_bucket(new_folder_path, bucket_name='data-pipeline-env-model'):
    """
    Make directory/folder in AWS S3 Bucket.
    
    params:
    new_folder_path : Folder path to create in the S3 bucket. Have to be like this "Main_folder/subfolder"
    bucket_name : S3 bucket name. Default set to 'data-pipeline-env-model'
    
    returns: None.
    """
    s3 = boto3.client('s3')
    s3.put_object(Bucket=bucket_name, Key=(new_folder_path+'/'))

## Plotting

In [10]:
def plot_dated_images_rasters(list_of_images, num_cols, figsize=(8, 6), shapefile=None,
                              title=None):
    """
    Plot images with date information. Date should be in the end like this '20211231'
    
    params:
    list_of_images : List of images filepath to plot.
    num_cols : number of columns in the plot.
    figsize : Figsise. Default set to (8, 6).
    shapefile : Shapefile filepath if want to show in plot. Default set to None.
    title : Title string. Default set to None.
    
    returns: None.
    """
    # setting number of row
    if len(list_of_images)%num_cols == 0:
        num_rows = len(list_of_images)//num_cols
    else:
        num_rows = len(list_of_images)//num_cols + 1
    
    if shapefile is not None:
        gdf = gpd.read_file(shapefile)
    
    # define the figures and axes
    fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=figsize)
    
    for image in list_of_images:
        # extarcting date
        date = os.path.basename(image).split('.')[0].split('_')[-1]
        year, month, date = date[:4], date[4:6], date[6:]
        date = f'{year}_{month}_{date}'
        
        img = rio.open(image)
        
        row = list_of_images.index(image)//num_cols
        col = list_of_images.index(image)%num_cols
        
        if shapefile is not None:
            gdf.plot(facecolor='none', edgecolor='black', ax=ax[row, col])
        
        show(img, ax=ax[row, col], cmap='Spectral_r')
        ax[row, col].set_title(f'Date {date}', fontsize=9)
        
    if title is not None:
        plt.suptitle(title)
    plt.tight_layout()
        
        
def plot_dated_images_scatter(images_list_data1, images_list_data2, num_cols, figsize=(8, 6), 
                                      title=None, xlabel=None, ylabel=None):
    """
    Plot scatter plot of two list of images.Each image should have date info the end like this '20211231'
    
    params:
    images_list_data1 : 1st List of images filepath to plot.
    images_list_data2 : 2nd List of images filepath to plot.
    num_cols : number of columns in the plot.
    figsize : Figsise. Default set to (8, 6).
    title, xlabel, ylabel : Title, xlabel, ylabel string. Default set to None.
    
    returns: None.
    """
    
    # setting number of row
    if len(images_list_data1)%num_cols == 0:
        num_rows = len(images_list_data1)//num_cols
    else:
        num_rows = len(images_list_data1)//num_cols + 1
    
    # define the figures and axes
    fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=figsize)
    
    for image1 in images_list_data1:
        # extarcting date
        date = os.path.basename(image1).split('.')[0].split('_')[-1]
        
        # loop to find data of matching dates in 2nd image list 
        for image2 in images_list_data2:    
            if date in image2:
                # formatting date
                year, month, dateee = date[:4], date[4:6], date[6:]
                date_fmt = f'{year}_{month}_{dateee}'
                
                # opening matching images
                img1 = rio.open(image1).read(1).flatten()
                img2 = rio.open(image2).read(1).flatten()
                
                # filtering out nan or -9999 values
                img_df = pd.DataFrame({'img1': img1, 'img2': img2})
                img_df = img_df.dropna()
                img_df = img_df[(img_df['img1'] != -9999) & (img_df['img2'] != -9999)]
                
                # calculating r2 score between the variables
                r2_val = r2_score(img_df.img1, img_df.img2)  
                
                # deciding [row, col] index of the plot
                row = images_list_data1.index(image1)//num_cols
                col = images_list_data1.index(image1)%num_cols
        
                # scatter plot
                ax[row, col].plot(img_df.img1, img_df.img2, 'b.', alpha=0.05)
                
                # setting xlabel, ylabel min max
                maxx = max(img_df.img1.max(), img_df.img2.max())
                maxy = max(img_df.img2.max(), img_df.img2.max())
                minx = min(img_df.img1.min(), img_df.img2.min())
                miny = min(img_df.img1.min(), img_df.img2.min())
                
                ax[row, col].set_xlim(minx, maxx)
                ax[row, col].set_ylim(miny, maxx)
                
                # including R2 value
                if maxx>5:
                    min_pos = minx+0.1
                    max_pos = maxx-1
                else:
                    min_pos = 0.01
                    max_pos = maxx-0.3
                ax[row, col].annotate(text='R2 = {:.3f}'.format(r2_val), xy=(min_pos, max_pos), fontsize=8)
                
                # setting 1:1 line
                pt = (0, 0)
                ax[row, col].axline(pt, slope=1, color='red', linewidth=0.3)
                
                # setting individual plot title
                ax[row, col].set_title(f'Date {date_fmt}', fontsize=9)
                
                # setting xlabel and ylabel
                if (xlabel is not None) & (ylabel is not None):
                    ax[row, col].set_xlabel(xlabel, fontsize=9)
                    ax[row, col].set_ylabel(ylabel, fontsize=9)
                
        # setting overall plot title 
        if title is not None:
            plt.suptitle(title)
        plt.tight_layout()

In [10]:
##### shapefile to raster using geocube, can rasterize and resample (cubic/neareast/linear) at the same time

# from geocube.api.core import make_geocube
# from functools import partial
# from geocube.rasterize import rasterize_points_griddata

# twc_weather_shape = '../datasets/weather_data_shapefiles/twc_weatherdata_California.parquet'
# twc_grid = '../datasets/weather_data_shapefiles/twcGrid_California.shp'

# data_gdf = read_twc_parquet_save_as_geodataframe(parquet_file=twc_weather_shape, twc_grid_geometry_file=twc_grid,
#                                                 save=False)

# data_gdf = data_gdf[data_gdf['date']=='2021-01-01']

# converted_raster = make_geocube(
#         vector_data=data_gdf,
#         measurements=["total_precip"],
#         resolution=(-0.036, 0.036),
#         fill = no_data_value,
#     rasterize_function=partial(rasterize_points_griddata, method="nearest"))
    
# # Save raster census raster
# converted_raster.rio.to_raster('../datasets/reference_rasters/twc_precip2.tif')