# Modeling Crop Yield
## Python modules

In [9]:
## import warnings
import time
import math
import os
import glob
from pyhere import here
from datetime import date

import numpy as np
import pandas as pd
import geopandas
import pickle

import seaborn as sns
import matplotlib.pyplot as plt

import pyarrow
import itertools
import multiprocessing
import p_tqdm

from sklearn.linear_model import Ridge, RidgeCV
from sklearn.model_selection import train_test_split, KFold, LeaveOneGroupOut, cross_val_score, GridSearchCV, cross_val_predict
from sklearn.metrics import r2_score
from scipy.stats import spearmanr,  pearsonr

In [2]:
country_shp = geopandas.read_file(here('data', 'geo_boundaries', 'gadm36_ZMB_2.shp'))
country_shp = country_shp.set_index('district')

crop_df = pd.read_csv(here('data', 'crop_yield', 'cfs_maize_districts_zambia_2009_2022.csv'))
crop_df = crop_df.set_index(['district', 'year'])[['yield_mt']]
                             
weights_4_fn = 'ZMB_cropland_percentage_4k-points.feather'
weights_15_fn = 'ZMB_cropland_percentage_15k-points.feather'
weights_20_fn = 'ZMB_cropland_percentage_20k-points.feather'
  
weights_4 = pd.read_feather(here("data", "land_cover", weights_4_fn))
weights_15 = pd.read_feather(here("data", "land_cover", weights_15_fn))
weights_20 = pd.read_feather(here("data", "land_cover", weights_20_fn))
                           
weights_4.lon, weights_4.lat = round(weights_4.lon, 5), round(weights_4.lat, 5)
weights_15.lon, weights_15.lat = round(weights_15.lon, 5), round(weights_15.lat, 5)
weights_20.lon, weights_20.lat = round(weights_20.lon, 5), round(weights_20.lat, 5)

In [3]:
def get_merged_files(flist, **kwargs):
    return pd.concat([pd.read_feather(f, **kwargs) for f in flist], axis=0).reset_index(drop=True)

def merge_tuple(x, bases = (tuple, list)):
    for e in x:
        if type(e) in bases:
            for e in merge_tuple(e, bases):
                yield e
        else:
            yield e

In [4]:
file_groups = pd.DataFrame()
satellites = ["sentinel-2-l2a","landsat-8-c2-l2","landsat-c2-l2"]
for satellite in satellites:
    
    directory = here("data", "random_features", satellite)
    files = os.listdir(directory)
    files = [f for f in files if f not in ('.gitkeep', '.ipynb_checkpoints')]
    files.sort()
    
    for file in files:
        f = file.split(sep="_")
        d = {
            'satellite'    : f[0],
            'bands'        : f[1].replace("bands-", ""),
            'country_code' : f[2],
            'points'       : int(f[3].replace("k-points", "")),
            'num_features' : f[4].replace("-features", ""),
            'pattern'      : f[0]+'_'+f[1]+'_'+f[2]+'_'+f[3]+'_'+f[4]+'_*'
        }
        df = pd.DataFrame(data=d, index=[0])
        file_groups = pd.concat([file_groups, df])
        
file_groups = file_groups.sort_values(by=['points'], ascending=True)
file_groups = file_groups.drop_duplicates().reset_index(drop=True)
file_groups

Unnamed: 0,satellite,bands,country_code,points,num_features,pattern
0,sentinel-2-l2a,2-3-4,ZMB,4,1000,sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-...
1,landsat-8-c2-l2,1-2-3-4-5-6-7,ZMB,15,1000,landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_15k-po...
2,sentinel-2-l2a,2-3-4-8,ZMB,15,1000,sentinel-2-l2a_bands-2-3-4-8_ZMB_15k-points_10...
3,sentinel-2-l2a,2-3-4,ZMB,15,1000,sentinel-2-l2a_bands-2-3-4_ZMB_15k-points_1000...
4,landsat-c2-l2,r-g-b-nir-swir16-swir22,ZMB,20,1024,landsat-c2-l2_bands-r-g-b-nir-swir16-swir22_ZM...
5,landsat-8-c2-l2,1-2-3-4-5-6-7,ZMB,20,1000,landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_20k-po...
6,sentinel-2-l2a,2-3-4,ZMB,20,1000,sentinel-2-l2a_bands-2-3-4_ZMB_20k-points_1000...


