### Author: Md Fahim Hasan
### Work Email: mdfahim.hasan@bayer.com

In [1]:
import os
import numpy as np
import pandas as pd
from time import time
from glob import glob
import rasterio as rio
import geopandas as gpd
import dask.dataframe as ddf
from rasterio.plot import show
import matplotlib.pyplot as plt
from rasterio.features import rasterize
from rasterio.enums import MergeAlg, Resampling

In [2]:
no_data_value=-9999

In [3]:
def read_raster_arr_object(raster_file, rasterio_obj=False, band=1, get_file=True, change_dtype=True):
    """
    Get raster array and raster file.

    :param raster_file: Input raster filepath.
    :param rasterio_obj: Set True if raster_file is a rasterio object.
    :param band: Selected band to read. Default set to 1.
    :param get_file: Set to False if raster file is not required.
    :param change_dtype: Set to True if want to change raster data type to float. Default set to True.

    :return: Raster numpy array and rasterio object file (get_file=True, rasterio_obj=False).
    """
    if not rasterio_obj:
        raster_file = rio.open(raster_file)
    else:
        get_file = False
    raster_arr = raster_file.read(band)
    if change_dtype:
        raster_arr = raster_arr.astype(np.float32)
        if raster_file.nodata:
            raster_arr[np.isclose(raster_arr, raster_file.nodata)] = np.nan
    if get_file:
        return raster_arr, raster_file
    else:
        return raster_arr


def write_array_to_raster(raster_arr, raster_file, transform, output_path, ref_file=None, nodata=no_data_value):
    """
    Write raster array to Geotiff format.

    :param raster_arr: Raster array data to be written.
    :param raster_file: Original rasterio raster file containing geo-coordinates.
    :param transform: Affine transformation matrix.
    :param output_path: Output filepath.
    :param ref_file: Write output raster considering parameters from reference raster file.
    :param nodata: no_data_value set as -9999.

    :return: Output filepath.
    """
    if ref_file:
        raster_file = rio.open(ref_file)
        transform = raster_file.transform

    with rio.open(
            output_path,
            'w',
            driver='GTiff',
            height=raster_arr.shape[0],
            width=raster_arr.shape[1],
            dtype=raster_arr.dtype,
            count=1,  # raster_file.count
            crs=raster_file.crs,
            transform=transform,
            nodata=nodata
    ) as dst:
        dst.write(raster_arr, 1) #raster_file.count

    return output_path


def rasterize_shapefile(input_file, output_raster, attribute, ref_raster, date=None, grid_shapefile=None, 
                        merge_alg = MergeAlg.replace, dtype='float32', no_data_value=-9999, paste_on_ref_raster=False):
    """
    rasterize shapefile.
    
    params:
    input_file : Filepath of parquet (or already read geodataframe) file with the attribute. If parquet file should have a 
                 grid_id column to be matched with the grid_shapefile.
    output_raster : Filepath of output raster file.
    grid_shapefile : If parquet file in given as input_file, filepath of grid/geometry shapefile. 
                     Should have a grid_id column to be matched with the parquet file.
                     Default set to None so that 
    attribute : Attriute column (in str) of parquet/gdf to rasterize. 
    ref_raster : Reference raster to be used in assigning rasterization shape, transform.
    date : Default set to None. Set to str of date if want to filter the parquet/gdf for a specific date (specially for weather data).
    merge_alg : Rasterio merge algorithm. Can be either MergeAlg.replace or MergeAlg.add. Default set to MergeAlg.replace.
    dtype : Data type of raster. Default set to Float32.
    no_data_value : No data value assigned to raster. Default set to -9999.
    paste_on_ref_raster : Set to True if want to paste rasterized values on a reference rasters. In this case, the raster will
                          have similar no_data pixels as reference rasters.              
    
    returns: The output raster filepath.
    """
    if 'parquet' in input_file: # if parquet file used as input_file
        gdf = read_parquet_as_geodataframe(parquet_file=input_file, grid_geometry_file=grid_shapefile, save=False)
    else: # if geodataframe used as input_file
        gdf = input_file
        
    if date is not None:
        gdf = gdf[gdf['date'] == date]

    ref_arr, ref_file = read_raster_arr_object(ref_raster)
    input_shape = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attribute]))

    raster_arr = rasterize(shapes=input_shape, out_shape=ref_arr.shape, fill=no_data_value, out=None, 
                           transform=ref_file.transform, all_touched=True, 
                           default_value=no_data_value, dtype=dtype, merge_alg=merge_alg)
    
    if paste_on_ref_raster:
        raster_arr[np.isnan(ref_arr)] = no_data_value
    
    write_array_to_raster(raster_arr, raster_file=ref_file, transform=ref_file.transform, 
                          output_path=output_raster, ref_file=None, nodata=no_data_value)

    return output_raster


def resample_raster_based_on_ref_raster(input_raster, ref_raster, output_dir, raster_name, resampling_alg=Resampling.bilinear,
                                        paste_value_on_ref_raster=False):
    """
    Resample raster based on a refernce raster.
    
    params:
    input_raster : Filepath of input raster to resample.
    ref_raster : Filepath of input raster to be used in determining resample height/width/affine transformation/crs/dtype/nodata.
    output_raster : Filepath of resampled output raster.
    resampling_alg : resampling algorithm. Can be Resampling.nearest/ Resampling.bilinear/Resampling.cubic or 
                     any resampling algorith rasterio supports Default set to Resampling.bilinear.
    paste_value_on_ref_raster : Set to True if want to have nodata pixels on the resampled raster similar to reference raster. 
    
    returns: The resampled output raster filepath.
    """
    makedirs([output_dir])
    
    ref_arr, ref_file = read_raster_arr_object(ref_raster)
    
    # target shape. use a reference raster (created using GIS for a specific region) to decide.
    resampled_height, resampled_width = ref_arr.shape

    with rio.open(input_raster) as dataset:
        # resample data to target shape
        resampled_arr = dataset.read(1,
                            out_shape=(1,
                                       resampled_height,
                                       resampled_width),
                            resampling=resampling_alg)

        resampled_arr = resampled_arr.squeeze() # removing the 1 (for count) from the dimension
        
        if paste_value_on_ref_raster:
            resampled_arr = np.where(np.isnan(ref_arr), -9999, resampled_arr)
        
        # Saving the resampled data
        output_raster = os.path.join(output_dir, raster_name)
        write_array_to_raster(raster_arr=resampled_arr, raster_file=ref_file, 
                              transform=ref_file.transform, output_path=output_raster, 
                              ref_file=None, nodata=-9999)
        
        return output_raster
    

In [4]:
def makedirs(directory_list):
    """
    Make directory (if not exists) from a list of directory.

    :param directory_list: A list of directories to create.

    :returns: None.
    """
    for directory in directory_list:
        if not os.path.exists(directory):
            os.makedirs(directory)
            
            
def make_lat_lon_array_from_raster(input_raster, nodata=-9999):
    """
    Make lat, lon array for each pixel using the input raster.
    
    params:
    input_raster : Input raster filepath that will be used as reference raster.
    nodata : No data value. Default set to -9999.
    
    returns: Lat, lon array with nan value (-9999) applied.
    """
    
    raster_file = rio.open(input_raster)
    raster_arr = raster_file.read(1)

    # calculating lat, lon of each cells centroid
    height, width = raster_arr.shape
    cols, rows = np.meshgrid(np.arange(width), np.arange(height))
    xs, ys = rio.transform.xy(rows=rows, cols=cols, transform=raster_file.transform)
    
    # flattening and reshaping to the input_raster's array size
    xs = np.array(xs).flatten()
    ys = np.array(ys).flatten()
    
    lon_arr = xs.reshape(raster_arr.shape)
    lat_arr = ys.reshape(raster_arr.shape)
    
    # assigning no_data_value
    lon_arr[raster_arr==nodata] = nodata
    lat_arr[raster_arr==nodata] = nodata
    
    return lon_arr, lat_arr

