In [77]:
from pathlib import Path
import json
import pandas as pd
from numpy import ndarray
from typing import Tuple

from sklearn.metrics import mean_absolute_error, root_mean_squared_error, r2_score, mean_absolute_percentage_error
from PyGRF import PyGRF

from src.data_utils import LabelledTraitData

In [78]:
Array = ndarray

In [92]:
BANDS = [
    'B2_real', # Blue band, 490 nm
    'B3_real', # Green band, 560 nm
    'B4_real', # Red band, 665 nm
    'B5_real', # Red edge band, 705 nm
    'B6_real', # Red edge band, 740 nm
    'B7_real', # Red edge band, 783 nm
    'B8_real', # NIR band, 842 nm
    'B11_real', # SWIR band, 1610 nm
    'B12_real' # SWIR band, 2190 nm
    # 'B8a_real' # NIR band, 865 nm
]

VARS = [
    'N.Percent',
    'P.Percent',
    'K.Percent',
    'Ca.Percent',
    'Mg.Percent',
    'C.Percent',
    'Amax',
    'Asat',
    'Area.cm2',
    'Dry.mass.g',
    'Fresh.mass.g',
    'Thickness.mm',
    'SLA.g.m2'
]

In [80]:
def preprocess_data(X: pd.DataFrame, y: Array) -> Tuple[Array, Array]:
    """Configure the data for the model."""
    return X[BANDS], y

def _unstandardise(x: Array, mean: float, std: float):
    return (x * std) + mean

def unstandardise(preds, targets, stats):
    """Rescale targets and predictions for sensible metrics."""
    preds = _unstandardise(preds, stats['mean'], stats['std'])
    targets = _unstandardise(targets, stats['mean'], stats['std'])
    return preds, targets

In [81]:
dpath = Path('/Users/campbelli/Documents/geofm-plant-traits/data')
var = 'Asat'

In [82]:
metadata_path = dpath / 'metadata'
with open(metadata_path / 'trait_stats.json', 'r') as f:
    trait_stats = json.load(f)

pixel_coords = pd.read_csv(metadata_path / 'pixel_coords.csv', index_col=0)

In [83]:
# Load model, data and trait stats.
dataset = LabelledTraitData(dpath, var)

X_train, y_train = dataset.train_data, dataset.train_labels
X_test, y_test = dataset.test_data, dataset.test_labels
X_val, y_val = dataset.val_data, dataset.val_labels

X_train, y_train = preprocess_data(X_train, y_train)
X_test, y_test = preprocess_data(X_test, y_test)
X_val, y_val = preprocess_data(X_val, y_val)

In [84]:
train_coords = pixel_coords.loc[X_train.index]
test_coords = pixel_coords.loc[X_test.index]
val_coords = pixel_coords.loc[X_val.index]

In [94]:
def get_trait_data_and_coords(dpath, var):
    dataset = LabelledTraitData(dpath, var)
    y = dataset.train_labels
    coords = pixel_coords.loc[y.index]
    return y, coords

In [None]:
for var in VARS:
    print(var)

    with open(metadata_path / 'grf_hparams.json', 'r') as f:
        hparams = json.load(f)

    if var not in hparams:
        # Get trait values and coordinates.
        y_train, train_coords = get_trait_data_and_coords(dpath, var)

        # Find optimal params with auto-correlation method.
        bandwidth, local_weight, p_value = PyGRF.search_bw_lw_ISA(
            y_train["TraitValue"],
            train_coords[['Lon', 'Lat']]
        )
        hparams[var] = {"bandwidth": bandwidth, "local_weight": local_weight}

        # Write the updated hyperparameters to file.
        with open(metadata_path / 'grf_hparams.json', 'w') as f:
            json.dump(hparams, f, indent=4)
    else:
        pass

In [86]:
pygrf = PyGRF.PyGRFBuilder(
    n_estimators=100,
    max_features=0.5,
    band_width=bandwidth,
    train_weighted=True,
    predict_weighted=True,
    bootstrap=False,
    resampled=True,
    random_state=42
)

In [None]:
# Fit and predict with the PyGRF model.
pygrf.fit(X_train, y_train, train_coords[['Lon', 'Lat']])
predict_combined, predict_global, predict_local = pygrf.predict(X_test, test_coords[['Lon', 'Lat']], local_weight=local_weight)

In [None]:
trait_stats[var]

{'mean': 13.087212276355176, 'std': 5.9540228989359205}

In [None]:
predict_combined = pd.Series(predict_combined, index=X_test.index)

In [None]:
unstandardised_preds, unstandardised_targets = unstandardise(predict_combined, y_test, trait_stats[var])

In [None]:
# Evaluate the model.
mae = mean_absolute_error(unstandardised_targets, unstandardised_preds)
rmse = root_mean_squared_error(unstandardised_targets, unstandardised_preds)
r2 = r2_score(unstandardised_targets, unstandardised_preds)
mape = mean_absolute_percentage_error(unstandardised_targets, unstandardised_preds)
print(f'MAE: {mae:.3f}')
print(f'RMSE: {rmse:.3f}')
print(f'R2: {r2:.3f}')
print(f'MAPE: {mape:.3f}')

MAE: 2.524
RMSE: 3.584
R2: 0.642
MAPE: 0.277
