In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import rioxarray
import xesmf as xe

# http://www.earthstat.org/cropland-pasture-area-2000/
# https://docs.xarray.dev/en/stable/user-guide/io.html#rasterio
# http://www.earthstat.org/harvested-area-yield-175-crops/


# Regrid cropland masks

In [None]:
# GCM = "mpi"
GCM = "cesm2"

# for crop in ("wheat", "rice", "soy", "maize", "all"):
for crop in ("soybean", "maize", "all"):
    if crop == "all":
        crop_name = crop
        geotiff_path = "data/cropland/CroplandPastureArea2000_Geotiff/Cropland2000_5m.tif"
    else:
        if crop == "wheat":
            crop_name = "wht"
        elif crop == "maize":
            crop_name = "gro"
        elif crop == "rice":
            crop_name = "pdr"
        elif crop == "soybean":
            crop_name = "osd"
        else:
            raise NotImplementedError()

        geotiff_path = "data/cropland/" + crop + "_HarvAreaYield_Geotiff/" + crop + "_HarvestedAreaFraction.tif"

    # Open into an xarray.DataArray
    crop_da = rioxarray.open_rasterio(geotiff_path)
    crop_da = crop_da.to_dataset("band")
    crop_da = crop_da.rename({1: "data"})["data"]
    crop_da = crop_da.rename({"x": "lon", "y": "lat"})

    data_lons = crop_da["lon"].values
    i = np.where(data_lons <= 0)
    data_lons[i] = data_lons[i] + 360
    i = np.where(data_lons >= 359.93)
    data_lons[i] = 0.0
    crop_da = crop_da.assign_coords(lon=data_lons)
    crop_da = crop_da.sortby("lon")

    crop_da = xr.where(crop_da < 0, 0.0, crop_da)
    crop_da

    # get the data grid
    if GCM == "mpi":
        data = xr.open_dataarray("processed_data/exp201_processed_data.nc")
    elif GCM == "cesm2":
        data = xr.open_dataarray("processed_data/exp2001_processed_data.nc")

    data = data.stack(sample=("member", "window")).transpose("sample", "lat", "lon")

    # perform the regridding and save the netcdf file
    grid_in = xr.Dataset(
        {
            "lat": (["lat"], crop_da["lat"].values, {"units": "degrees_north"}),
            "lon": (["lon"], crop_da["lon"].values, {"units": "degrees_east"}),
        }
    )
    grid_out = xr.Dataset(
        {
            "lat": (["lat"], data["lat"].values, {"units": "degrees_north"}),
            "lon": (["lon"], data["lon"].values, {"units": "degrees_east"}),
        }
    )

    regridder = xe.Regridder(grid_in, grid_out, "conservative", periodic=True)
    crop_da_regrid = regridder(crop_da, keep_attrs=True)
    crop_da_regrid = xr.where(crop_da_regrid == 0, np.nan, crop_da_regrid)

    crop_da_regrid.to_netcdf("data/cropland/regridded_cropland/" + crop_name + "_cropped_regrid_" + GCM + ".nc")

In [None]:
# data = xr.open_dataarray("processed_data/exp201_processed_data.nc")
# data = data.stack(sample=("member", "window")).transpose("sample", "lat", "lon")

In [None]:
# grid_in = xr.Dataset(
#     {
#         "lat": (["lat"], crop_da["lat"].values, {"units": "degrees_north"}),
#         "lon": (["lon"], crop_da["lon"].values, {"units": "degrees_east"}),
#     }
# )
# grid_out = xr.Dataset(
#     {
#         "lat": (["lat"], data["lat"].values, {"units": "degrees_north"}),
#         "lon": (["lon"], data["lon"].values, {"units": "degrees_east"}),
#     }
# )

# regridder = xe.Regridder(grid_in, grid_out, "conservative", periodic=True)
# crop_da_regrid = regridder(crop_da, keep_attrs=True)
# crop_da_regrid = xr.where(crop_da_regrid==0, np.nan, crop_da_regrid)
# crop_da_regrid.to_netcdf("data/cropland/" + crop + "_HarvAreaYield_Geotiff/" + crop + "_cropped_regrid.nc")

# Check my work

In [None]:
assert False

In [None]:
crop_da_regrid = xr.load_dataarray("data/cropland/" + crop + "_HarvAreaYield_Geotiff/" + crop + "_cropped_regrid.nc")

plt.pcolor(crop_da_regrid)
plt.colorbar()
plt.show

In [None]:
import data_processing

DATA_DIRECTORY = "/Users/eabarnes/big_data/"
SHAPE_DIRECTORY = "shapefiles/gadm_shapefiles_20230301_gtapv11/"
mask_country, regs_shp = data_processing.get_country_masks(SHAPE_DIRECTORY, DATA_DIRECTORY)


In [None]:
ishp_partner = 3
mask_partner = xr.where(mask_country == ishp_partner, 1.0, np.nan)
mask_partner_crop = xr.where(mask_country == ishp_partner, 1.0, np.nan) * crop_da_regrid

In [None]:
import importlib as imp

imp.reload(data_processing)

response_crop = data_processing.compute_global_sum(
    data * mask_partner_crop / data_processing.compute_global_sum(mask_partner_crop)
)
response = data_processing.compute_global_sum(data * mask_partner / data_processing.compute_global_sum(mask_partner))


In [None]:
x = data_processing.compute_global_sum(data * mask_partner_crop)
i = np.argmax(x.values)
i

In [None]:
sample = 0  # 100#139
print(response[sample].values, response_crop[sample].values)

plt.figure(figsize=(18, 4))

plt.subplot(1, 3, 1)
plt.pcolor(mask_partner_crop, cmap="plasma")
plt.title("Fraction of Area Cropped")
plt.colorbar()
plt.clim(0, 1)
plt.xlim(25, 85)
plt.ylim(55, 80)

plt.subplot(1, 3, 2)
plt.pcolor(data[sample, :, :] * mask_partner)
plt.title("Boolean Stress Index for Sample #" + str(sample))
plt.clim(0, 1)
plt.colorbar()
plt.xlim(25, 85)
plt.ylim(55, 80)

plt.subplot(1, 3, 3)
x_plot = data[sample, :, :] * mask_partner_crop
x_plot = np.where(x_plot == 0, np.nan, x_plot)
plt.pcolor(x_plot, cmap="plasma")
plt.title("Percent Cropped that is Stressed = " + str(np.round(100 * response_crop[sample].values)) + "%")
plt.colorbar()
plt.clim(0, 1)
plt.xlim(25, 85)
plt.ylim(55, 80)

plt.tight_layout()
plt.show()
