# TODOs left
- [ ] create different configs for the different cases 
- [ ] incorporate PyDrive again 
- [ ] stitch quadkeys back together / link them with UNOSAT labels in order to evaluate them 

In [None]:
# 1. create config with all the stuff necessary below
from omegaconf import DictConfig, OmegaConf

In [None]:
# Ukraine: 
cfg_ukr = OmegaConf.create(
        dict(
        #     aggregation_method="mean",
        #     model_name="random_forest",
        #     model_kwargs=dict(
        #         numberOfTrees=100,
        #         minLeafPopulation=3,
        #         maxNodes=1e4,
        #     ),
            data=dict(
                # aois_test=[f"UKR{i}" for i in range(1, 19) if i not in [1, 2, 3, 4]],
                # damages_to_keep=[1, 2],
                extract_winds=["3x3"]#,  # ['1x1', '3x3', '5x5']
                # time_periods={  # to train
                #     "pre": ("2020-02-24", "2021-02-23"),  # always only one
                #     "post": "3months",
                # },
            ),
            inference=dict(
                country="Ukraine", #added country identifier
                time_periods={
                    "pre": ("2020-02-24", "2021-02-23"),  # always only one
                    "post": [
                        ("2021-02-24", "2021-05-23"),
                        ("2021-05-24", "2021-08-23"),
                        ("2021-08-24", "2021-11-23"),
                        ("2021-11-24", "2022-02-23"),
                        ("2022-02-24", "2022-05-23"), # actual start of invasion: 24.02.2022
                        ("2022-05-24", "2022-08-23"),
                        ("2022-08-24", "2022-11-23"),
                        ("2022-11-24", "2023-02-23")
                    ],
                },
                quadkey_zoom=8,
            ),
            reducer_names=["mean", "stdDev", "median", "min", "max", "skew", "kurtosis"],
            # train_on_all=False,  # train on all damages (train + test split)
            verbose=0,
            # export_as_trees=False, 
            # seed=123,
            run_name="240307",  # must be string
        )
    )

In [1]:
# Syria: 
cfg_syria = OmegaConf.create(
        dict(
        #     aggregation_method="mean",
        #     model_name="random_forest",
        #     model_kwargs=dict(
        #         numberOfTrees=100,
        #         minLeafPopulation=3,
        #         maxNodes=1e4,
        #     ),
            data=dict(
                # aois_test=[f"UKR{i}" for i in range(1, 19) if i not in [1, 2, 3, 4]],
                # damages_to_keep=[1, 2],
                extract_winds=["3x3"]#,  # ['1x1', '3x3', '5x5']
                # time_periods={  # to train
                #     "pre": ("2020-02-24", "2021-02-23"),  # always only one
                #     "post": "3months",
                # },
            ),
            inference=dict(
                country="Syria", #added country identifier
                time_periods={ # should be oriented at UNOSAT before and afters if possible
                    "pre": ("2020-02-24", "2021-02-23"),  # always only one
                    "post": [
                        ("2021-02-24", "2021-05-23"),
                        ("2021-05-24", "2021-08-23")#,
                        #("2021-08-24", "2021-11-23"),
                        #("2021-11-24", "2022-02-23"),
                        #("2022-02-24", "2022-05-23"),
                        #("2022-05-24", "2022-08-23"),
                        #("2022-08-24", "2022-11-23"),
                        #("2022-11-24", "2023-02-23"),
                    ],
                },
                quadkey_zoom=8,
            ),
            reducer_names=["mean", "stdDev", "median", "min", "max", "skew", "kurtosis"],
            # train_on_all=False,  # train on all damages (train + test split)
            verbose=0,
            # export_as_trees=False, 
            # seed=123,
            run_name="240307",  # must be string
        )
    )

#TODO adapt for other countries: create different configs for the different cases

In [9]:
# 2. load trained model from GEE assets
from src.gee.classification.model import load_classifier
asset_id = "projects/rmac-ethz/assets/s1tsdd_Ukraine/240307/classifier_3months_100trees"
classifier = load_classifier(asset_id) #load trained classifier (by olivier) from GEE assets

In [10]:
# 3. ####### NEW - PREDICT AND EXPORT FUNCTION TO WORK FOR OTHER REGIONS FLEXIBLY ######

from typing import List
from tqdm import tqdm

# for creating the quadkey grid
from src.data.quadkeys import load_country_quadkeys_grid
from shapely.geometry import mapping
import ee
from src.utils.gee import init_gee
init_gee()

