In [1]:
import pandas as pd
import geopandas as gpd
import shapely
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import plotly.graph_objects as go
from scipy.stats import zscore

import os
import xarray as xr

import tqdm

In [2]:
import sys
sys.path.append('..')
sys.path.append('../scripts')

In [3]:
import era5_csu
import csu
import rsutils.rich_data_filter

In [4]:
era5_catalog_df = pd.read_csv(era5_csu.ERA5_CATALOG_FILEPATH)

In [None]:
era5_catalog_df

In [6]:
# region_gdf = gpd.read_file(era5_csu.MALAWI_SHAPEFILEPATH)
# region_gdf = gpd.read_file('/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/shapefiles/AfSP012Qry_ISRIC/GIS_Shape/AfSP012Qry_SubSaharanAfrica.shp')

region_gdf = gpd.read_file('/gpfs/data1/cmongp1/sasirajann/togo/shapefiles/Shapefiles/tgo_admbnda_adm0_inseed_itos_20210107.shp')

# output_folderpath = '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/era5_csu/sub-saharan-africa/maize_pollination'

output_folderpath = '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/era5_csu/togo/'

os.makedirs(output_folderpath, exist_ok=True)

shapes = [shapely.unary_union(region_gdf['geometry']).envelope]
crs = region_gdf.crs

In [7]:
startdate = '1994-01-01'
enddate = '2024-12-01'
cutoffdate = '2024-01-01'

In [None]:
region_temp_data = era5_csu.load_clipped_data_by_daterange(
    startdate = startdate,
    enddate = enddate,
    var = era5_csu.VAR_TEMP,
    catalog_df = era5_catalog_df,
    shapes = shapes,
    crs = crs,
    njobs = 120,
)

In [None]:
region_prec_data = era5_csu.load_clipped_data_by_daterange(
    startdate = startdate,
    enddate = enddate,
    var = era5_csu.VAR_PREC,
    catalog_df = era5_catalog_df,
    shapes = shapes,
    crs = crs,
    njobs = 120,
)

In [None]:
region_prec_data.data.shape

In [20]:
def plot_the_df(plot_df, cols, x):
    fig = go.Figure()

    for col in cols:
        zcol = col + "_z"
        # pass original values as customdata so hover can display them
        fig.add_trace(
            go.Scatter(
                x = x,
                y = plot_df[zcol],
                name=f"{col} (z-score)",
                mode="lines",
                customdata=np.column_stack([plot_df[col].values]),  # shape (N,1)
                hovertemplate=(
                    "%{x}<br>" +
                    # f"{col} " + "(z): %{y:.2f}<br>" +
                    f"{col} (orig): "+"%{customdata[0]:.3f}<extra></extra>"
                )
            )
        )

    # layout polish
    fig.update_layout(
        title="Multiple time series (z-score normalized) — hover shows original values",
        xaxis_title="Date",
        yaxis_title="Normalized (z-score)",
        hovermode="x unified",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        width=1200,   # in pixels
        height=600,   # in pixels
    )

    # add range slider and selectors
    fig.update_layout(
        xaxis=dict(rangeselector=dict(buttons=list([
            dict(count=7, label="1w", step="day", stepmode="backward"),
            dict(count=1, label="1m", step="month", stepmode="backward"),
            dict(count=6, label="6m", step="month", stepmode="backward"),
            dict(step="all")
        ])), rangeslider=dict(visible=True), type="date")
    )

    fig.show()

In [21]:
def shift_right(arr, n, fill_value=0):
    n_ts, height, width = arr.shape
    return np.concatenate([arr[n:], np.full(shape=(n, height, width), fill_value=fill_value)])


def updated_compute_csu(
    region_temp_data,
    region_prec_data,
    t_base,
    required_gdd_for_pollination,
    required_gdd_for_maturity,
    max_tolerable_temp,
    min_tolerable_temp,
    max_duration,
    min_total_prec_till_maturity,
    max_consecutive_dry_days,
    dryspell_threshold,
    days_to_germination,
    min_total_prec_for_germination,
):
    days_to_pollination, _ = \
    csu.calculate_days_to_maturity(
        temp_ts = region_temp_data.values,
        t_base = t_base,
        required_gdd = required_gdd_for_pollination,
        max_tolerable_temp = max_tolerable_temp,
        min_tolerable_temp = min_tolerable_temp,
    )

    days_to_maturity, _ = \
    csu.calculate_days_to_maturity(
        temp_ts = region_temp_data.values,
        t_base = t_base,
        required_gdd = required_gdd_for_maturity,
        max_tolerable_temp = max_tolerable_temp,
        min_tolerable_temp = min_tolerable_temp,
    )

    days_to_maturity[days_to_maturity == np.inf] = -1
    days_to_maturity = days_to_maturity.astype(int)

    days_to_pollination[days_to_pollination == np.inf] = -1
    days_to_pollination = days_to_pollination.astype(int)

    consecutive_dry_days = rsutils.rich_data_filter.get_continuous_sum(
        (region_prec_data.values < dryspell_threshold).astype(int).swapaxes(0, -1), spot_fill = False,
    ).swapaxes(0, -1)

    consecutive_dry_days_at_pollination = \
    csu.lookup(
        value_arr = consecutive_dry_days.astype(float),
        shift_arr = days_to_pollination,
    )

    total_prec_in_days_to_maturity = \
    csu.total_prec_in_days_to_maturity(
        cumsum_prec_ts = region_prec_data.values.cumsum(axis=0),
        days_to_maturity = days_to_maturity,
    )

    suitable_days = np.zeros(shape=days_to_maturity.shape, dtype=np.uint8)
    suitable_days[np.where(
        (days_to_maturity <= max_duration)
        & (days_to_maturity != -1)
        & (consecutive_dry_days_at_pollination < max_consecutive_dry_days)
        & (total_prec_in_days_to_maturity >= min_total_prec_till_maturity)
    )] = 1

    total_prec_in_days_to_germination = \
    csu.total_prec_in_days_to_maturity(
        cumsum_prec_ts = region_prec_data.values.cumsum(axis=0),
        days_to_maturity = np.full(shape=days_to_maturity.shape, fill_value=days_to_germination, dtype=int),
    )
    
    final_suitability_days = shift_right(suitable_days, days_to_germination, 0) & (total_prec_in_days_to_germination > min_total_prec_for_germination)
    
    return final_suitability_days, \
        suitable_days, \
        days_to_maturity, \
        days_to_pollination, \
        total_prec_in_days_to_maturity, \
        total_prec_in_days_to_germination, \
        consecutive_dry_days

