In [None]:
from collections import defaultdict
import json
import os

from dask.distributed import Client
import fsspec
import numpy as np
import pandas as pd
from shapely.ops import cascaded_union
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import RadiusNeighborsClassifier

## And a bunch of carbonplan dependencies
from carbonplan_data import cat as core_cat

from carbonplan_retro.data import cat
from carbonplan_retro.analysis.assign_project_fldtypcd import load_classification_data
from carbonplan_retro.load.geometry import (
    get_overlapping_states,
    load_supersections,
)
from carbonplan_retro.load.project_db import load_project_db

In [None]:
def species_array_to_d(species_array):
    return {str(species["code"]): round(species["fraction"], 4) for species in species_array}


def prepare_regional_classifier(ss_ids):
    """returns trained classifier and data vectorizer to apply to multiple opr_ids"""
    da = core_cat.nlcd.raster(region="conus").to_dask()
    crs = da.attrs["crs"]

    supersections = load_supersections().to_crs(crs)

    subset_supersection = supersections[supersections["ss_id"].isin(ss_ids)].copy()
    subset_supersection.loc[:, "dissolve_all"] = 1

    aoi = subset_supersection.dissolve(by="dissolve_all").buffer(150_000).to_crs("epsg:4326").item()

    postal_codes = get_overlapping_states(aoi)
    print(f"preparing to load: {[x for x in postal_codes]}")

    if (len(ss_ids) == 1) & (ss_ids[0] > 200):
        data = load_classification_data(postal_codes)
    else:
        data = load_classification_data(postal_codes, bounds=aoi.bounds)

    print(f"fitting classifier ")
    base_clf = RadiusNeighborsClassifier(weights="distance", algorithm="brute", outlier_label=-999)
    param_grid = [
        {"radius": np.arange(0.15, 0.651, 0.025)}
    ]  # initial testing never yielded a case where we went above 0.5

    clf = GridSearchCV(
        base_clf, param_grid, n_jobs=int(os.cpu_count() / 2), cv=5, refit=True, verbose=10
    )
    clf.fit(data["features"], data["targets"])
    return clf, data["dictvectorizer"]

In [None]:
client = Client()
client

In [None]:
project_db = load_project_db("/home/jovyan/lost+found/Forest-Offset-Projects-v0.3.json")
projects = project_db[~project_db["project"]["early_action"].str.startswith("CAR")]
projects = projects[~project_db["project"]["species"].isnull()]
projects = projects[~projects["project"]["species"].apply(lambda x: "all" in x)]

In [None]:
clf_cache = {}  # separate cell so you can re-run next cell if anything bonks

In [None]:
classifications = defaultdict(dict)

for opr_id, project in projects.iterrows():
    if opr_id in ["CAR1094", "CAR1032"]:
        print(f"skipping {opr_id} -- discuss w group")
        continue

    print(opr_id)
    try:
        clf, data_encoder = clf_cache[project["project"]["supersection_ids"].astype(str).item()]
    except:
        clf, data_encoder = prepare_regional_classifier(
            project["project"]["supersection_ids"].item()
        )
        clf_cache[project["project"]["supersection_ids"].astype(str).item()] = (clf, data_encoder)

    for aa_id, species_arr in project_db["project"]["species"][opr_id].items():

        feat_dict = species_array_to_d(species_arr)
        feats = data_encoder.transform(feat_dict)
        classification = pd.Series(clf.predict_proba(feats).flatten(), index=clf.classes_)
        classifications[opr_id][aa_id] = classification[classification > 0].sort_values().to_dict()

## Store some outputs

Store the 5-fold CV radius parameter as well as the classifications. Can't imagine we'd ever need to
use the radii, but they're sort of expensive to compute so store for good measure.


In [None]:
fit_radii = {k: v[0].best_params_["radius"] for k, v in clf_cache.items()}

with fsspec.open(
    "az://carbonplan-scratch/radius_neighbor_params.json",
    account_name="carbonplan",
    mode="w",
    account_key=os.environ["BLOB_ACCOUNT_KEY"],
) as f:
    json.dump(fit_radii, f, indent=2)

with fsspec.open(
    "az://carbonplan-scratch/project_radius_classification.json",
    account_name="carbonplan",
    mode="w",
    account_key=os.environ["BLOB_ACCOUNT_KEY"],
) as f:
    json.dump(classifications, f, indent=2)