In [5]:
file_groups = file_groups[file_groups.satellite == "landsat-8-c2-l2"]
file_groups = file_groups[file_groups.points == 20]
file_groups

Unnamed: 0,satellite,bands,country_code,points,num_features,pattern
5,landsat-8-c2-l2,1-2-3-4-5-6-7,ZMB,20,1000,landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_20k-po...


In [6]:
names = 'limit_months crop_mask'.split()
paramlist = list(itertools.product([False,True], repeat = len(names)))
paramlist = list(itertools.product(file_groups.pattern.to_list(), paramlist))
for i in range(len(paramlist)):
    paramlist[i] = tuple(merge_tuple(paramlist[i]))
paramlist = [t for t in paramlist if (t[1] == True) & (t[2] == True)][0]
paramlist

('landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_20k-points_1000-features_*',
 True,
 True)

In [7]:
def impute_features(params):
    file         = params[0]
    limit_months = params[1]
    crop_mask    = params[2]
    # weighted_avg = params[3]
    f            = file.split(sep="_")
    satellite    = f[0]
    points       = int(f[3].replace("k-points", ""))
    num_features = int(f[4].replace("-features", ""))
 
    path = str(here("data", "random_features", satellite, file))
    files = glob.glob(pathname=path)
    
    print('Opening')
    
    features = get_merged_files(files)

    year_end = max(features.year)
    
    if satellite == "landsat-c2-l2":
        year_start = 2008
    elif satellite == "landsat-8-c2-l2":
        year_start = 2013 
    else:
        year_start = 2015 
        
    month_range = range(4, 10) if limit_months else range(1, 13)

    if (satellite == "landsat-8-c2-l2") & (limit_months):
        month_start = 4
    else:
        month_start = 10

    keep = np.where(
        ((features.year == year_start) & (features.month >= month_start)) | (features.year > year_start), True, False)

    features = features[keep]

    features['year'] = np.where(
        features['month'].isin([10, 11, 12]),
        features['year'] + 1, 
        features['year']
    )
    features = features[features.year <= year_end]

    features.lon, features.lat = round(features.lon, 5), round(features.lat, 5)

    features = features[features.month.isin(month_range)]

    features = features.set_index(['lon','lat', "year", 'month']).unstack()
    features.columns = features.columns.map(lambda x: '{}_{}'.format(*x))

    features.replace([np.inf, -np.inf], np.nan, inplace=True)
    features.reset_index(inplace = True)

    if points == 4:
        weights = weights_4.copy()
    elif points == 15:
        weights = weights_15.copy()
    elif points == 20:
        weights = weights_20.copy()

    features = features.join(weights.set_index(['lon', 'lat']), on = ['lon', 'lat'])

    if crop_mask:
        features = features[features.crop_perc > 0]
    else:
        pass   

    features = geopandas.GeoDataFrame(
        features, 
        geometry = geopandas.points_from_xy(x = features.lon, y = features.lat), 
        crs='EPSG:4326'
    )

    features = (
        features
        .sjoin(country_shp, how = 'left', predicate = 'within')
        .drop(['geometry'], axis = 1)
        .rename(columns = {"index_right": "district"})
        .dropna(subset=['district'])
        .reset_index(drop = True)
    )

    print('Imputing')
    
    num_cells = len(features) * len(month_range) * int(num_features)
    ln_ft = len(features); ln_na = len(features.dropna())
    features.fillna(features.groupby(['year', 'district'], as_index=False).transform('mean'), inplace=True)

    ln_ft = len(features); ln_na = len(features.dropna())
    features.fillna(features.groupby(['district'], as_index=False).transform('mean'), inplace=True)

    ln_ft = len(features); ln_na = len(features.dropna())
    features = features.dropna(axis=0)

    min_yr = min(features.year); max_yr = max(features.year)
    min_mn = min(month_range);   max_mn = max(month_range)

    f = f'{file[:-1]}yr-{min_yr}-{max_yr}_mn-{min_mn}-{max_mn}_lm-{limit_months}'+\
        f'_cm-{crop_mask}_full.feather'
    full_file = here('data', 'random_features', 'full_files', f)

    print('Saving')
    
    features.reset_index(drop=True).to_feather(full_file)


In [8]:
%%time
impute_features(paramlist)

Opening
Imputing
Saving
CPU times: user 7min 54s, sys: 3min 7s, total: 11min 2s
Wall time: 9min 25s


In [10]:
hot_encode = True
weighhted_avg = True

In [11]:
f = 'landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_20k-points_1000-features_yr-2013-2021_mn-4-9_lm-True_cm-True_full.feather'
fn = here('data', 'random_features', 'full_files', f)
features = pd.read_feather(fn)

In [12]:
drop_cols = ['year', 'lon', 'lat', 'crop_perc', 'district']

if weighhted_avg:
    features = features.set_index(drop_cols)
    features.rename(columns={x:y for x,y in zip(features.columns,range(0,len(features.columns)))}, inplace=True)
    features = features.reset_index()
    features.columns = features.columns.astype(str)

if hot_encode:
    drop_cols.remove('district')
    features = pd.get_dummies(features, columns=["district"], drop_first=False)
else:
    pass

features

predictions = features.copy()[drop_cols]

In [13]:
model_fn_suffix = f.replace('_full.feather', '')+ '_wa-True_he-True'
model_fn_suffix 

k_model_fn = f'k-fold-cv_rr-model_{model_fn_suffix}.pkl'
logo_model_fn = f'logo-cv_rr-model_{model_fn_suffix}.pkl'

# with open(here('models', k_model_fn),'wb') as f:
#     pickle.dump(best_kfold_model, f)
        
with open(here('models', k_model_fn), 'rb') as f:
    best_kfold_model = pickle.load(f)
    
x_all = features.drop(drop_cols, axis = 1) 
predictions['prediction'] = best_kfold_model.predict(x_all)

In [149]:
f_pred = f'high-res-pred_k-fold-cv_{model_fn_suffix}.feather'
fn = here('data', 'results', f_pred)
predictions.to_feather(str(fn))

In [150]:
predictions

Unnamed: 0,year,lon,lat,crop_perc,prediction
0,2013,22.07488,-14.86423,0.12790,0.132330
1,2014,22.07488,-14.86423,0.12790,0.158144
2,2015,22.07488,-14.86423,0.12790,0.149651
3,2016,22.07488,-14.86423,0.12790,0.164380
4,2017,22.07488,-14.86423,0.12790,0.225569
...,...,...,...,...,...
176377,2017,33.52488,-10.32423,0.30577,0.737005
176378,2018,33.52488,-10.32423,0.30577,0.641154
176379,2019,33.52488,-10.32423,0.30577,0.595801
176380,2020,33.52488,-10.32423,0.30577,0.597986


In [152]:
max(crop_df.log_yield)

1.0804510383588988

In [125]:
# fig, ax = plt.subplots(figsize=(12, 10))
# country_shp.boundary.plot(ax = ax, edgecolor = "black")
# plt.scatter(predictions.lon, predictions.lat,  c=predictions.crop_perc, s=.3, marker = ',')

In [126]:
# predictions_gdf = geopandas.GeoDataFrame(predictions, geometry=geopandas.points_from_xy(predictions.lon, predictions.lat))
# predictions_gdf[predictions_gdf.year == 2013].plot(column = 'predictions', markersize=1)

In [127]:
# crop_df = crop_df.join(country_shp).reset_index()
# crop_df['log_yield'] = np.log10(crop_df.yield_mt.to_numpy() + 1)
# crop_df = geopandas.GeoDataFrame(crop_df)

In [128]:
# crop_df['log_yield'] = np.log10(crop_df.yield_mt.to_numpy() + 1)
# crop_df

In [None]:
# min_yield, max_yield = min(crop_df.log_yield), max(crop_df.log_yield)

In [123]:
# def scatter(x, y, c, **kwargs):
#     del kwargs["color"]
#     fig, ax = plt.subplots()
#     country_shp.boundary.plot(ax = ax, edgecolor = "black")
#     plt.scatter(x, y, c = c, **kwargs)
    
# def yield_plot(**kwargs):
#     del kwargs["color"]
#     geopandas.GeoDataFrame.plot('yield_mt',**kwargs)

In [32]:
# g = sns.FacetGrid(
#     crop_df,
#     col='year',
#     col_wrap=3,
#     height=4, 
#     aspect=1
# )
# g.map(yield_plot) 

In [72]:
# import matplotlib.colors as colors

In [129]:
# years = range(min(predictions.year), max(predictions.year)+1)

# fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(12, 12),
#                         # constrained_layout=True, 
#                         sharex=True, sharey=True, 
#                         subplot_kw=dict(aspect='equal'))
# plt.subplots_adjust(hspace=0.5)
# fig.suptitle("Log Yield", fontsize=18, y=0.95, x=.35)
# fig.tight_layout()

# for year, ax in zip(years, axs.ravel()):
#     crop_df[crop_df["year"] == year].plot(
#         ax=ax, 
#         column = "log_yield", 
#         norm=colors.Normalize(vmin= min_yield, vmax=max_yield)
#     )
#     ax.set_title(year)
#     ax.set_xlabel("")
    
# axs = axs.ravel()
# patch_col = axs[0].collections[0]
# cb = fig.colorbar(patch_col, ax=axs, shrink=0.5)

In [130]:
# g = sns.FacetGrid(
#     predictions,
#     col='year',
#     col_wrap=3,
#     height=4, 
#     aspect=1
# )
# g.map_dataframe(scatter, 'lon', 'lat', 'predictions', s = .25)  

In [131]:
# country_shp.boundary.plot(edgecolor = "black")

In [77]:
# fig, ((ax1,ax2,ax3) ,(ax4,ax5,ax6),(ax7,ax8,ax9)) = plt.subplots(nrows=3, ncols=3, figsize=(15, 20))
# ax1 = (crop_df[crop_df.year == 2013]
#        .plot(ax = ax1, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2013 log_yields"))
# ax2 = (crop_df[crop_df.year == 2014]
#        .plot(ax = ax2, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2014 log_yields"))
# ax3 = (crop_df[crop_df.year == 2015]
#        .plot(ax = ax3, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2015 log_yields"))
# ax4 = (crop_df[crop_df.year == 2016]
#        .plot(ax = ax4, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2016 log_yields"))
# ax5 = (crop_df[crop_df.year == 2017]
#        .plot(ax = ax5, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2017 log_yields"))
# ax6 = (crop_df[crop_df.year == 2018]
#        .plot(ax = ax6, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2018 log_yields"))
# ax7 = (crop_df[crop_df.year == 2019]
#        .plot(ax = ax7, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2019 log_yields"))
# ax8 = (crop_df[crop_df.year == 2020]
#        .plot(ax = ax8, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2020 log_yields"))
# ax9 = (crop_df[crop_df.year == 2021]
#        .plot(ax = ax9, column = "log_yield", legend = True, norm=colors.Normalize(vmin= min_yield, vmax=max_yield))
#        .set_title("2021 log_yields"))

# caption = "A positive value is an underestimated prediction (the prediction is lower than the actual yield), a negative value is an over estimated prediction"
# plt.figtext(0.5, 0.01, caption, wrap=True, horizontalalignment='center', fontsize=12)
# fig.tight_layout()
# handles, labels = ax1.get_legend_handles_labels()
# fig.legend(handles, labels, loc='upper center')
# plt.figlegend(loc = 'lower center', ncol=5, labelspacing=0.)

# from matplotlib.legend import _get_legend_handles_labels
# ...
# fig.legend(*_get_legend_handles_and_labels(fig.axes), ...)