In [5]:
def compile_twc_daily_data_to_dataframe(savename, twc_data_folder, output_folder):
    """
    Compile twc daily data in a dataframe. All datasets have to be of same shape.
    
    params:
    savename : (str) Name of the output parquet file.
    twc_data_folder : TWC daily data main folder. The code will automatically get data in the sub-directories.
    output_folder : Main output folder. The code will automatically save data in the individual sub-directories.
    
    returns: compiled TWC dataframe.
    """
    start_time = time()
    
    makedirs([output_folder])
    
    # making list of variables in the twc raster data folder
    variable_names = os.listdir(twc_data_folder)
    variable_paths = [os.path.join(twc_data_folder, folder) for folder in variable_names]
    
    # will be used to multiply lat/lon data 
    # the condition is added because we have to process different TWC data and not all has max_temp or total_precip
    num_days = len(glob(os.path.join(twc_data_folder, 'max_temp', '*.tif')))
    if num_days > 1:
        pass
    else:
        num_days = len(glob(os.path.join(twc_data_folder, 'total_precip', '*.tif')))
    
    
    variable_dict = {}  # a dictionary where daily dataset values will be stored under variable_name 
    
    for path in variable_paths:
        all_data = glob(os.path.join(path, '*.tif')) # making list of all dataset in a particular folder
        all_data = sorted(all_data)  # to sort data by date so that all variables are compiled in same serial
        
        variable_name = os.path.basename(path).split('.')[0]  # extracted variable name 
        if variable_name not in ['lat', 'lon']:
            print(f'compiling data for {variable_name}...')

            # loop for reading datasets and storing pixel info in a dictionary
            for count, data in enumerate(all_data):
                # retrieving and storing data
                data_arr = read_raster_arr_object(data, get_file=False).flatten()  # read data as array and flattened it

                # extarcting, formatting, and storing date info
                date = os.path.basename(data).split('.')[0].split('_')[-1]
                year, month, day = date[:4], date[4:6], date[6:]

                len_data = len(data_arr)  # number of pixels in each daily dataset (array)
                year_list = [int(year)] * len_data
                month_list = [int(month)] * len_data
                day_list = [int(day)] * len_data
                date_list = [int(date)] * len_data

                # Assigning all values to the variable_dict
                if count == 0:
                    variable_dict[variable_name] = list(data_arr)  # storing flattened data in a dictionary under the variable name
                    variable_dict['date'] = date_list
                    variable_dict['year'] = year_list
                    variable_dict['month'] = month_list
                    variable_dict['day'] = day_list

                else:
                    variable_dict[variable_name].extend(list(data_arr))  # storing flattened data in a dictionary under the variable name)
                    variable_dict['date'].extend(date_list)
                    variable_dict['year'].extend(year_list)
                    variable_dict['month'].extend(month_list)
                    variable_dict['day'].extend(day_list)
        else: # for lat/lon
            data = glob(os.path.join(path, '*.tif'))[0]##############
            data_arr = read_raster_arr_object(data, get_file=False).flatten()  # read data as array and flattened it
            data_duplicated_for_days = list(data_arr) * num_days
            variable_dict[variable_name] = data_duplicated_for_days

    twc_variable_df = pd.DataFrame(variable_dict)
    twc_variable_ddf = ddf.from_pandas(twc_variable_df, npartitions=20)
    twc_variable_ddf = twc_variable_ddf.dropna()
    twc_variable_ddf = twc_variable_ddf.reset_index()
    
    if '.parquet' in savename:
        output_parquet_file = os.path.join(output_folder, savename)
    else:
        output_parquet_file = os.path.join(output_folder, savename+'.parquet')
    
    twc_variable_ddf.to_parquet(output_parquet_file)
    
    end_time = time()
    print('time taken', round((end_time-start_time)/60, 3), 'mins')

    return twc_variable_ddf



def combine_twc_era5_datasets(twc_dataset, era5_dataset, output_file, merge_on=['date', 'lat', 'lon'], how='inner'):
    """
    Combine twc era5 datasets.
    
    params:
    twc_dataset : TWC dataframe filepath or dataframe.
    era5_datset : ERA5 dataframe filepath or dataframe.
    output_file : Combined TWC and ERA5 dataframe output filepath.
    merge_on : List of columns to use in dataframe merging. Default set to ['date', 'lat', 'lon'].
    how : Type of merging. Default set to 'inner'.
    
    returns: Combined TWC and ERA5 dataframe.
    """
    if isinstance(twc_dataset, pd.DataFrame):
        twc_df = twc_dataset
        era5_df = era5_dataset
    else:
        twc_df = pd.read_parquet(twc_dataset)
        era5_df = pd.read_parquet(era5_dataset)
    
    twc_era5_combined = twc_df.merge(era5_df, on=merge_on, suffixes=('_twc', '_era5'), how=how)
    twc_era5_combined.to_parquet(output_file)

    return twc_era5_combined

