# Modeling Crop Yield
## Python modules

In [None]:
import warnings
import time
import os
import glob

import dask
from dask.distributed import Client
import multiprocessing

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import itertools

import geopandas

import pyarrow
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.impute import SimpleImputer
from scipy.stats import spearmanr
from scipy.linalg import LinAlgWarning
from scipy.stats import pearsonr

from pyhere import here

import math
import seaborn as sns

from pyhere import here

In [35]:
country_shp = geopandas.read_file(here('data', 'geo_boundaries', 'gadm36_ZMB_2.shp'))
country_shp = country_shp.set_index('district')

crop_df = pd.read_csv(here('data', 'crop_yield', 'cfs_maize_districts_zambia_2009_2022.csv'))
crop_df = crop_df.set_index(['district', 'year'])[['yield_mt']]
                             
weights_4_fn = 'ZMB_cropland_percentage_4k-points.feather'
weights_15_fn = 'ZMB_cropland_percentage_15k-points.feather'
weights_20_fn = 'ZMB_cropland_percentage_20k-points.feather'
  
weights_4 = pd.read_feather(here("data", "land_cover", weights_4_fn))
weights_15 = pd.read_feather(here("data", "land_cover", weights_15_fn))
weights_20 = pd.read_feather(here("data", "land_cover", weights_20_fn))
                           
weights_4.lon, weights_4.lat = round(weights_4.lon, 5), round(weights_4.lat, 5)
weights_15.lon, weights_15.lat = round(weights_15.lon, 5), round(weights_15.lat, 5)
weights_20.lon, weights_20.lat = round(weights_20.lon, 5), round(weights_20.lat, 5)

In [36]:
def get_merged_files(flist, **kwargs):
    return pd.concat([pd.read_feather(f, **kwargs) for f in flist], axis=0).reset_index(drop=True)

def merge_tuple(x, bases = (tuple, list)):
    for e in x:
        if type(e) in bases:
            for e in merge_tuple(e, bases):
                yield e
        else:
            yield e

In [40]:
file_groups = pd.DataFrame()
satellites = ["sentinel-2-l2a","landsat-8-c2-l2","landsat-c2-l2"]
for satellite in satellites:
    
    directory = here("data", "random_features", satellite)
    files = os.listdir(directory)
    files = [f for f in files if f not in ('.gitkeep', '.ipynb_checkpoints')]
    files.sort()
    
    for file in files:
        f = file.split(sep="_")
        d = {
            'satellite'    : f[0],
            'bands'        : f[1].replace("bands-", ""),
            'country_code' : f[2],
            'points'       : int(f[3].replace("k-points", "")),
            'num_features' : f[4].replace("-features", ""),
            'pattern'      : f[0]+'_'+f[1]+'_'+f[2]+'_'+f[3]+'_'+f[4]+'_*'
        }
        df = pd.DataFrame(data=d, index=[0])
        file_groups = pd.concat([file_groups, df])
        
file_groups = file_groups.sort_values(by=['points'], ascending=True)
file_groups = file_groups.drop_duplicates().reset_index(drop=True)
file_groups

Unnamed: 0,satellite,bands,country_code,points,num_features,pattern
0,sentinel-2-l2a,2-3-4,ZMB,4,1000,sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-...
1,landsat-8-c2-l2,1-2-3-4-5-6-7,ZMB,15,1000,landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_15k-po...
2,sentinel-2-l2a,2-3-4-8,ZMB,15,1000,sentinel-2-l2a_bands-2-3-4-8_ZMB_15k-points_10...
3,sentinel-2-l2a,2-3-4,ZMB,15,1000,sentinel-2-l2a_bands-2-3-4_ZMB_15k-points_1000...
4,landsat-c2-l2,r-g-b-nir-swir16-swir22,ZMB,20,1024,landsat-c2-l2_bands-r-g-b-nir-swir16-swir22_ZM...
5,landsat-8-c2-l2,1-2-3-4-5-6-7,ZMB,20,1000,landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_20k-po...
6,sentinel-2-l2a,2-3-4,ZMB,20,1000,sentinel-2-l2a_bands-2-3-4_ZMB_20k-points_1000...


In [41]:
# file_groups = file_groups[file_groups.index == 0]
# file_groups

In [42]:
names = 'limit_months crop_mask weighted_avg'.split()
paramlist = list(itertools.product([False,True], repeat = len(names)))
paramlist = list(itertools.product(file_groups.pattern.to_list(), paramlist))
for i in range(len(paramlist)):
    paramlist[i] = tuple(merge_tuple(paramlist[i]))
len(paramlist)

56