# from src.gee.constants import ASSETS_PATH
from src.gee.classification.inference import predict_geo
# from src.utils.gdrive import get_files_in_folder #TODO had to remove this bc linked to pydrive

def predict_and_export_all_grids(
    classifier: ee.Classifier,
    cfg: DictConfig,
    folder: str,
    # ids: List[str] = None,
    # n_limit: int = None,
    verbose: int = 0,
):
    """
    Predict and export for all grids (quadkeys) in Ukraine.

    If ids is not None, predict only these grids. If n_limit is given, only predict on n_limit grids.
    """

    # Get all grids
    print(f"Predicting for quadkey grid for {cfg.inference.country} with zoom {cfg.inference.quadkey_zoom}")
    # grids = ee.FeatureCollection(ASSETS_PATH + f"s1tsdd_Ukraine/quadkeys_grid_zoom{cfg.inference.quadkey_zoom}") 
    # instead: create quadkey COUNTRY grid: 
    grids = load_country_quadkeys_grid(zoom=cfg.inference.quadkey_zoom, country=cfg.inference.country) # if there isn't a quadkey grid yet: one will be created
    grids["geomee"] = grids.apply(lambda x: ee.Geometry(mapping(x['geometry'])), axis = 1) # add column with ee.Geometries() which is necessary for applying the model

    # if ids is None:
    #     # No IDs were given, we predict on all (or n_limit if given)
    #     if n_limit:
    #         # For debugging
    #         grids = grids.limit(n_limit)
    #     ids = grids.aggregate_array("qk").getInfo()
    # else:
    #     # make sure ids are strings
    #     ids = [str(id_) for id_ in ids]

    # Filter IDs that have already been predicted (names are qk_12345678.tif for instance) #TODO removed bc reliant on pydrive which doesnt work atm
    # files = get_files_in_folder(folder, return_names=True)
    # existing_names = [f.split(".")[0] for f in files if f.startswith("qk_")]
    # ids = [id_ for id_ in ids if id_ not in existing_names]

    # get operations still running
    def get_description(id_):
        return f"{cfg.run_name}_qk{id_}_{'_'.join(cfg.inference.time_periods.post)}"

    # ops = [o for o in ee.data.listOperations() if o["metadata"]["state"] in ["PENDING", "RUNNING"]]
    # ids_running = [o["metadata"]["description"] for o in ops]
    # ids = [id_ for id_ in ids if get_description(id_) not in ids_running]

    print(f"Predicting and exporting {len(grids)} grids")
    for i in tqdm(range(len(grids))):

        grid = grids.iloc[i]
        preds = predict_geo(
            grid.geomee,
            classifier,
            cfg.inference.time_periods,
            cfg.data.extract_winds,
            cfg.reducer_names,
            orbits=None,
            verbose=verbose,
        )
        preds = preds.set("qk", grid.qk)

        name = f"qk_{grid.qk}"
        task = ee.batch.Export.image.toDrive(
            image=preds.multiply(2**8 - 1).toUint8(),  # multiply by 255 and convert to uint8
            description=get_description(grid.qk),
            folder=folder,
            fileNamePrefix=name,
            region=grid.geomee,
            scale=10,
            maxPixels=1e13,
        )
        task.start()

In [None]:
####### OLD / ADAPTED FROM OLIVIER / ONLY WORKS FOR UKRAINE ######

from typing import List
from tqdm import tqdm

from src.gee.constants import ASSETS_PATH
from src.gee.classification.inference import predict_geo
# from src.utils.gdrive import get_files_in_folder #TODO had to remove this bc linked to pydrive

