In [None]:
# general
import time
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import warnings
from importlib import reload

# modeling
import statsmodels
import statsmodels.api as sm
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import log_loss
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.preprocessing import StandardScaler
import shap

# custom
from scripts import HyperparameterTuning as HT
from scripts import Plots, Metrics

# Load and pre-process data

In [None]:
# options
env_path = 'data/environmental_data.csv'
bat_path = 'data/bat-level_data.csv'
random_state = 1337
dataset = 'env' # 'env', 'bat'
target = 'shortage'
rolling_forecast = True

# load configs
env_features = []
with open('config/env_features.txt', 'r') as f:
    for line in f:
        env_features.append(line.strip())
bat_features = []
with open('config/bat_features.txt', 'r') as f:
    for line in f:
        bat_features.append(line.strip())
rename = {}
with open('config/rename.tsv', 'r') as f:
    for line in f:
        line = line.strip().split('\t')
        rename[line[0]] = line[1]
    rename['Intercept'] = 'Intercept'

# book keeping
features = env_features if dataset == 'env' else bat_features

# load both datasets, sort, and merge
df_env = pd.read_csv(env_path).sort_values(
    by=['cal_year', 'cal_month'], ascending=[True, True])
df_bat = pd.read_csv(bat_path).sort_values(
    by=['cal_year', 'cal_month'], ascending=[True, True])
df = pd.merge(df_env, df_bat, on=['cal_year', 'cal_month'], how='outer')

# convert categories to binary and drop missing
df = df.replace('shortage', 1).replace('not_shortage', 0)
df = df.dropna(subset=[target])
df[target] = df[target].astype(int)

# include season as additional feature
to_season = {
    12: 1,  1: 1,  2: 1,
     3: 2,  4: 2,  5: 2,
     6: 3,  7: 3,  8: 3,
     9: 4, 10: 4, 11: 4,}
df['cal_season'] = df['cal_month'].apply(lambda x: to_season[x]).astype(float)

# drop rows with >50% missing features
df = df.dropna(thresh=0.5*len(features), subset=features)
df_missing = df.copy()

# impute missing values
n_missing = df[features].isna().sum().sum()
n_total = len(df) * len(features)
imp = IterativeImputer(max_iter=100, random_state=random_state)
df[features] = imp.fit_transform(df[features])

# start integer features from 0 and convert to categorical
for feature in ['cal_year', 'cal_month']:
    df[feature] = df[feature].astype(int)
for feature in ['cal_season', 'cal_month']:
    df[feature] = (df[feature] - df[feature].min()).astype('category')

# reset df index starting from 0
df = df.reset_index(drop=True)

# summary
print('Missing proportion: {:.2f}% values'.format(n_missing / n_total * 100))
print('Data shape: {} months, {} features'.format(*df[features].shape))
print('Food shortages: {} / {} ({:.2f}%)'.format(
    df[target].sum(), len(df), df[target].sum() / len(df) * 100))

# Split data into train/val/test sets

In [None]:
# options
start_year = df['cal_year'].min() + 4
test_year = 2018
split_size = 2

# convert dataframes to numpy arrays for modeling
X, y = df[features].values, df[target].values
years = df['cal_year'].values

# get indices of training and validation sets
split_years = np.arange(start_year, test_year, split_size)
if rolling_forecast:
    train_idx = np.argwhere(years < test_year).flatten()
else:
    train_idx = np.argwhere((years >= test_year - 4) & (years < test_year)).flatten()
train_idxs, val_idxs = [], []
for i, sy in enumerate(split_years):
    if rolling_forecast:
        train_idxs.append(np.argwhere(years < sy).flatten())
    else:
        train_idxs.append(np.argwhere((years >= sy-4) & (years < sy)).flatten())
    val_idxs.append(np.argwhere((years >= sy) * (years < sy+split_size)).flatten())
test_idx = np.argwhere(years >= test_year).flatten()

# compute train/val score weights
weights = [len(t_idx) for t_idx in train_idxs]

# plot train/val/test splits
Plots.cv_splits(
    years, 
    cv_splits=[train_idxs, val_idxs, test_idx], 
    dataset=dataset)

# Train GLM using hyperparameter grid search

In [None]:
# suppress PerfectSeparationWarning from statsmodels
warnings.filterwarnings(
    "ignore", 
    category=statsmodels.tools.sm_exceptions.PerfectSeparationWarning)