In [13]:
t_base = 8
required_gdd_for_pollination = 842
required_gdd_for_maturity = 2400
max_tolerable_temp = 45
min_tolerable_temp = 0
max_duration = 500
min_total_prec_till_maturity = 450
max_consecutive_dry_days = 5
dryspell_threshold = 1
days_to_germination = 3
min_total_prec_for_germination = 20

final_suitability_days, \
suitable_days, \
days_to_maturity, \
days_to_pollination, \
total_prec_in_days_to_maturity, \
total_prec_in_days_to_germination, \
consecutive_dry_days = updated_compute_csu(
    region_temp_data = region_temp_data,
    region_prec_data = region_prec_data,
    t_base = t_base,
    required_gdd_for_pollination = required_gdd_for_pollination,
    required_gdd_for_maturity = required_gdd_for_maturity,
    max_tolerable_temp = max_tolerable_temp,
    min_tolerable_temp = min_tolerable_temp,
    max_duration = max_duration,
    min_total_prec_till_maturity = min_total_prec_till_maturity,
    max_consecutive_dry_days = max_consecutive_dry_days,
    dryspell_threshold = dryspell_threshold,
    days_to_germination = days_to_germination,
    min_total_prec_for_germination = min_total_prec_for_germination,
)

In [15]:
def get_cutoff_index(
    dates,
    cutoffdate,
):
    cutoff_index = np.where(np.array(dates) == pd.Timestamp(cutoffdate))[0][0]
    return cutoff_index

In [17]:
CUTOFFDATE = '2024-01-01'

cutoffindex = get_cutoff_index(
    dates = region_prec_data.valid_time,
    cutoffdate = CUTOFFDATE,
)

In [None]:
region_prec_data.shape

In [28]:
x, y = 10, 4

df = pd.DataFrame(data = {
    # 'temperature': region_temp_data[:cutoffindex, x, y],
    # 'precipitation': region_prec_data[:cutoffindex, x, y],
    # 'days_to_maturity': days_to_maturity[:cutoffindex, x, y],
    # 'days_to_pollination': days_to_pollination[:cutoffindex, x, y],
    # 'consecutive_dry_days': consecutive_dry_days[:cutoffindex, x, y],
    f'total_prec_in_days_to_germination>{min_total_prec_for_germination}': total_prec_in_days_to_germination[:cutoffindex, x, y] > min_total_prec_for_germination,
    'suitable_days': suitable_days[:cutoffindex, x, y],
    'final_suitability_days': final_suitability_days[:cutoffindex, x, y],
})

norm_df = df.apply(
    lambda x: zscore(x, nan_policy='omit')).rename(columns=lambda c: f"{c}_z"
)

plot_df = pd.concat([df, norm_df], axis=1)

In [None]:
plot_the_df(plot_df=plot_df, cols=df.columns, x=region_prec_data.valid_time.values[:cutoffindex])

In [None]:
filename = (
    'sum-suitable-days'
    f'_tbase={t_base}'
    f'_reqgddpol={required_gdd_for_pollination}'
    f'_reqgddmat={required_gdd_for_maturity}'
    f'_maxtoltemp={max_tolerable_temp}'
    f'_mintoltemp={min_tolerable_temp}'
    f'_minreqprecmat={min_total_prec_till_maturity}'
    f'_maxduration={max_duration}'
    f'_dryspellthreshold={dryspell_threshold}'
    f'_maxconsecdryspell={max_consecutive_dry_days}'
    f'_daystogerm={days_to_germination}'
    f'_minreqprecgerm={min_total_prec_for_germination}'
    '.tif'
)

filename

In [None]:
output_folderpath

In [32]:
export_folderpath = output_folderpath