In [6]:
def copy_file(input_dir_file, copy_dir, search_by='*.tif', rename=None):
    """
    Copy a file to the specified directory.

    :param input_dir_file: File path of input directory/ Path of the file to copy.
    :param copy_dir: File path of copy directory.
    :param search_by: Default set to '*.tif'.
    :param rename: New name of file if required. Default set to None. DOesn't work if a directory is being copied.

    :returns: File path of copied file.
    """
    makedirs([copy_dir])
    if '.tif' not in input_dir_file:
        input_files = glob(os.path.join(input_dir_file, search_by))

        for each in input_files:
            file_name = os.path.basename(each)
            copy_file = os.path.join(copy_dir, file_name)

            shutil.copyfile(each, copy_file)

    else:
        if rename is not None:
            copy_file = os.path.join(copy_dir, f'{rename}.tif')
        else:
            file_name = os.path.basename(input_dir_file)
            copy_file = os.path.join(copy_dir, file_name)

        shutil.copyfile(input_dir_file, copy_file)

    return copy_file

In [7]:
def plot_dated_images_rasters(list_of_images, num_cols, figsize=(8, 6), shapefile=None,
                              title=None):
    """
    Plot images with date information. Date should be in the end like this '20211231'
    
    params:
    list_of_images : List of images filepath to plot.
    num_cols : number of columns in the plot.
    figsize : Figsise. Default set to (8, 6).
    shapefile : Shapefile filepath if want to show in plot. Default set to None.
    title : Title string. Default set to None.
    
    returns: None.
    """
    # setting number of row
    if len(list_of_images)%num_cols == 0:
        num_rows = len(list_of_images)//num_cols
    else:
        num_rows = len(list_of_images)//num_cols + 1
    
    if shapefile is not None:
        gdf = gpd.read_file(shapefile)
    
    # define the figures and axes
    fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=figsize)
    
    for image in list_of_images:
        # extarcting date
        date = os.path.basename(image).split('.')[0].split('_')[-1]
        year, month, date = date[:4], date[4:6], date[6:]
        date = f'{year}_{month}_{date}'
        
        img = rio.open(image)
        
        row = list_of_images.index(image)//num_cols
        col = list_of_images.index(image)%num_cols
        
        if shapefile is not None:
            gdf.plot(facecolor='none', edgecolor='black', ax=ax[row, col])
        
        show(img, ax=ax[row, col], cmap='Spectral_r')
        ax[row, col].set_title(f'Date {date}', fontsize=9)
        
    if title is not None:
        plt.suptitle(title)
    plt.tight_layout()
        
        