def predict_and_export_all_grids(
    classifier: ee.Classifier,
    cfg: DictConfig,
    folder: str,
    ids: List[str] = None,
    n_limit: int = None,
    verbose: int = 0,
):
    """
    Predict and export for all grids (quadkeys) in Ukraine.

    If ids is not None, predict only these grids. If n_limit is given, only predict on n_limit grids.
    """

    # Get all grids
    print(f"Predicting for quadkey grid with zoom {cfg.inference.quadkey_zoom}")
    grids = ee.FeatureCollection(ASSETS_PATH + f"s1tsdd_Ukraine/quadkeys_grid_zoom{cfg.inference.quadkey_zoom}") #TODO how to adapt this for other countries? could I also get this from elsewhere? 
    #TODO turn quadkeys for other regions (clipped to country or AOI borders) into ee.FeatureCollection (check birke_test.ipynb for this)
    #OR doesnt necessarily have to work with quadkeys - could also just create ee.FeatureCollection of areas I want to co
    # ver (e.g. all AOIs) through which I can loop later

    if ids is None:
        # No IDs were given, we predict on all (or n_limit if given)
        if n_limit:
            # For debugging
            grids = grids.limit(n_limit)
        ids = grids.aggregate_array("qk").getInfo()
    else:
        # make sure ids are strings
        ids = [str(id_) for id_ in ids]

    # Filter IDs that have already been predicted (names are qk_12345678.tif for instance) #TODO removed bc reliant on pydrive which doesnt work atm
    # files = get_files_in_folder(folder, return_names=True)
    # existing_names = [f.split(".")[0] for f in files if f.startswith("qk_")]
    # ids = [id_ for id_ in ids if id_ not in existing_names]

    # get operations still running
    def get_description(id_):
        return f"{cfg.run_name}_qk{id_}_{'_'.join(cfg.inference.time_periods.post)}"

    ops = [o for o in ee.data.listOperations() if o["metadata"]["state"] in ["PENDING", "RUNNING"]]
    ids_running = [o["metadata"]["description"] for o in ops]
    ids = [id_ for id_ in ids if get_description(id_) not in ids_running]

    print(f"Predicting and exporting {len(ids)} grids")
    for id_ in tqdm(ids):

        grid = grids.filter(ee.Filter.eq("qk", id_))
        preds = predict_geo(
            grid.geometry(),
            classifier,
            cfg.inference.time_periods,
            cfg.data.extract_winds,
            cfg.reducer_names,
            orbits=None,
            verbose=verbose,
        )
        preds = preds.set("qk", id_)

        name = f"qk_{id_}"
        task = ee.batch.Export.image.toDrive(
            image=preds.multiply(2**8 - 1).toUint8(),  # multiply by 255 and convert to uint8
            description=get_description(id_),
            folder=folder,
            fileNamePrefix=name,
            region=grid.geometry(),
            scale=10,
            maxPixels=1e13,
        )
        task.start()

In [None]:
# 4. apply function to work flexibly with different country configs
# TODO incl. creating separate folders etc. for different quadkeys BUT atm cannot download client id - solved this for now by creating folders manually
# from src.utils.gdrive import create_drive_folder, create_yaml_file_in_drive_from_config_dict

def apply_function(
    classifier: ee.Classifier,
    cfg: DictConfig
):
    base_folder_name = f"{cfg.run_name}_quadkeys_predictions"

    # try:
    #     # Create drive folder and save config
    #     create_drive_folder(base_folder_name)
    #     create_yaml_file_in_drive_from_config_dict(cfg, base_folder_name)
    # except Exception:
    #     # get input from user to be sure they want to continue
    #     print("Folder already exists. Continue? (y/n)")
    #     user_input = input()
    #     if user_input != "y":
    #         raise ValueError("Interrupted")

    post_periods = cfg.inference.time_periods.post

    for post_period in post_periods:

        folder_name = f"{base_folder_name}/{'_'.join(post_period)}"
        cfg.inference.time_periods.post = post_period

        # try:
        #     # Create drive folder and save config
        #     create_drive_folder(folder_name)
        # except Exception:
        #     # get input from user to be sure they want to continue
        #     print("Folder already exists. Continue? (y/n)")
        #     user_input = input()
        #     if user_input != "y":
        #         raise ValueError("Interrupted")

        
        # Launch predictions
        predict_and_export_all_grids(
            classifier=classifier,
            cfg=cfg,
            folder=folder_name.split("/")[-1], 
            ids=None,
            n_limit=None,
            verbose=cfg.verbose,
    )
    

In [None]:
# 5. actually run the function now with the specified country-config file
apply_function(classifier=classifier, cfg=cfg_syria)

In [None]:
#TODO: stitch quadkeys back together / link them with UNOSAT labels in order to evaluate them 
#might make sense to put them into a geodatabase -> have in different tables the links with admin regions, AOIs, and buildings maybe? => ON HOLD

In [None]:
from src.postprocessing.utils import read_fp_within_geo 
from src.constants import PREDS_PATH
from src.utils.geometry import load_country_boundaries

ukr_geo = load_country_boundaries("Ukraine")
fp = PREDS_PATH / cfg.run_name / "qk_12021333.tif"
preds_arr = read_fp_within_geo(fp, ukr_geo)
preds_arr

In [None]:
preds_arr.plot()