# custom GLM classifier mimicking LightGBMClassifier
class GLMClassifier(BaseEstimator, ClassifierMixin):

    def __init__(
        self,
        add_intercept=True,
        max_iter=100,
        random_state=42
    ):
        
        # initialize
        self.add_intercept = add_intercept
        self.max_iter = max_iter
        self.random_state = random_state
        self.result_ = None
        self.eps = np.sqrt(np.finfo(float).eps)

        # copy lgb evaluation metric for hyperparameter tuning script
        self.evals_result_ = {'valid_0': {'binary_logloss': []}}

    def fit(self, X, y, eval_set=None, **kwargs):
        
        # optional intercept
        X = sm.add_constant(X, prepend=True) if self.add_intercept else X

        # logistic regression GLM
        model = sm.GLM(
            y,
            X,
            family=sm.families.Binomial()
        )

        # fit GLM
        self.result_ = model.fit(maxiter=self.max_iter, **kwargs)

        # validation eval (if provided)
        if eval_set is not None:
            X_val, y_val = eval_set
            X_val = sm.add_constant(X_val, prepend=True) if self.add_intercept else X_val
            val_loss = log_loss(y_val, self.predict_proba(X_val)[:, 1], labels=[0, 1])
        else:
            val_loss = np.nan
        self.evals_result_['valid_0']['binary_logloss'].append(val_loss)

        return self

    def predict_proba(self, X):
        
        # preprocess
        X = sm.add_constant(X, prepend=True) if self.add_intercept else X

        # logistic regression returns p = P(y=1)
        p = self.result_.predict(X)
        p = np.clip(p, self.eps, 1 - self.eps)  # numeric stability
        return np.column_stack([1 - p, p])

    def predict(self, X, threshold=0.5):
        probs = self.predict_proba(X)[:, 1]
        return (probs >= threshold).astype(int)

    def evaluate(self, X, y):
        
        # check previously stored validation losses
        val_losses = self.evals_result_['valid_0']['binary_logloss']
        valid_numeric_losses = [v for v in val_losses if not np.isnan(v)]

        # compute log loss on given data if no valid logs are found
        if len(valid_numeric_losses) > 0:
            return min(valid_numeric_losses)
        else:
            if self.add_intercept:
                X = sm.add_constant(X, prepend=True)
            p = self.predict_proba(X)[:, 1]
            return log_loss(y, p, labels=[0, 1])

In [None]:
#
# tune hyperparameters
#

# ignore RuntimeWarning: overflow encountered in exp
warnings.filterwarnings("ignore", category=RuntimeWarning)

model_class = GLMClassifier

# hyperparameter grid
model_hyper = {
    'add_intercept': [True, False],
    'max_iter': list(np.arange(100)),
}
model_fixed = {
    'random_state': 42,
}

# Training hyperparams can remain empty if you have no extra parameters
train_hyper = {}
train_fixed = {}

model_best, train_best, model_list, score_list = HT.hyperparameter_tuning(
    model_class=model_class,
    model_hyper=model_hyper,
    model_fixed=model_fixed,
    train_hyper=train_hyper,
    train_fixed=train_fixed,
    x=X,
    y=y,
    train_indices=train_idxs,
    val_indices=val_idxs,
    weights=None,
    verbose=True,
    normalize=True,
)

print("Best model hyperparameters:", model_best)
print("Best training hyperparameters:", train_best)

# Plot model predictions

In [None]:
reload(Plots)
reload(Metrics)

def predict(train_idx, val_idx):

    # center and scale
    scaler = StandardScaler().fit(X[train_idx])
    x_train = scaler.transform(X[train_idx])
    x_val = scaler.transform(X[val_idx])

    # fit model
    model = model_class(
        **model_best, 
        **model_fixed,
    ).fit(
        X=x_train,
        y=y[train_idx],
        eval_set=(x_val, y[val_idx]),
        **train_best,
        **train_fixed,
    )

    # predict probabilities
    y_pred_train = model.predict_proba(x_train)[:, 1]
    y_pred_val = model.predict_proba(x_val)[:, 1]

    return y_pred_train, y_pred_val, model

# compute predictions for each set
y_prob_train = [predict(t, v)[0] for t, v in zip(train_idxs, val_idxs)]
y_prob_train.append(predict(train_idx, test_idx)[0])
y_prob_val = np.concatenate([predict(t, v)[1] for t, v in zip(train_idxs, val_idxs)])
y_prob_test = predict(np.arange(test_idx[0]), test_idx)[1]
y_prob = np.concatenate([y_prob_train[0], y_prob_val, y_prob_test])

# compute thresholds for each set
thresholds = [Metrics.get_threshold(y[v], y_prob[v]) for v in val_idxs]

# create dates from years and months
dates = df['cal_year'].astype(str) + '-' + df['cal_month'].astype(int).astype(str)
dates = pd.Series(dates.values, index=np.arange(len(dates)))