In [22]:
sum_suitable_days = xr.DataArray(
    final_suitability_days[:cutoffindex].sum(axis=(0)).astype(np.uint16),
    coords = {
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

sum_suitable_days = sum_suitable_days.rio.set_spatial_dims('longitude', 'latitude')
sum_suitable_days = sum_suitable_days.rio.write_crs('epsg:4326')
sum_suitable_days.rio.to_raster(os.path.join(export_folderpath, filename))

In [None]:
days_to_maturity = days_to_maturity.astype(float)
days_to_maturity[final_suitability_days != 1] = np.nan
mean_days_to_maturity = np.nanmean(days_to_maturity, axis=0)
mean_days_to_maturity.shape

In [None]:
days_to_maturity[final_suitability_days != 1]

In [None]:
mean_days_to_maturity_da = xr.DataArray(
    mean_days_to_maturity,
    coords = {
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

mean_days_to_maturity_da = mean_days_to_maturity_da.rio.set_spatial_dims('longitude', 'latitude')
mean_days_to_maturity_da = mean_days_to_maturity_da.rio.write_crs('epsg:4326')
mean_days_to_maturity_da.rio.to_raster(os.path.join(export_folderpath, 'mean_days_to_maturity.tif'))

In [106]:
days_to_maturity_copy = days_to_maturity.copy().astype(float)

days_to_maturity_copy[days_to_maturity_copy == -1] = np.nan
days_to_maturity_copy[final_suitability_days != 1] = np.nan

mean_days_to_maturity = np.nanmean(days_to_maturity_copy, axis=0)

In [None]:
days_to_maturity[days_to_maturity != -1].min()

In [None]:
final_suitability_days[:cutoffindex].sum(axis=(0)).min()

In [32]:
alt_suitability_measure = days_to_maturity.copy()

alt_suitability_measure = (1 / alt_suitability_measure)**2

alt_suitability_measure[alt_suitability_measure == -1] = 0
alt_suitability_measure[final_suitability_days != 1] = 0

In [33]:
sum_suitable_days = xr.DataArray(
    alt_suitability_measure[:cutoffindex].sum(axis=(0)).astype(np.float64),
    coords = {
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

sum_suitable_days = sum_suitable_days.rio.set_spatial_dims('longitude', 'latitude')
sum_suitable_days = sum_suitable_days.rio.write_crs('epsg:4326')
sum_suitable_days.rio.to_raster(os.path.join(export_folderpath, 'alt_' + filename))

In [33]:
final_suitability_days_da = xr.DataArray(
    final_suitability_days[:cutoffindex].astype(np.float64),
    coords = {
        'valid_time': region_temp_data.valid_time[:cutoffindex],
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('valid_time', 'latitude', 'longitude'),
)

In [34]:
# final_suitability_days_da

In [35]:
final_suitability_days_da_doy_mean = final_suitability_days_da.groupby('valid_time.dayofyear').mean('valid_time')
# final_suitability_days_da_doy_mean

In [36]:
import sklearn.cluster
import rsutils.utils

In [None]:
final_suitability_days_da_doy_mean.values.shape

In [None]:
n_ts, height, width = final_suitability_days_da_doy_mean.values.shape

final_suitability_days_2d = final_suitability_days_da_doy_mean.values.reshape(n_ts, height*width).swapaxes(0, 1)
final_suitability_days_2d.shape

In [39]:
def relabel_clusters_by_count(cluster_ids:np.ndarray):
    _ids, _counts = np.unique(cluster_ids, return_counts=True)
    cluster_count_df = pd.DataFrame(data={
        'cluster_id': _ids,
        'count': _counts
    })
    cluster_count_df = cluster_count_df.sort_values(by='count', ascending=False)
    cluster_count_df['new_cluster_id'] = range(_ids.shape[0])
    new_cluster_id_map = dict(zip(
        cluster_count_df['cluster_id'],
        cluster_count_df['new_cluster_id'],
    ))
    new_cluster_ids = np.zeros(shape=cluster_ids.shape)
    for old_id, new_id in new_cluster_id_map.items():
        new_cluster_ids[cluster_ids == old_id] = new_id
    return new_cluster_ids.astype(int)

In [None]:
filename = 'doy_mean_final_suitability'

n_clusters = 4
nrows, ncols = 2, 2

cluster_ids = sklearn.cluster.MiniBatchKMeans(
    n_clusters = n_clusters,
    random_state = 42,
).fit(final_suitability_days_2d).labels_


cluster_ids = relabel_clusters_by_count(cluster_ids=cluster_ids)


rsutils.utils.plot_clustered_lineplots(
    crop_name = '',
    band_name = 'mean suitability',
    timeseries = final_suitability_days_2d,
    cluster_ids = cluster_ids,
    y_min = -0.1,
    y_max = 1.1,
    nrows = nrows,
    ncols = ncols,
    x = range(1, 367),
    x_label = 'doy',
    save_filepath = os.path.join(export_folderpath, f'{filename}.png'),
    alpha = 0.1,
)

In [None]:
filename

In [43]:
cluster_ids_da = xr.DataArray(
    cluster_ids.reshape(height, width).astype(np.uint8),
    coords = {
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

cluster_ids_da = cluster_ids_da.rio.set_spatial_dims('longitude', 'latitude')
cluster_ids_da = cluster_ids_da.rio.write_crs('epsg:4326')
cluster_ids_da.rio.to_raster(os.path.join(export_folderpath, 'cluster_ids_' + filename + '.tif'))

In [76]:
t_base = 8
required_gdd_for_pollination = 842 # to pollination
required_gdd_for_maturity = 2400
required_gdd_for_emergence = 100 # to 150 , 90-120 T-base 10, recalculated for T-base 8
max_tolerable_temp = 45
min_tolerable_temp = 0

days_to_pollination, _ = \
csu.calculate_days_to_maturity(
    temp_ts = region_temp_data.values,
    t_base = t_base,
    required_gdd = required_gdd_for_pollination,
    max_tolerable_temp = max_tolerable_temp,
    min_tolerable_temp = min_tolerable_temp,
)

days_to_maturity, _ = \
csu.calculate_days_to_maturity(
    temp_ts = region_temp_data.values,
    t_base = t_base,
    required_gdd = required_gdd_for_maturity,
    max_tolerable_temp = max_tolerable_temp,
    min_tolerable_temp = min_tolerable_temp,
)

days_to_emergence, _ = \
csu.calculate_days_to_maturity(
    temp_ts = region_temp_data.values,
    t_base = t_base,
    required_gdd = required_gdd_for_emergence,
    max_tolerable_temp = max_tolerable_temp,
    min_tolerable_temp = min_tolerable_temp,
)

In [20]:
days_to_maturity[days_to_maturity == np.inf] = -1
days_to_maturity = days_to_maturity.astype(int)

days_to_pollination[days_to_pollination == np.inf] = -1
days_to_pollination = days_to_pollination.astype(int)

In [None]:
region_prec_data.values.shape

In [12]:
dryspell_threshold = 1 # mm

consecutive_dry_days = rsutils.rich_data_filter.get_continuous_sum(
    (region_prec_data.values < dryspell_threshold).astype(int).swapaxes(0, -1), spot_fill = False,
).swapaxes(0, -1)

In [21]:
consecutive_dry_days_at_pollination = \
csu.lookup(
    value_arr = consecutive_dry_days.astype(float),
    shift_arr = days_to_pollination,
)

In [22]:
total_prec_in_days_to_maturity = \
csu.total_prec_in_days_to_maturity(
    cumsum_prec_ts = region_prec_data.values.cumsum(axis=0),
    days_to_maturity = days_to_maturity,
)

In [37]:
total_prec_in_3days = \
csu.total_prec_in_days_to_maturity(
    cumsum_prec_ts = region_prec_data.values.cumsum(axis=0),
    days_to_maturity = np.full(shape=days_to_maturity.shape, fill_value=3, dtype=int),
)

In [23]:
def get_cutoff_index(
    dates,
    cutoffdate,
):
    cutoff_index = np.where(np.array(dates) == pd.Timestamp(cutoffdate))[0][0]
    return cutoff_index


In [24]:
CUTOFFDATE = '2024-01-01'

cutoffindex = get_cutoff_index(
    dates = region_prec_data.valid_time,
    cutoffdate = CUTOFFDATE,
)

In [26]:
valid_days_to_maturity = days_to_maturity[:cutoffindex]
valid_consecutive_dry_days_at_pollination = consecutive_dry_days_at_pollination[:cutoffindex]
valid_total_prec_in_days_to_maturity = total_prec_in_days_to_maturity[:cutoffindex]

In [68]:
max_duration = 320
max_consecutive_dry_days = 5
min_total_prec_till_maturity = 450
min_total_prec_till_pollination = 300

suitable_days = np.zeros(shape=days_to_maturity.shape, dtype=np.uint8)
suitable_days[np.where(
    (days_to_maturity <= max_duration)
    & (days_to_maturity != -1)
    & (consecutive_dry_days_at_pollination < max_consecutive_dry_days)
    & (total_prec_in_days_to_maturity >= min_total_prec_till_maturity)
)] = 1

In [69]:
def shift_right(arr, n, fill_value=0):
    n_ts, height, width = arr.shape
    return np.concatenate([arr[n:], np.full(shape=(n, height, width), fill_value=fill_value)])

In [71]:
final_suitability_days = shift_right(suitable_days, 7, 0) & (total_prec_in_3days>20)

In [51]:
x, y = 10, 4

In [None]:
(total_prec_in_3days[:cutoffindex, x, y]>20).astype(int)

In [None]:
suitable_days[:, x, y]

In [None]:
A = np.arange(10)
B = np.arange(10)
A, B

In [None]:
n = 3

np.concatenate([B[n:], np.full(shape=(n), fill_value=-1)])

In [72]:
x, y = 10, 4

df = pd.DataFrame(data = {
    # 'temperature': region_temp_data[:cutoffindex, x, y],
    # 'precipitation': region_prec_data[:cutoffindex, x, y],
    # 'days_to_maturity': days_to_maturity[:cutoffindex, x, y],
    # 'days_to_pollination': days_to_pollination[:cutoffindex, x, y],
    # 'consecutive_dry_days': consecutive_dry_days[:cutoffindex, x, y],
    'total_prec_in_3days>20': total_prec_in_3days[:cutoffindex, x, y]>20,
    'suitable_days': suitable_days[:cutoffindex, x, y],
    'final_suitability_days': final_suitability_days[:cutoffindex, x, y],
})

norm_df = df.apply(
    lambda x: zscore(x, nan_policy='omit')).rename(columns=lambda c: f"{c}_z"
)

plot_df = pd.concat([df, norm_df], axis=1)

In [None]:
plot_df

In [None]:
plot_the_df(plot_df=plot_df)

In [18]:
CUTOFFDATE = '2024-01-01'

cutoffindex = get_cutoff_index(
    dates = region_prec_data.valid_time,
    cutoffdate = CUTOFFDATE,
)

In [19]:
days_to_maturity[days_to_maturity == np.inf] = -1
days_to_maturity = days_to_maturity.astype(int)

In [None]:
consecutive_dry_days.shape

In [None]:
days_to_maturity.shape

In [22]:
consecutive_dry_days_at_maturity = \
csu.lookup(
    value_arr = consecutive_dry_days.astype(float),
    shift_arr = days_to_maturity,
)

In [None]:
np.nanmax(consecutive_dry_days_at_maturity)

In [26]:
total_prec_in_days_to_maturity = \
csu.total_prec_in_days_to_maturity(
    cumsum_prec_ts = region_prec_data.values.cumsum(axis=0),
    days_to_maturity = days_to_maturity,
)

In [34]:
valid_days_to_maturity = days_to_maturity[:cutoffindex]
valid_consecutive_dry_days_at_maturity = consecutive_dry_days_at_maturity[:cutoffindex]
valid_total_prec_in_days_to_maturity = total_prec_in_days_to_maturity[:cutoffindex]

In [41]:
max_duration = 360
max_consecutive_dry_days = 5
min_total_prec_till_pollination = 300

suitable_days = np.zeros(shape=valid_days_to_maturity.shape, dtype=np.uint8)
suitable_days[np.where(
    (valid_days_to_maturity <= max_duration)
    & (valid_days_to_maturity != -1)
    & (valid_consecutive_dry_days_at_maturity < max_consecutive_dry_days)
    & (valid_total_prec_in_days_to_maturity >= min_total_prec_till_pollination)
)] = 1

In [42]:
x, y = 10, 4

df = pd.DataFrame(data = {
    'temperature': region_temp_data[:cutoffindex, x, y],
    'precipitation': region_prec_data[:cutoffindex, x, y],
    'days_to_maturity': days_to_maturity[:cutoffindex, x, y],
    'consecutive_dry_days_at_maturity': consecutive_dry_days_at_maturity[:cutoffindex, x, y],
    'suitable_days': suitable_days[:, x, y],
})

norm_df = df.apply(
    lambda x: zscore(x, nan_policy='omit')).rename(columns=lambda c: f"{c}_z"
)

plot_df = pd.concat([df, norm_df], axis=1)

In [None]:
plot_the_df(plot_df=plot_df)

In [11]:
def compute_csu(
    region_temp_data,
    region_prec_data,
    t_base,
    required_gdd,
    max_tolerable_temp,
    min_tolerable_temp,
    max_duration,
    min_total_prec,
    max_total_prec,
    cutoffdate = '2024-01-01'
):
    days_to_maturity, _ = \
    csu.calculate_days_to_maturity(
        temp_ts = region_temp_data.values,
        t_base = t_base,
        required_gdd = required_gdd,
        max_tolerable_temp = max_tolerable_temp,
        min_tolerable_temp = min_tolerable_temp,
    )

    days_to_maturity[days_to_maturity == np.inf] = -1
    days_to_maturity = days_to_maturity.astype(int)

    total_prec_in_days_to_maturity = \
    csu.total_prec_in_days_to_maturity(
        cumsum_prec_ts = region_prec_data.values.cumsum(axis=0),
        days_to_maturity = days_to_maturity,
    )

    cutoffindex = get_cutoff_index(
        dates = region_temp_data.valid_time,
        cutoffdate = cutoffdate,
    )

    valid_days_to_maturity = days_to_maturity[:cutoffindex]
    valid_total_prec_in_days_to_maturity = total_prec_in_days_to_maturity[:cutoffindex]

    suitable_days = np.zeros(shape=valid_days_to_maturity.shape, dtype=np.uint8)
    suitable_days[np.where(
        (valid_days_to_maturity <= max_duration)
        & (valid_days_to_maturity != -1)
        & (valid_total_prec_in_days_to_maturity >= min_total_prec)
        & (valid_total_prec_in_days_to_maturity <= max_total_prec)
    )] = 1

    return suitable_days, \
        valid_days_to_maturity, \
        valid_total_prec_in_days_to_maturity, \
        valid_days_to_maturity <= max_duration, \
        valid_days_to_maturity != -1, \
        valid_total_prec_in_days_to_maturity >= min_total_prec, \
        valid_total_prec_in_days_to_maturity <= max_total_prec

In [12]:
T_BASE = 10
REQUIRED_GDD = 3000
MAX_TOLERABLE_TEMP = 45
MIN_TOLERABLE_TEMP = 0
REQUIRED_PRECP = 500 # mm / total growing period
MAX_CROP_DURATION = 365 # days
MIN_TOTAL_PREC = 800
MAX_TOTAL_PREC = 1300
CUTOFFDATE = '2024-01-01'

In [None]:
t_bases = [8, 10]
required_gdds = [1500, 1800, 2400, 2800, 3000, 3500]
max_tolerable_temps = [38]
min_tolerable_temps = [0]
max_durations = [80, 160, 320]
min_total_precs = [450, 500, 800]
max_total_precs = [1300]

params = [
    (_t_base, _req_gdd, _max_tol_temp, _min_tol_temp, _max_dur, _min_tot_prec, _max_tot_prec)
    for _t_base in t_bases
    for _req_gdd in required_gdds
    for _max_tol_temp in max_tolerable_temps
    for _min_tol_temp in min_tolerable_temps
    for _max_dur in max_durations
    for _min_tot_prec in min_total_precs
    for _max_tot_prec in max_total_precs
]

len(params)

In [None]:
108 * 6 / 60

In [12]:
mean_temperature = xr.DataArray(
    region_temp_data.values.mean(axis=(0)),
    coords = {
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

mean_temperature = mean_temperature.rio.set_spatial_dims('longitude', 'latitude')
mean_temperature = mean_temperature.rio.write_crs('epsg:4326')
mean_temperature.rio.to_raster(os.path.join(output_folderpath, 'mean_temperature.tif'))

In [13]:
sum_precipitation = xr.DataArray(
    region_prec_data.values.sum(axis=(0)),
    coords = {
        'latitude': region_temp_data.latitude,
        'longitude': region_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

sum_precipitation = sum_precipitation.rio.set_spatial_dims('longitude', 'latitude')
sum_precipitation = sum_precipitation.rio.write_crs('epsg:4326')
sum_precipitation.rio.to_raster(os.path.join(output_folderpath, 'sum_precipitation.tif'))

In [17]:
# export_folderpath = '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/era5_csu/malawi/experiments'

# os.makedirs(export_folderpath, exist_ok=True)

# data = {
#     't_base' : [],
#     'required_gdd' : [],
#     'max_tolerable_temp' : [],
#     'min_tolerable_temp' : [],
#     'max_duration' : [],
#     'min_total_prec' : [],
#     'max_total_prec' : [],
#     'valid_days_to_maturity (sum)' : [],
#     'valid_total_prec_in_days_to_maturity (sum)' : [],
#     'days_to_maturity_under_max_duration (sum)' : [],
#     'days_to_maturity_valid (sum)' : [],
#     'total_prec_above_minimum (sum)' : [],
#     'total_prec_under_maximum (sum)' : [],
#     'suitable_days (sum)': [],
# }

# for t_base, required_gdd, max_tolerable_temp, \
#     min_tolerable_temp, max_duration, min_total_prec, \
#     max_total_prec in tqdm.tqdm(params):

#     suitable_days, \
#     valid_days_to_maturity, \
#     valid_total_prec_in_days_to_maturity, \
#     days_to_maturity_under_max_duration, \
#     days_to_maturity_valid, \
#     total_prec_above_minimum, \
#     total_prec_under_maximum, \
#     = compute_csu(
#         malawi_temp_data = malawi_temp_data,
#         malawi_prec_data = malawi_prec_data,
#         t_base = t_base,
#         required_gdd = required_gdd,
#         max_tolerable_temp = max_tolerable_temp,
#         min_tolerable_temp = min_tolerable_temp,
#         max_duration = max_duration,
#         min_total_prec = min_total_prec,
#         max_total_prec = max_total_prec,
#         cutoffdate = CUTOFFDATE
#     )

#     data['t_base'].append(t_base)
#     data['required_gdd'].append(required_gdd)
#     data['max_tolerable_temp'].append(max_tolerable_temp)
#     data['min_tolerable_temp'].append(min_tolerable_temp)
#     data['max_duration'].append(max_duration)
#     data['min_total_prec'].append(min_total_prec)
#     data['max_total_prec'].append(max_total_prec)
#     data['valid_days_to_maturity (sum)'].append(np.nansum(valid_days_to_maturity))
#     data['valid_total_prec_in_days_to_maturity (sum)'].append(np.nansum(valid_total_prec_in_days_to_maturity))
#     data['days_to_maturity_under_max_duration (sum)'].append(np.nansum(days_to_maturity_under_max_duration))
#     data['days_to_maturity_valid (sum)'].append(np.nansum(days_to_maturity_valid))
#     data['total_prec_above_minimum (sum)'].append(np.nansum(total_prec_above_minimum))
#     data['total_prec_under_maximum (sum)'].append(np.nansum(total_prec_under_maximum))
#     data['suitable_days (sum)'].append(np.nansum(suitable_days))

#     filename = (
#         f'sum-suitable-days_tbase={t_base}_'
#         f'reqgdd={required_gdd}_'
#         f'maxtoltemp={max_tolerable_temp}_'
#         f'mintoltemp={min_tolerable_temp}_'
#         f'minreqprec={min_total_prec}_'
#         f'maxreqprec={max_total_prec}_'
#         f'maxduration={max_duration}.tif'
#     )

#     sum_suitable_days = xr.DataArray(
#         suitable_days.sum(axis=(0)).astype(np.uint16),
#         coords = {
#             'latitude': malawi_temp_data.latitude,
#             'longitude': malawi_temp_data.longitude,
#         },
#         dims = ('latitude', 'longitude'),
#     )

#     sum_suitable_days = sum_suitable_days.rio.set_spatial_dims('longitude', 'latitude')
#     sum_suitable_days = sum_suitable_days.rio.write_crs('epsg:4326')
#     sum_suitable_days.rio.to_raster(os.path.join(export_folderpath, filename))

# experiment_stats_df = pd.DataFrame(data = data)

In [18]:
# experiment_stats_df[[
#     't_base',
#     'required_gdd',
#     # 'max_tolerable_temp',
#     # 'min_tolerable_temp',
#     'max_duration',
#     'min_total_prec',
#     # 'max_total_prec',
#     'suitable_days (sum)',
# ]].corr()

In [19]:
# scale = 5
# aspect_ratio = 1

# fig, ax = plt.subplots(figsize=(scale*aspect_ratio, scale))

# g = sns.histplot(
#     ax = ax,
#     data = experiment_stats_df,
#     x = 'max_tolerable_temp',
#     y = 'suitable_days (sum)',
#     bins = (25, 25),
#     cmap = 'Spectral_r',
#     cbar = True,
# )

In [20]:
# experiment_stats_df.groupby(by=['max_duration', 'min_total_prec'])['suitable_days (sum)'].sum().reset_index()

In [21]:
# experiment_stats_df[
#     (experiment_stats_df['min_total_prec'] == 450)
#     & (experiment_stats_df['max_duration'] == 160)
# ].sort_values(by='suitable_days (sum)', ascending=False)[[
#     't_base',
#     'required_gdd',
#     # 'max_tolerable_temp',
#     # 'min_tolerable_temp',
#     'max_duration',
#     'min_total_prec',
#     # 'max_total_prec',
#     'suitable_days (sum)',
# ]]

In [22]:
suitable_days, \
valid_days_to_maturity, \
valid_total_prec_in_days_to_maturity, \
days_to_maturity_under_max_duration, \
days_to_maturity_valid, \
total_prec_above_minimum, \
total_prec_under_maximum, \
= compute_csu(
    malawi_temp_data = malawi_temp_data,
    malawi_prec_data = malawi_prec_data,
    t_base = 8,
    required_gdd = 2400,
    max_tolerable_temp = 38,
    min_tolerable_temp = 0,
    max_duration = 320,
    min_total_prec = 450,
    max_total_prec = 1300,
    cutoffdate = CUTOFFDATE
)

In [23]:
cutoffindex = get_cutoff_index(
    dates = malawi_prec_data.valid_time,
    cutoffdate = CUTOFFDATE,
)

In [30]:
x, y = 10, 4

df = pd.DataFrame(data = {
    'temperature': malawi_temp_data[:cutoffindex, x, y],
    'precipitation': malawi_prec_data[:cutoffindex, x, y],
    'days_to_maturity': valid_days_to_maturity[:, x, y],
    'total_prec_in_d2m': valid_total_prec_in_days_to_maturity[:, x, y],
    'suitable_days': suitable_days[:, x, y],
})

norm_df = df.apply(
    lambda x: zscore(x, nan_policy='omit')).rename(columns=lambda c: f"{c}_z"
)

plot_df = pd.concat([df, norm_df], axis=1)

In [None]:
plot_df

In [None]:
fig = go.Figure()

for col in df.columns:
    zcol = col + "_z"
    # pass original values as customdata so hover can display them
    fig.add_trace(
        go.Scatter(
            x = malawi_prec_data.valid_time.values[:cutoffindex],
            y = plot_df[zcol],
            name=f"{col} (z-score)",
            mode="lines",
            customdata=np.column_stack([plot_df[col].values]),  # shape (N,1)
            hovertemplate=(
                "%{x}<br>" +
                # f"{col} " + "(z): %{y:.2f}<br>" +
                f"{col} (orig): "+"%{customdata[0]:.3f}<extra></extra>"
            )
        )
    )

# layout polish
fig.update_layout(
    title="Multiple time series (z-score normalized) — hover shows original values",
    xaxis_title="Date",
    yaxis_title="Normalized (z-score)",
    hovermode="x unified",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    width=1200,   # in pixels
    height=600,   # in pixels
)

# add range slider and selectors
fig.update_layout(
    xaxis=dict(rangeselector=dict(buttons=list([
        dict(count=7, label="1w", step="day", stepmode="backward"),
        dict(count=1, label="1m", step="month", stepmode="backward"),
        dict(count=6, label="6m", step="month", stepmode="backward"),
        dict(step="all")
    ])), rangeslider=dict(visible=True), type="date")
)

fig.show()

In [None]:
df.columns

In [None]:
start_index = get_cutoff_index(
    dates = malawi_prec_data.valid_time,
    cutoffdate = '2021-05-26',
)

list(malawi_prec_data.valid_time.values[start_index:start_index+10]), \
list(df['total_prec_in_d2m'].to_numpy()[start_index:start_index+10])

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

y1_var = 'days_to_maturity'
y2_var = 'total_prec_in_d2m'

# y1_var = 'temperature'
# y2_var = 'precipitation'

units = {
    'temperature': 'deg C',
    'precipitation': 'mm',
    'total_prec_in_d2m': 'mm',
    'days_to_maturity': 'days',
}

start_index = get_cutoff_index(
    dates = malawi_prec_data.valid_time,
    cutoffdate = '2021-05-26',
)
end_index = get_cutoff_index(
    dates = malawi_prec_data.valid_time,
    cutoffdate = '2022-05-26',
)


# Example data
x = malawi_prec_data.valid_time.values[start_index:end_index]
y1 = df[y1_var].to_numpy()[start_index:end_index]    # Variable 1 (e.g. rainfall)
y2 = df[y2_var].to_numpy()[start_index:end_index]     # Variable 2 (e.g. temperature)
mask = df['suitable_days'].to_numpy()[start_index:end_index]

fig, ax1 = plt.subplots(figsize=(8, 5))

# Plot first variable
ax1.plot(x, y1, color='tab:blue', label=y1_var)
ax1.set_xlabel('Date')
ax1.set_ylabel(y1_var + f" ({units[y1_var]})", color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')

# Create a second y-axis sharing the same x-axis
ax2 = ax1.twinx()
ax2.plot(x, y2, color='tab:red', label=y2_var)
ax2.set_ylabel(y2_var + f" ({units[y2_var]})", color='tab:red')
ax2.tick_params(axis='y', labelcolor='tab:red')

# --- Add translucent green rectangles where mask == 1 ---
for i in range(len(mask)):
    if mask[i] == 1:
        # Define start and end of highlight
        start = x[i] - pd.Timedelta(days=0)  # half a month before
        end = x[i] + pd.Timedelta(days=1)    # half a month after
        ax1.axvspan(start, end, color='green', alpha=0.15, lw=0)

# Optional: combine legends from both axes
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper left')

# --- Rotate x-axis ticks ---
ax1.tick_params(axis='x', rotation=90)

plt.title(f'{y1_var} and {y2_var} over time')
plt.tight_layout()
plt.show()


In [50]:
export_filepath = os.path.join(
    '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/era5_csu/malawi',
    f'sum-suitable-days_tbase={T_BASE}_reqgdd={REQUIRED_GDD}_maxtoltemp={MAX_TOLERABLE_TEMP}_mintoltemp={MIN_TOLERABLE_TEMP}_minreqprec={MIN_TOTAL_PREC}_maxreqprec={MAX_TOTAL_PREC}_maxduration={MAX_CROP_DURATION}.tif'
)

sum_suitable_days = xr.DataArray(
    suitable_days.sum(axis=(0)).astype(np.uint16),
    coords = {
        'latitude': malawi_temp_data.latitude,
        'longitude': malawi_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

sum_suitable_days = sum_suitable_days.rio.set_spatial_dims('longitude', 'latitude')
sum_suitable_days = sum_suitable_days.rio.write_crs('epsg:4326')
sum_suitable_days.rio.to_raster(export_filepath)

In [11]:
days_to_maturity, gdd_at_maturity = \
csu.calculate_days_to_maturity(
    temp_ts = malawi_temp_data.values,
    t_base = T_BASE,
    required_gdd = REQUIRED_GDD,
    max_tolerable_temp = MAX_TOLERABLE_TEMP,
    min_tolerable_temp = MIN_TOLERABLE_TEMP,
)

In [12]:
dates = malawi_temp_data.valid_time.values
cutoff_index = np.where(np.array(dates) == pd.Timestamp(cutoffdate))[0][0]

In [None]:
cutoff_index

In [14]:
days_to_maturity[days_to_maturity == np.inf] = -1
days_to_maturity = days_to_maturity.astype(int)

In [15]:
total_prec_in_days_to_maturity = \
csu.total_prec_in_days_to_maturity(
    cumsum_prec_ts = malawi_prec_data.values.cumsum(axis=0),
    days_to_maturity = days_to_maturity,
)

In [16]:
valid_days_to_maturity = days_to_maturity[:cutoff_index]
valid_total_prec_in_days_to_maturity = total_prec_in_days_to_maturity[:cutoff_index]

In [None]:
malawi_prec_data.shape

In [None]:
malawi_prec_data.values[:,15,7]

In [None]:
np.cumsum(malawi_prec_data.values, axis=0)[:,15,7]

In [None]:
scale = 5
aspect_ratio = 5

x, y = 15, 7

fig, ax = plt.subplots(figsize=(scale*aspect_ratio, scale))

_df = pd.DataFrame(data = {
    'date': malawi_prec_data.valid_time[:cutoff_index],
    'prec': malawi_prec_data.values[:cutoff_index, x, y],
    'cumsum_prec': np.cumsum(malawi_prec_data.values, axis=0)[:cutoff_index, x, y],
    'total_prec_in_d2m': total_prec_in_days_to_maturity[:cutoff_index, x, y],
})

# g = sns.lineplot(
#     ax = ax,
#     data = _df,
#     x = 'timestamp',
#     y = 'prec',
# )

g = sns.lineplot(
    ax = ax,
    data = _df,
    x = 'date',
    y = 'total_prec_in_d2m',
)

In [22]:
df = pd.DataFrame(data = {
    # 'prec': malawi_prec_data.values[:, x, y],
    'cumsum_prec': np.cumsum(malawi_prec_data.values, axis=0)[:cutoff_index, x, y],
    'total_prec_in_d2m': total_prec_in_days_to_maturity[:cutoff_index, x, y],
    'days_to_maturity': days_to_maturity[:cutoff_index, x, y],
})

norm_df = df.apply(
    lambda x: zscore(x, nan_policy='omit')).rename(columns=lambda c: f"{c}_z"
)

plot_df = pd.concat([df, norm_df], axis=1)

In [None]:
plot_df

In [None]:
fig = go.Figure()

for col in df.columns:
    zcol = col + "_z"
    # pass original values as customdata so hover can display them
    fig.add_trace(
        go.Scatter(
            x=malawi_prec_data.valid_time.values[:cutoff_index],
            y=plot_df[zcol],
            name=f"{col} (z-score)",
            mode="lines",
            customdata=np.column_stack([plot_df[col].values]),  # shape (N,1)
            hovertemplate=(
                "%{x}<br>" +
                # f"{col} " + "(z): %{y:.2f}<br>" +
                f"{col} (orig): "+"%{customdata[0]:.3f}<extra></extra>"
            )
        )
    )

# layout polish
fig.update_layout(
    title="Multiple time series (z-score normalized) — hover shows original values",
    xaxis_title="Date",
    yaxis_title="Normalized (z-score)",
    hovermode="x unified",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    width=1200,   # in pixels
    height=600,   # in pixels
)

# add range slider and selectors
fig.update_layout(
    xaxis=dict(rangeselector=dict(buttons=list([
        dict(count=7, label="1w", step="day", stepmode="backward"),
        dict(count=1, label="1m", step="month", stepmode="backward"),
        dict(count=6, label="6m", step="month", stepmode="backward"),
        dict(step="all")
    ])), rangeslider=dict(visible=True), type="date")
)

fig.show()

In [25]:
import xarray as xr
import os

In [None]:
valid_days_to_maturity.shape

In [27]:
max_duration = 365
min_total_prec = 800
max_total_prec = 1300

export_filepath = os.path.join(
    '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/era5_csu/malawi',
    f'sum-suitable-days_tbase={T_BASE}_reqgdd={REQUIRED_GDD}_maxtoltemp={MAX_TOLERABLE_TEMP}_mintoltemp={MIN_TOLERABLE_TEMP}_minreqprec={min_total_prec}_maxreqprec={max_total_prec}_maxduration={max_duration}.tif'
)

suitable_days = np.zeros(shape=valid_days_to_maturity.shape, dtype=np.uint8)
suitable_days[np.where(
    (valid_days_to_maturity <= max_duration)
    & (valid_days_to_maturity != -1)
    & (valid_total_prec_in_days_to_maturity >= min_total_prec)
    & (valid_total_prec_in_days_to_maturity <= max_total_prec)
)] = 1

sum_suitable_days = xr.DataArray(
    suitable_days.sum(axis=(0)).astype(np.uint16),
    coords = {
        'latitude': malawi_temp_data.latitude,
        'longitude': malawi_temp_data.longitude,
    },
    dims = ('latitude', 'longitude'),
)

sum_suitable_days = sum_suitable_days.rio.set_spatial_dims('longitude', 'latitude')
sum_suitable_days = sum_suitable_days.rio.write_crs('epsg:4326')
sum_suitable_days.rio.to_raster(export_filepath)

In [28]:
max_duration = 365
min_total_prec = 800
max_total_prec = 1300

dates = malawi_temp_data.valid_time.values


for year in range(1994, 2024):
    _start_index = np.where(np.array(dates) == pd.Timestamp(f'{year}-01-01'))[0][0]
    _end_index = np.where(np.array(dates) == pd.Timestamp(f'{year}-12-31'))[0][0]

    export_filepath = os.path.join(
        '/gpfs/data1/cmongp2/sasirajann/nh_crop_calendar/crop_calendar/data/era5_csu/malawi/yearwise',
        f'sum-suitable-days_year={year}_tbase={T_BASE}_reqgdd={REQUIRED_GDD}_maxtoltemp={MAX_TOLERABLE_TEMP}_mintoltemp={MIN_TOLERABLE_TEMP}_minreqprec={min_total_prec}_maxreqprec={max_total_prec}_maxduration={max_duration}.tif'
    )

    year_days_to_maturity = days_to_maturity[_start_index:_end_index]
    year_total_prec_in_days_to_maturity = total_prec_in_days_to_maturity[_start_index:_end_index]

    suitable_days = np.zeros(shape=year_days_to_maturity.shape, dtype=np.uint8)
    suitable_days[np.where(
        (year_days_to_maturity <= max_duration)
        & (year_days_to_maturity != -1)
        & (year_total_prec_in_days_to_maturity >= min_total_prec)
        & (year_total_prec_in_days_to_maturity <= max_total_prec)
    )] = 1

    sum_suitable_days = xr.DataArray(
        suitable_days.sum(axis=(0)).astype(np.uint16),
        coords = {
            'latitude': malawi_temp_data.latitude,
            'longitude': malawi_temp_data.longitude,
        },
        dims = ('latitude', 'longitude'),
    )

    sum_suitable_days = sum_suitable_days.rio.set_spatial_dims('longitude', 'latitude')
    sum_suitable_days = sum_suitable_days.rio.write_crs('epsg:4326')
    sum_suitable_days.rio.to_raster(export_filepath)