def plot_dated_images_scatter(images_list_data1, images_list_data2, num_cols, figsize=(8, 6), 
                                      title=None, xlabel=None, ylabel=None):
    """
    Plot scatter plot of two list of images.Each image should have date info the end like this '20211231'
    
    params:
    images_list_data1 : 1st List of images filepath to plot.
    images_list_data2 : 2nd List of images filepath to plot.
    num_cols : number of columns in the plot.
    figsize : Figsise. Default set to (8, 6).
    title, xlabel, ylabel : Title, xlabel, ylabel string. Default set to None.
    
    returns: None.
    """
    
    # setting number of row
    if len(images_list_data1)%num_cols == 0:
        num_rows = len(images_list_data1)//num_cols
    else:
        num_rows = len(images_list_data1)//num_cols + 1
    
    # define the figures and axes
    fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=figsize)
    
    for image1 in images_list_data1:
        # extarcting date
        date = os.path.basename(image1).split('.')[0].split('_')[-1]
        
        # loop to find data of matching dates in 2nd image list 
        for image2 in images_list_data2:    
            if date in image2:
                # formatting date
                year, month, dateee = date[:4], date[4:6], date[6:]
                date_fmt = f'{year}_{month}_{dateee}'
                
                # opening matching images
                img1 = rio.open(image1).read(1).flatten()
                img2 = rio.open(image2).read(1).flatten()
                
                # filtering out nan or -9999 values
                img_df = pd.DataFrame({'img1': img1, 'img2': img2})
                img_df = img_df.dropna()
                img_df = img_df[(img_df['img1'] != -9999) & (img_df['img2'] != -9999)]
                
                # calculating r2 score between the variables
                r2_val = r2_score(img_df.img1, img_df.img2)  
                
                # deciding [row, col] index of the plot
                row = images_list_data1.index(image1)//num_cols
                col = images_list_data1.index(image1)%num_cols
        
                # scatter plot
                ax[row, col].plot(img_df.img1, img_df.img2, 'b.', alpha=0.05)
                
                # setting xlabel, ylabel min max
                maxx = max(img_df.img1.max(), img_df.img2.max())
                maxy = max(img_df.img2.max(), img_df.img2.max())
                minx = min(img_df.img1.min(), img_df.img2.min())
                miny = min(img_df.img1.min(), img_df.img2.min())
                
                ax[row, col].set_xlim(minx, maxx)
                ax[row, col].set_ylim(miny, maxx)
                
                # including R2 value
                if maxx>5:
                    min_pos = minx+0.1
                    max_pos = maxx-1
                else:
                    min_pos = 0.01
                    max_pos = maxx-0.3
                ax[row, col].annotate(text='R2 = {:.3f}'.format(r2_val), xy=(min_pos, max_pos), fontsize=8)
                
                # setting 1:1 line
                pt = (0, 0)
                ax[row, col].axline(pt, slope=1, color='red', linewidth=0.3)
                
                # setting individual plot title
                ax[row, col].set_title(f'Date {date_fmt}', fontsize=9)
                
                # setting xlabel and ylabel
                if (xlabel is not None) & (ylabel is not None):
                    ax[row, col].set_xlabel(xlabel, fontsize=9)
                    ax[row, col].set_ylabel(ylabel, fontsize=9)
                
        # setting overall plot title 
        if title is not None:
            plt.suptitle(title)
        plt.tight_layout()
    

    
def plot_era5_twc_downscaled_rasters(era5_data, twc_data, downscaled_data,  
                                     title=None, suptitle_pos=0.75,
                                     xlabels=['ERA5 Data (28km res.)',
                                              'TWC Data (4km res.)',
                                              'Model Interpolated Data (4km res.)'],
                                     cbar_axes_pos=[1, 0.27, 0.02, 0.42]):
    """
    Plot ERA5 (original 28km), TWC, and model-downscaled datasets.
    
    :param era5_data : Filepath of a single day data from ERA5 datasets.
    :param twc_data : Filepath of a single day data from TWC datasets.
    :param downscaled data : Filepath of a single day data from downscaled datasets.
    :param title : A title to use in the plot. Default set to None.
    :param suptitle_pos : A float value that designates suptitle position in the y-direction. Default set to 0.75.
    :param xlabels : A list of xlabel for the three rasters. Default set to: 
                    ['ERA5 Data (28km res.)', 'TWC Data (4km res.)', 'Model Interpolated Data (4km res.)']
    :param cbar_axes_pos : A list of float values indicating colorbar position. Change it to place the colorbar perfectly.
                           Default set to [1, 0.27, 0.02, 0.42].
    
    returns: Plot of ERA5, TWC, and downscaled raster with same colorbar.
    
    """
    
    # Data read and replacing -9999 values with np.nan 
    era5_arr = rio.open(era5_data).read(1)
    era5_arr = np.where(era5_arr < -100, np.nan, era5_arr)
    
    twc_arr = rio.open(twc_data).read(1)
    twc_arr = np.where(twc_arr < -100, np.nan, twc_arr)

    downscaled_arr = rio.open(downscaled_data).read(1)
    downscaled_arr = np.where(downscaled_arr < -100, np.nan, downscaled_arr)

    arr_to_plot = [era5_arr, twc_arr, downscaled_arr]
    
    plt.rcParams['font.size'] = 12
    # Plotting
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 8))
    for i, ax in enumerate(axes.flat):
        im = ax.imshow(arr_to_plot[i], cmap='RdYlGn_r')
        ax.set_xlabel(xlabels[i])
    
    
    # Title
    fig.suptitle(title, y=suptitle_pos, fontsize=16)
    
    # Placing colorbar
    cbar_ax = fig.add_axes(cbar_axes_pos)
    fig.colorbar(im, cax=cbar_ax)
    plt.tight_layout()