In [2]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, StackingRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
import joblib
import pandas as pd

df = pd.read_csv("../data/final/Ghana_Maize_Master_Dataset.csv")

X = df[['Rainfall', 'Temperature', 'Humidity', 'Sunlight', 'Soil_Moisture', 
        'Soil_Type', 'Pest_Risk', 'PFJ_Policy', 'Year', 'Yield_Lag1']]
y = df['Yield']

numeric_cols = [c for c in X.columns if c != "Soil_Type"]
categorical_cols = ["Soil_Type"]

preprocessor = ColumnTransformer([
    ('num', StandardScaler(), numeric_cols),
    ('cat', OneHotEncoder(handle_unknown="ignore"), categorical_cols)
])

# Strong base learners
estimators = [
    ('rf', RandomForestRegressor(
        n_estimators=300, random_state=299, min_samples_split=2
    )),
    ('gb', GradientBoostingRegressor(
        n_estimators=300, learning_rate=0.05, random_state=299
    )),
    ('et', ExtraTreesRegressor(
        n_estimators=300, random_state=299
    ))
]

# Meta-learner learns how to combine base models
stacked = StackingRegressor(
    estimators=estimators,
    final_estimator=LinearRegression(),
    passthrough=True
)

model_pipeline = Pipeline([
    ('preprocessor', preprocessor),
    ('regressor', stacked)
])

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=537
)
#randome_state=22,1000, 93
model_pipeline.fit(X_train, y_train)

# Evaluate
y_pred = model_pipeline.predict(X_test)
r2 = r2_score(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)

print(f"Model Accuracy (R2): {r2:.3f}")
print(f"Error Margin (MAE): {mae:.3f} tons/ha")

joblib.dump(model_pipeline, "../models/maize_model.pkl")
print("Model saved.")


Model Accuracy (R2): 0.957
Error Margin (MAE): 0.056 tons/ha
Model saved.


In [3]:
# =============================================================================
# 6. VISUALIZATIONS: Save plots to ../reports
# =============================================================================
import os
import matplotlib.pyplot as plt
import seaborn as sns
os.makedirs("../reports", exist_ok=True)
df = pd.read_csv("../data/final/Ghana_Maize_Master_Dataset.csv")
# Set plotting style robustly (fallbacks if style name not available)
try:
    plt.style.use('seaborn-darkgrid')
except Exception:
    try:
        sns.set_theme(style='darkgrid')
    except Exception:
        try:
            sns.set_style('darkgrid')
        except Exception:
            plt.style.use('ggplot')

# 1) National average yield trend by Year
yr_df = df.groupby('Year')['Yield'].mean().reset_index()
plt.figure(figsize=(8,4))
sns.lineplot(data=yr_df, x='Year', y='Yield', marker='o')
plt.title('National Average Maize Yield by Year')
plt.ylabel('Yield (tons/ha)')
plt.xlabel('Year')
plt.savefig('../reports/national_yield_trend.png', bbox_inches='tight')
plt.close()

# 2) Yield distribution by Soil Type (boxplot)
plt.figure(figsize=(10,6))
sns.boxplot(data=df, x='Soil_Type', y='Yield')
plt.xticks(rotation=45, ha='right')
plt.title('Yield Distribution by Soil Type')
plt.xlabel('Soil Type')
plt.ylabel('Yield (tons/ha)')
plt.savefig('../reports/yield_by_soil_type.png', bbox_inches='tight')
plt.close()

# 3) Rainfall (Apr-Aug) vs Yield with regression line
plt.figure(figsize=(8,6))
sns.regplot(data=df, x='Rainfall', y='Yield', scatter_kws={'s':10}, line_kws={'color':'red'})
plt.title('Rainfall (Apr-Aug) vs Yield')
plt.xlabel('Rainfall (mm)')
plt.ylabel('Yield (tons/ha)')
plt.savefig('../reports/rainfall_vs_yield.png', bbox_inches='tight')
plt.close()

# 4) Correlation heatmap among key numeric features
num_cols = ['Yield','Rainfall','Temperature','Humidity','Sunlight','Soil_Moisture','Yield_Lag1']
corr = df[num_cols].corr()
plt.figure(figsize=(8,6))
sns.heatmap(corr, annot=True, fmt='.2f', cmap='coolwarm')
plt.title('Feature Correlation Matrix')
plt.savefig('../reports/feature_correlation.png', bbox_inches='tight')
plt.close()

# 5) Feature importances (attempt to extract from RandomForest in ensemble)
try:
    reg = model_pipeline.named_steps['regressor']
    rf = None
    if hasattr(reg, 'estimators_') and len(reg.estimators_)>0:
        candidate = reg.estimators_[0]
        if isinstance(candidate, tuple):
            rf = candidate[1]
        else:
            rf = candidate

    if rf is not None and hasattr(rf, 'feature_importances_'):
        importances = rf.feature_importances_
        pre = model_pipeline.named_steps['preprocessor']
        num_features = [c for c in X.columns if c != 'Soil_Type']
        try:
            ohe = pre.named_transformers_['cat']
            ohe_features = list(ohe.get_feature_names_out(['Soil_Type']))
        except Exception:
            ohe_features = []
        feature_names = num_features + ohe_features
        if len(importances) != len(feature_names):
            feature_names = [f'F{i}' for i in range(len(importances))]
        feat_imp_df = pd.DataFrame({'feature': feature_names, 'importance': importances}).sort_values('importance', ascending=False)
        plt.figure(figsize=(8,6))
        # Assign hue to 'feature' (avoids palette-without-hue deprecation) and remove legend
        ax = sns.barplot(data=feat_imp_df, x='importance', y='feature', hue='feature', dodge=False, palette='viridis')
        try:
            lg = ax.get_legend()
            if lg is not None:
                lg.remove()
        except Exception:
            pass
        plt.title('Feature Importances (Random Forest)')
        plt.xlabel('Importance')
        plt.ylabel('Feature')
        plt.savefig('../reports/feature_importances.png', bbox_inches='tight')
        plt.close()
    else:
        print('RandomForest feature importances not available.')
except Exception as e:
    print('Could not compute feature importances:', e)

print('Visualizations saved to ../reports/')

Visualizations saved to ../reports/