# compute optimal probability threshold based on f1 score
val_idx = np.concatenate(val_idxs)
threshold = Metrics.get_threshold(y[val_idx], y_prob_val)
y_pred = (y_prob > threshold).astype(int)

# append final threshold
thresholds = thresholds + [threshold]

# plot predictions over time
bd = 'bd'[dataset == 'bat']
Plots.predictions(
    y_true=y, 
    y_probs=[y_prob_train[0], y_prob_val, y_prob_test],
    cv_splits=[train_idxs, val_idxs, test_idx],
    dates=dates, 
    threshold=threshold, 
    dataset=dataset,
    save_name=f'figures/Fig1{bd}.pdf' if rolling_forecast else None,
)

# plot train/val/test splits separately
fig = ['SI_Fig6', 'SI_Fig7'][not rolling_forecast]
bd = 'bd'[dataset == 'bat']
Plots.predictions_subplots(
    y_true=y, 
    y_probs=[y_prob_train, y_prob_val, y_prob_test],
    cv_splits=[train_idxs+[train_idx], val_idxs, test_idx],
    dates=dates, 
    thresholds=thresholds, 
    dataset=dataset,
    save_name=f'figures/{fig}{bd}.pdf',
)

# Print model performance metrics

In [None]:
reload(Metrics)

print('------------------------------------------------------------------')
print(f'All validation folds and test set')
print('------------------------------------------------------------------')
print()
Metrics.print_metrics(
    y, 
    y_prob, 
    dates.values, 
    val_idx=val_idx, 
    test_idx=test_idx,
    decimal=3)

In [None]:
reload(Metrics)
import matplotlib.pyplot as plt

for i in range(len(val_idxs)):
    start_date = dates.values[val_idxs[i][0]]
    end_date = dates.values[val_idxs[i][-1]]
    print('------------------------------------------------------------------')
    print(f'Validation period: {start_date} to {end_date}')
    print('------------------------------------------------------------------')
    print()
    precision, recall, thresholds = Metrics.precision_recall(y[val_idxs[i]], y_prob[val_idxs[i]])
    f1 = 2 * (precision * recall) / (precision + recall).clip(1e-8)
    threshold_fold = Metrics.get_threshold(y[val_idxs[i]], y_prob[val_idxs[i]])
    if np.isnan(threshold_fold):
        threshold_fold = 0.5
    y_pred_fold = (y_prob[val_idxs[i]] >= threshold_fold).astype(int)
    plt.plot(y_prob[val_idxs[i]], 'k-o')
    plt.plot(y[val_idxs[i]], '-o', color='tab:blue', alpha=0.5)
    plt.plot(y_pred_fold, '-o', color='tab:orange', alpha=0.5)
    plt.fill_between(
        np.arange(len(val_idxs[i])), 
        0, y[val_idxs[i]], 
        color='tab:blue', alpha=0.5)
    plt.fill_between(
        np.arange(len(val_idxs[i])), 
        0, y_pred_fold, 
        color='tab:orange', alpha=0.5)
    plt.axhline(threshold_fold, color='red', linestyle='--')
    plt.ylim([-0.05, 1.05])
    plt.xlim([-1, len(val_idxs[i])])
    plt.ylabel('Predicted probability')
    plt.xlabel('Month')
    plt.title(f'Validation period: {start_date} to {end_date}')
    plt.show()
    print()
    Metrics.print_metrics(
        y,
        y_prob,
        dates.values,
        val_idx=val_idxs[i],
        decimal=3,
        threshold=0.5 if y[val_idxs[i]].sum() == 0 else None,
    )

# Plot precision-recall curves

In [None]:
reload(Metrics)

import matplotlib
import matplotlib.pyplot as plt

if dataset == 'env':
    fig, axs = plt.subplots(2, 5, figsize=(4*6.4, 2*4.8))
if dataset == 'bat':
    fig, axs = plt.subplots(1, 5, figsize=(4*6.4, 1*4.8))
axs = axs.flatten()

for i in range(len(val_idxs)):

    start_date = dates.values[val_idxs[i][0]].split('-')[0]
    end_date = dates.values[val_idxs[i][-1]].split('-')[0]
    y_true_train = y[train_idxs[i]]
    y_pred_train = y_prob_train[i]
    y_true_val = y[val_idxs[i]]
    y_pred_val = y_prob[val_idxs[i]]
    balanced_class_weights = len(y_true_train) / (2 * np.bincount(y_true_train))

    if i % 5 == 0:
        axs[i].set_ylabel('Precision', fontsize=20)
    if i >= 5 or dataset == 'bat':
        axs[i].set_xlabel('Threshold', fontsize=20)
    axs[i].set_title(f'Fold {i+1} ({start_date} - {end_date})')
    axs[i].set_xlim([-0.03, 1.03])
    axs[i].set_ylim([-0.03, 1.03])

    if y_true_val.sum() == 0:
        continue

    p, r, t = Metrics.precision_recall(y_true_val, y_pred_val)
    f = 2 * (p * r) / (p + r).clip(1e-8)

    axs[i].plot(t, p[:-1], 'k', label='Precision')
    axs[i].plot(t, r[:-1], 'k', linestyle='--', label='Recall')
    axs[i].plot(t, f[:-1], 'k', linestyle=':', label='F1 Score')
    axs[i].scatter(t[np.argmax(f[:-1])], max(f), 
                   s=200, color='k', zorder=np.inf, label='Optimal')
    axs[i].legend()