In [30]:
def summarize_features(params):
    file         = params[0]
    limit_months = params[1]
    crop_mask    = params[2]
    weighted_avg = params[3]
    f            = file.split(sep="_")
    satellite    = f[0]
    points       = int(f[3].replace("k-points", ""))
    num_features = int(f[4].replace("-features", ""))
 
    path = str(here("data", "random_features", satellite, file))
    files = glob.glob(pathname=path)
    
    features = get_merged_files(files)

    year_end = max(features.year)
    
    if satellite == "landsat-c2-l2":
        year_start = 2008
    elif satellite == "landsat-8-c2-l2":
        year_start = 2013 
    else:
        year_start = 2015 
        
    month_range = range(4, 10) if limit_months else range(1, 13)

    if (satellite == "landsat-8-c2-l2") & (limit_months):
        month_start = 4
    else:
        month_start = 10

    keep = np.where(
        ((features.year == year_start) & (features.month >= month_start)) | (features.year > year_start), True, False)

    features = features[keep]

    features['year'] = np.where(
        features['month'].isin([10, 11, 12]),
        features['year'] + 1, 
        features['year']
    )
    features = features[features.year <= year_end]

    features.lon, features.lat = round(features.lon, 5), round(features.lat, 5)

    features = features[features.month.isin(month_range)]

    features = features.set_index(['lon','lat', "year", 'month']).unstack()
    features.columns = features.columns.map(lambda x: '{}_{}'.format(*x))

    features.replace([np.inf, -np.inf], np.nan, inplace=True)
    features.reset_index(inplace = True)

    if points == 4:
        weights = weights_4.copy()
    elif points == 15:
        weights = weights_15.copy()
    elif points == 20:
        weights = weights_20.copy()

    features = features.join(weights.set_index(['lon', 'lat']), on = ['lon', 'lat'])

    if crop_mask:
        features = features[features.crop_perc > 0]
    else:
        pass   

    features = geopandas.GeoDataFrame(
        features, 
        geometry = geopandas.points_from_xy(x = features.lon, y = features.lat), 
        crs='EPSG:4326'
    )

    features = (
        features
        .sjoin(country_shp, how = 'left', predicate = 'within')
        .drop(['geometry', 'lon', 'lat'], axis = 1)
        .rename(columns = {"index_right": "district"})
        .dropna(subset=['district'])
        .reset_index(drop = True)
    )

    num_cells = len(features) * len(month_range) * int(num_features)
    ln_ft = len(features); ln_na = len(features.dropna())
    features.fillna(features.groupby(['year', 'district'], as_index=False).transform('mean'), inplace=True)

    ln_ft = len(features); ln_na = len(features.dropna())
    features.fillna(features.groupby(['district'], as_index=False).transform('mean'), inplace=True)

    ln_ft = len(features); ln_na = len(features.dropna())
    features = features.dropna(axis=0)

    if weighted_avg:
        feature_cols = features.columns[1:-2].values.tolist()
        features_summary = (
            features
            .groupby(['year', 'district'], as_index=False)
            .apply(lambda x: pd.Series([sum(x[feature] * x.crop_perc) / sum(x.crop_perc) for feature in feature_cols]))
        )
    else:
        features_summary = features.groupby(['district',"year"], as_index = False).mean()  

    features_summary = features_summary.set_index(["district", "year"]).join(other = crop_df).reset_index()

    features_summary.columns = features_summary.columns.astype(str)

    features_summary = features_summary[~features_summary.isna().any(axis = 1)]

    min_yr = min(features_summary.year); max_yr = max(features_summary.year)
    min_mn = min(month_range); max_mn = max(month_range)

    f = f'{file[:-1]}yr-{min_yr}-{max_yr}_mn-{min_mn}-{max_mn}_lm-{limit_months}'+\
        f'_cm-{crop_mask}_wa-{weighted_avg}_summary.feather'
    full_file = here('data', 'random_features', 'summary', f)
    
    features_summary.reset_index(drop=True).to_feather(full_file)


In [32]:
%%time
workers = 24 # os.cpu_count() # Uses too much memory
if __name__ == "__main__":
    with multiprocessing.Pool(processes=workers) as pool:
        pool.map(summarize_features, paramlist)


file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: False
crop_mask:    False
weighted_avg: False

file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: True
crop_mask:    False
weighted_avg: True

file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: True
crop_mask:    True
weighted_avg: True

file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: False
crop_mask:    False
weighted_avg: True

file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: False
crop_mask:    True
weighted_avg: False

file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: False
crop_mask:    True
weighted_avg: True



file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: True
crop_mask:    True
weighted_avg: False

file pattern: sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_*
limit_months: