# Crop Mask Inference
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/crop_mask_inference.ipynb)

**Author:** Ivan Zvonkov (izvonkov@umd.edu)

**Description:** This notebook provides all the code to create a crop mask using the deployed Google Cloud architecture.


In [None]:
import os
import re
import requests

from collections import defaultdict
from datetime import datetime
from dateutil.relativedelta import relativedelta
from glob import glob
from google.colab import auth
from google.cloud import storage
from pathlib import Path
from tqdm.notebook import tqdm

earthengine_user = "izvonkov" # Update to be your username

gcloud_project_id = "bsos-geog-harvest1"
tifs_bucket_name = "crop-mask-earthengine"
preds_bucket_name = "crop-mask-preds"
preds_merged_bucket_name = "crop-mask-preds-merged"
ee_status_url = "https://us-central1-bsos-geog-harvest1.cloudfunctions.net/ee-status"
models_url = "https://crop-mask-management-api-grxg7bzh2a-uc.a.run.app/models"
start_inference_url = "http://us-central1-bsos-geog-harvest1.cloudfunctions.net/export-region"


# Functions

In [None]:
#######################################################
# Status functions
#######################################################
def get_ee_task_amount():
    response = requests.get(ee_status_url)
    assert response.status_code == 200, f"Got {response.status_code}. Either the url is incorrect or gcloud is not authenticated."
    ee_tasks = response.json()
    return ee_tasks['amount']

def get_gcs_file_dict_and_amount(bucket_name, prefix):
    blobs = client.list_blobs(bucket_name, prefix=prefix)
    files_dict = defaultdict(lambda: [])
    amount = 0
    for blob in tqdm(blobs, desc=f"From {bucket_name}"):
        p = Path(blob.name)
        files_dict[str(p.parent)].append(p.stem.replace("pred_", ""))
        amount += 1
    return files_dict, amount

def get_gcs_file_amount(bucket_name, prefix):
    return len(list(client.list_blobs(bucket_name, prefix=prefix)))

def get_status(model_name, version):
    prefix = f"{model_name}/{version}"
    print("------------------------------------------------------------------------------")
    print(prefix) 
    print("------------------------------------------------------------------------------")
    ee_task_amount = get_ee_task_amount()
    tifs_amount = get_gcs_file_amount(tifs_bucket_name, prefix=prefix)
    predictions_amount = get_gcs_file_amount(preds_bucket_name, prefix=prefix)
    print(f"Earth Engine tasks: {ee_task_amount}")
    print(f"Data available: {tifs_amount}")
    print(f"Predictions: {predictions_amount}")
    return ee_task_amount, tifs_amount, predictions_amount

#######################################################
# Inference functions
#######################################################
def start_inference(request_config):
    print("Starting inference")
    print("Depending on the size the bounding box this may take multiple hours.")
    response = requests.post(
        url=start_inference_url,
        json=request_config
    )
    print(response.json())
    return response

def find_missing_predictions(model_name, version, verbose=False):
    print("Addressing missing files")
    prefix = f"{model_name}/{version}"
    tif_files, tif_amount = get_gcs_file_dict_and_amount(tifs_bucket_name, prefix=prefix)
    pred_files, pred_amount  = get_gcs_file_dict_and_amount(preds_bucket_name, prefix=prefix)
    missing = {}
    for full_k in tqdm(tif_files.keys(), desc="Missing files"):
        if full_k not in pred_files:
            diffs = tif_files[full_k]
        else:
            diffs = list(set(tif_files[full_k]) - set(pred_files[full_k]))
        if len(diffs) > 0:
            missing[full_k] = diffs

    batches_with_issues = len(missing.keys())
    if verbose:
        print("------------------------------------------------------------------------------")
        print(prefix) 
        print("------------------------------------------------------------------------------")
    if batches_with_issues > 0:
        print(f"\u2716 {batches_with_issues}/{len(tif_files.keys())} batches have a total {tif_amount - pred_amount} missing predictions")
        if verbose:
            for batch, files in missing.items():
                print("\t--------------------------------------------------")
                print(f"\t{Path(batch).stem}: {len(files)}")
                print("\t--------------------------------------------------")
                [print(f"\t{f}") for f in files]
    else:
        print(f"\u2714 all files in each batch match")
    return missing

def make_new_predictions(missing):
    bucket = client.bucket(tifs_bucket_name)
    for batch, files in tqdm(missing.items(), desc="Going through batches"):
        for file in tqdm(files, desc="Renaming files", leave=False):
            blob_name = f"{batch}/{file}.tif"
            blob = bucket.blob(blob_name)
            if blob.exists():
                new_blob_name = f"{batch}/{file}-retry1.tif"
                bucket.rename_blob(blob, new_blob_name)
            else:
                print(f"Could not find: {blob_name}")  

#######################################################
# Map making functions
#######################################################
def gdal_cmd(cmd_type: str, in_file: str, out_file: str, msg = None, print_cmd=False):
    if cmd_type == "gdalbuildvrt":
        cmd = f"gdalbuildvrt {out_file} {in_file}"
    elif cmd_type == "gdal_translate":
        cmd = f"gdal_translate -a_srs EPSG:4326 -of GTiff {in_file} {out_file}"
    else:
        raise NotImplementedError(f"{cmd_type} not implemented.")
    if msg:
        print(msg)
    if print_cmd:
        print(cmd)
    os.system(cmd)

def build_vrt(prefix):
    # Build vrts for each batch of predictions
    print("Building vrt for each batch")
    for d in tqdm(glob(f"{prefix}_preds/*/*/")):
        if "batch" not in d:
            continue

        match = re.search("batch_(.*?)/", d)
        if match:
            i = int(match.group(1))
        else:
            raise ValueError(f"Cannot parse i from {d}")
        vrt_file = Path(f"{prefix}_vrts/{i}.vrt")
        if not vrt_file.exists():
            gdal_cmd(cmd_type="gdalbuildvrt", in_file=f"{d}*", out_file=str(vrt_file))

    gdal_cmd(
        cmd_type="gdalbuildvrt",
        in_file=f"{prefix}_vrts/*.vrt",
        out_file=f"{prefix}_final.vrt",
        msg="Building full vrt",
    )


# 1. Setup
**Prerequisite**: Access to bsos-geog-harvest Google Cloud project.

In [None]:
auth.authenticate_user()

In [None]:
client = storage.Client(project=gcloud_project_id)

In [None]:
response = requests.get(models_url)
assert response.status_code == 200, f"Got {response.status_code}. Either the url is incorrect or gcloud is not authenticated."
available_models = [item["modelName"] for item in response.json()["models"]]
available_models

# 2. Inference configuration

In [None]:
##################################################################
# START: Configuration (edit below code)
##################################################################

model_name = "Ethiopia_Tigray_2020" # Name of model to use
version = "v2" # Version of map being made

# Coordinates for map
min_lon = 36.45,
max_lon = 40.00,
min_lat = 12.25,
max_lat = 14.895,

# Time range for data
start_date = "2020-02-01"
end_date = "2021-02-01"

##################################################################
# END: Configuration
##################################################################

# Verify configuration
config_issue = False
if model_name in available_models:
    print(f"\u2714 {model_name} is an available model")
else:
    print(f"\u2716 {model_name} not in available models")
    config_issue = True

if isinstance(version, str) and version != "":
    print(f"\u2714 Version: {version} specified.")
else:
    print(f"\u2716 {version} must be a string that's not blank")
    config_issue = True

if min_lon < max_lon and min_lat < max_lat:
    print(f"\u2714 Map coordinates are consistent")
else:
    print(f"\u2716 Check that min longitude/latitude is smaller than max longitude/latitude")
    config_issue = True

try:
    start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
    end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
    if start_date_obj >= end_date_obj:
        print(f"\u2716 {start_date} should be before {end_date}")
    else:
        print(f"\u2714 Dates: {start_date} and {end_date} have correct format")
    
    max_date_obj = datetime.now().replace(day=1) + relativedelta(months=-3)
    if end_date_obj > max_date_obj:
        print(f"\u2716 End date {end_date} must be before {max_date_obj}")
        config_issue = True
except ValueError:
    config_issue = True

assert config_issue is False, "Issue in config, check logs."

# 3. Run inference

In [None]:
ee_task_amount, tifs_amount, predictions_amount = get_status(model_name, version)
if ee_task_amount == 0:
    if tifs_amount == 0 and predictions_amount == 0:
        start_inference({
                "model_name": model_name,
                "version": version,
                "min_lon": min_lon,
                "max_lon": max_lon,
                "min_lat": min_lat,
                "max_lat": max_lat,
                "start_date": start_date,
                "end_date": end_date
            })
    elif tifs_amount > predictions_amount:
        missing = find_missing_predictions(model_name, version)
        make_new_predictions(missing)
    else:
        print("Inference complete! Time to merge predictions into a map.")
else:
    print(f"Please wait for all Earth Engine {ee_task_amount} tasks to complete and rerun this cell."
    "\nLarger area of interest means more Earth Engine tasks means longer processing time.")
    

# 4. Merge predictions into a map


In [None]:
if ee_task_amount > 0:
    print(f"Please wait for all Earth Engine {ee_task_amount} tasks to complete and rerun the above cell before moving on.")

prefix = f"{model_name}_{version}"
!mkdir {prefix}_preds
!mkdir {prefix}_vrts
!mkdir {prefix}_tifs

In [None]:
print("Download predictions as nc files (will take several minutes)")
!gsutil -m cp -n -r gs://{preds_bucket_name}/{model_name}/{version}* {prefix}_preds

In [None]:
build_vrt(prefix)

In [None]:
# Translate vrt for all predictions into a tif file
!gdal_translate -a_srs EPSG:4326 -of GTiff {prefix}_final.vrt {prefix}_final.tif

# 5. Upload map to Earth Engine

In [None]:
dest = f"gs://{preds_merged_bucket_name}/{model_name}/{version}_{start_date}_{end_date}"

In [None]:
!gsutil cp final.tif {dest}

In [None]:
!earthengine authenticate

In [None]:
!earthengine upload image --asset_id users/{earthengine_user}/{prefix}_1 -ts {start_date} -te {end_date} {dest}

See map upload here: https://code.earthengine.google.com/