start_date = dates.values[val_idx[0]].split('-')[0]
end_date = dates.values[val_idx[-1]].split('-')[0]
y_true_val = y[val_idx]
y_pred_val = y_prob[val_idx]
balanced_class_weights = len(y_true_val) / (2 * np.bincount(y_true_val))

p, r, t = Metrics.precision_recall(y_true_val, y_pred_val)
f = 2 * (p * r) / (p + r).clip(1e-8)

axs[-1].plot(t, p[:-1], 'k', label='Precision')
axs[-1].plot(t, r[:-1], 'k', linestyle='--', label='Recall')
axs[-1].plot(t, f[:-1], 'k', linestyle=':', label='F1 Score')
axs[-1].scatter(t[np.argmax(f[:-1])], max(f), 
                s=200, color='k', zorder=np.inf, label='Optimal')
axs[-1].legend()

axs[-1].set_xlabel('Threshold')
axs[-1].set_title(f'All folds ({start_date} - {end_date})')
axs[-1].set_xlim([-0.03, 1.03])
axs[-1].set_ylim([-0.03, 1.03])

plt.tight_layout()
if rolling_forecast:
    bd = 'bd'[dataset == 'bat']
    plt.savefig(
        f'figures/SI_Fig5{bd}.pdf', 
        format='pdf', bbox_inches='tight')
plt.show()

# Plot feature importance

In [None]:
import matplotlib
import matplotlib.pyplot as plt

# add intercept to features if it was optimal
if model_best['add_intercept']:
    _features = ['Intercept'] + features
else:
    _features = features

# train model on full training set
model = predict(np.arange(test_idx[0]), test_idx)[-1]

# unpack GLM coefficients
feats = np.array([rename[feature] for feature in _features])[::-1]
coeffs = np.array(model.result_.params[::-1])

# order from largest to smallest
order = np.argsort(coeffs)
feats, coeffs = feats[order], coeffs[order]

# only keep significant features
keep = np.abs(coeffs) > 0.2
feats, coeffs = feats[keep], coeffs[keep]

def plot_glm_coeffs(features, coeffs):

    # book keeping
    y_pos = np.arange(len(features))
    coeffs_normalized = (coeffs / np.max(np.abs(coeffs)) + 1) / 2 # to [0, 1]
    colors = matplotlib.colormaps['coolwarm'](coeffs_normalized)
    switch = (coeffs > 0).sum() - 0.5

    if dataset == 'bat':
        switch -= 1

    # initialize figure
    plt.style.use('fivethirtyeight')
    height = 0.3 * len(features)
    height = height + 1 if dataset == 'bat' else height
    fig, ax = plt.subplots(figsize=(10, height))

    # coefficients
    ax.hlines(y=y_pos, xmin=0, xmax=coeffs, color="gray", lw=1)
    ax.scatter(coeffs, y_pos, s=200, c=colors, edgecolor="black", zorder=np.inf)

    # zero line
    ax.axvline(0, color="gray", lw=3, zorder=10)

    # positive to negative switch line
    ax.axhline(switch, color="gray", lw=1, zorder=10, linestyle='--')

    # format axes
    ax.set_yticks(y_pos)
    ax.set_yticklabels(features)
    ax.grid(True, axis="both", color="0.9")
    for spine in ["top", "right", "left", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.tick_params(axis="both", length=0)
    ax.set_xlabel("Coefficient (log-odds)")
    if dataset == 'bat':
        ax.set_ylim([-1, len(features)])
        ax.set_xlim([-2.5, 2.5])
    else:
        ax.set_xlim([-6.5, 6.5])

    plt.tight_layout()
    return fig, ax

plot_glm_coeffs(feats, coeffs)
if rolling_forecast:
    ab = 'ab'[dataset == 'bat']
    plt.savefig(
        f'figures/SI_Fig8{ab}.pdf', 
        format='pdf', bbox_inches='tight')
plt.show()

In [None]:
model.result_.summary()