In [1]:
import os
import geopandas as gpd
import shapely
import rasterio
import rasterio.features
import rasterio.windows
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import xarray as xr
import tqdm
import time
import gc

In [2]:
import sys
sys.path.append('..')

In [3]:
import rsutils.utils

In [4]:
output_folderpath = '/gpfs/data1/cmongp1/sasirajann/nh_crop_calendar/crop_calendar/data/outputs/sub-saharan-africa-countrywise'
shapefilepath = '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/shapefiles/AfSP012Qry_ISRIC/GIS_Shape/AfSP012Qry_SubSaharanAfrica.shp'
cropmask_filepath = '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/outputs/sub-saharan-africa/merged_WC_cropmask/maize.tif'
labels_tif_filepath = os.path.join(output_folderpath, 'labels.tif')

continent_data_folderpath = '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/outputs/sub-saharan-africa'
datasets = [
    'chirps.nc',
    'cpc-tmax.nc',
    'cpc-tmin.nc',
    'esi-4wk.nc',
    'fpar.nc',
    'gcvi.nc',
    'ndvi.nc',
    'nsidc-rootzone.nc',
    'nsidc-surface.nc',
]

In [5]:
with rasterio.open(cropmask_filepath) as src:
    transform = src.transform
    out_shape = (src.height, src.width)
    raster_crs = src.crs
    profile = src.profile.copy()

In [6]:
regions_gdf = gpd.read_file(shapefilepath).to_crs(raster_crs)
name_geometry_dict = regions_gdf.groupby('FAO_NAME')['geometry'].apply(shapely.unary_union).to_dict()
name_id_dict = {name: _id for _id, name in enumerate(name_geometry_dict.keys(), 1)}
geometry_id_tuple = ((geometry, name_id_dict[name]) for name, geometry in name_geometry_dict.items())

In [9]:
data = {
    'FAO_NAME': [],
    'label': [],
    'geometry': [],
}

for name, geometry in name_geometry_dict.items():
    data['FAO_NAME'].append(name)
    data['label'].append(name_id_dict[name])
    data['geometry'].append(geometry)

union_region_gdf = gpd.GeoDataFrame(data=data, crs=regions_gdf.crs)
union_region_gdf.to_file(os.path.join(output_folderpath, 'regions.geojson'))

In [None]:
name_geometry_dict

In [7]:
labels = rasterio.features.rasterize(
    shapes = geometry_id_tuple,
    out_shape = out_shape,
    transform = transform,
    fill = 0,
    dtype = int,
    all_touched = False,
)

In [8]:
profile.update(
    dtype=rasterio.uint8,
    count=1,
    compress='LZW',
    nodata=0
)

In [None]:
with rasterio.open(labels_tif_filepath, 'w', **profile) as dst:
    dst.write(labels, 1)

In [11]:
def crop_tif_by_bbindices(
    input_tif:str,
    output_tif:str,
    row_start:int,
    row_stop:int,
    col_start:int,
    col_stop:int,
    target_value:int = None, # and_set_the_containing_region to 1
):
    with rasterio.open(input_tif) as src:
        # Width and height of window
        height = row_stop - row_start
        width  = col_stop - col_start

        # Define window in pixel/row-col terms
        window = rasterio.windows.Window(col_start, row_start, width, height)

        # Read from the window
        cropped = src.read(1, window=window)

        if target_value is not None:
            new_cropped = np.zeros(cropped.shape)
            new_cropped[cropped == target_value] = 1
            cropped = new_cropped

        # Update metadata for output
        profile = src.profile
        profile.update({
            "height": height,
            "width": width,
            "transform": rasterio.windows.transform(window, src.transform)
        })

    # Write cropped raster
    with rasterio.open(output_tif, "w", **profile) as dst:
        dst.write(cropped, 1)

In [None]:
for name, _id in name_id_dict.items():
    folderpath = os.path.join(output_folderpath, name)
    os.makedirs(folderpath, exist_ok=True)

    xs, ys = np.where(labels == _id)
    min_x = min(xs)
    max_x = max(xs)
    min_y = min(ys)
    max_y = max(ys)

    crop_tif_by_bbindices(
        input_tif = labels_tif_filepath,
        output_tif = os.path.join(folderpath, 'region_mask.tif'),
        row_start = min_x,
        row_stop = max_x,
        col_start = min_y,
        col_stop = max_y,
        target_value = _id,
    )

In [12]:
xs, ys = np.where(labels == 1)
min_x = min(xs)
max_x = max(xs)
min_y = min(ys)
max_y = max(ys)

In [None]:
min_x, \
max_x, \
min_y, \
max_y,

In [None]:
# for dataset in tqdm.tqdm(datasets):
#     print('dataset:', dataset)
#     dataarray = xr.load_dataarray(os.path.join(continent_data_folderpath, dataset))

#     for name, _id in tqdm.tqdm(list(name_id_dict.items())):
#         print('name:', name)

#         folderpath = os.path.join(output_folderpath, name)
#         os.makedirs(folderpath, exist_ok=True)

#         export_filepath = os.path.join(folderpath, dataset)

#         if os.path.exists(export_filepath):
#             print('Exists.')
#             continue

#         xs, ys = np.where(labels == _id)
#         min_x = min(xs)
#         max_x = max(xs)
#         min_y = min(ys)
#         max_y = max(ys)

#         _dataarray = xr.DataArray(
#             data = dataarray.values[:, min_x:max_x, min_y:max_y],
#             dims = dataarray.dims,
#             coords = dataarray.coords,
#         )
#         try:
#             _dataarray.to_netcdf(export_filepath)
#         except Exception as e:
#             os.remove(export_filepath)
#             print(f'Failed. {str(e)}')

#         del _dataarray
    
#     del dataarray
#     gc.collect()