In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import ee
import geemap

from src.gee.utils import init_gee
init_gee()

# Classification in GEE

## 1. Load precomputed dataset

In [None]:
from omegaconf import OmegaConf
from src.gee.data.datasets import load_dataset, get_all_start_dates
from src.gee.data.datasets import get_name_asset
from src.gee.data.utils import get_base_asset_folder

n_tiles=32
extract_window=30
split="train"
fold=None
keep_damage=[1, 2]
random_loc=0.1
first_start_date="2020-06-01"
last_start_date="2022-06-01"
every_n_months = 1

start_dates = get_all_start_dates(first_start_date, last_start_date, every_n_months=1)
ds_train = None
for i in range(len(start_dates)//4):
    _start_dates = start_dates[i*4:(i+1)*4]
    cfg_train = OmegaConf.create(
        dict(
            split="train",
            fold=None,
            random_loc=random_loc,
            keep_damage=keep_damage,
            n_tiles=n_tiles,
            extract_window=extract_window,
            start_dates=_start_dates,
            save_if_doesnt_exist=True,
        )
    )
    _ds_train = load_dataset(**cfg_train)
    ds_train = ds_train.merge(_ds_train) if ds_train is not None else _ds_train


base_path = get_base_asset_folder(n_tiles, extract_window)
asset_name = get_name_asset(split, fold, keep_damage, random_loc, start_dates)
asset_path =base_path + f"Final/{asset_name}"

task = ee.batch.Export.table.toAsset(
    collection=ds_train,
    description=asset_name,
    assetId=asset_path,
)
#task.start()
#print(f"Task {asset_name} started")

In [None]:
from src.gee.data.datasets import load_dataset, get_all_start_dates
from omegaconf import OmegaConf

first_start_date = "2020-06-01"
last_start_date = "2022-05-01"
every_n_months = 1
#start_dates = get_all_start_dates(first_start_date, last_start_date, every_n_months=every_n_months)

# start_dates = ["2020-06-01", "2020-10-01", "2021-06-01", "2021-10-01"]
start_dates = ["2020-10-01", "2021-10-01"]
print(f"Start dates: {start_dates}")


cfg_train = OmegaConf.create(
    dict(
        split="train",
        fold=None,
        random_loc=0,
        keep_damage=[1, 2],
        n_tiles=32,
        extract_window=30,
        start_dates=start_dates,
        save_if_doesnt_exist=True,
    )
)

cfg_test = OmegaConf.create(
    dict(
        split="test",
        fold=None,
        random_loc=0,
        keep_damage=[1, 2],
        n_tiles=32,
        extract_window=30,
        start_dates=["2020-10-01", "2021-10-01"],
        save_if_doesnt_exist=True,
    )
)

In [None]:
ds_train = load_dataset(**cfg_train)
ds_test = load_dataset(**cfg_test)

In [None]:
ds_train.size()

In [None]:
ds_test.size()

## 2. Train Classifier

In [None]:
from src.gee.classification.model import train_classifier

cfg_model = OmegaConf.create(dict(
    model_name='randomForest',
    n_trees=50,
    output_mode='CLASSIFICATION',
))
trained_clf = train_classifier(ds_train, **cfg_model)

In [None]:
from src.gee.classification.model import export_classifier
from src.gee.data.datasets import get_name_asset
from src.gee.data.utils import get_base_asset_folder

model_name = get_name_asset(split='rf', fold=cfg_train.fold, keep_damage=cfg_train.keep_damage, random_perc=cfg_train.random_loc, start_dates=cfg_train.start_dates)
base_folder = get_base_asset_folder(cfg_train.n_tiles, cfg_train.extract_window) + 'Models_trained/'

In [None]:
model_name

In [None]:
export_classifier(trained_clf, model_name, base_folder)

In [None]:
import time
while ee.data.getTaskList()[0]['state'] == 'RUNNING':
    print('Running...')
    time.sleep(5)

## 3. Check metrics on test set

In [None]:
from src.gee.classification.model import load_classifier
from src.gee.classification.utils import compute_metrics

#trained_clf = load_classifier(model_name, base_folder)

In [None]:
preds = ds_test.classify(trained_clf)
compute_metrics(preds)

In [None]:
trained_clf = trained_clf.setOutputMode('PROBABILITY')
preds = ds_test.classify(trained_clf)
compute_metrics(preds)

In [None]:
preds = ds_test.classify(trained_clf)
compute_metrics(preds)

In [None]:
trained_clf = trained_clf.setOutputMode('PROBABILITY')
preds_proba = ds_test.classify(trained_clf)

In [None]:
def aggregate_predictions(preds):
    unique_dates = preds.aggregate_array("startDate").distinct()

    def aggregate_date(date):
        preds_date = preds.filter(ee.Filter.eq("startDate", date))
        unique_ids = preds_date.aggregate_array("unosat_id").distinct()

        def aggregate_id(id):
            all_preds_date_id = preds_date.filter(ee.Filter.eq("unosat_id", id))
            geo = all_preds_date_id.first().geometry()
            new_props = {
                "label": ee.String(all_preds_date_id.first().get("label")),
                "unosat_id": ee.String(id),
                "start_date": ee.String(date),
                "classification": all_preds_date_id.aggregate_mean("classification"),
            }
            new_feature = ee.Feature(ee.Geometry(geo), new_props)
            return new_feature

        _preds = ee.FeatureCollection(unique_ids.map(aggregate_id))
        return _preds

    return ee.FeatureCollection(unique_dates.map(aggregate_date)).flatten()

In [None]:
agg_preds = aggregate_predictions(preds_proba)
agg_preds = agg_preds.map(lambda f: f.set("classification_bin", ee.Number(f.get("classification")).gte(0.5)))
compute_metrics(agg_preds, preds_name="classification_bin")

## 4. Large-scale (country-wide) predictions

In [None]:
from src.utils.geometry import load_country_boundaries
from src.utils.gee import shapely_to_gee
ukraine_geo = load_country_boundaries('Ukraine')
ukraine_geo_ee = shapely_to_gee(ukraine_geo)

In [None]:
from src.gee.data.unosat import get_unosat_geo
from src.gee.data.collections import get_s1_collection
from src.gee.classification.features_extractor import manual_stats_from_s1


def inference(geo, start_date, trained_clf):
    # Make sure output is probability
    trained_clf = trained_clf.setOutputMode("PROBABILITY")

    s1 = get_s1_collection(geo, start=start_date)
    orbits = (
        s1.filterDate(start_date, ee.Date(start_date).advance(30, "day"))
        .aggregate_array("relativeOrbitNumber_start")
        .distinct()
    )

    def inference_one_orbit(orbit):
        s1_orbit = s1.filterMetadata("relativeOrbitNumber_start", "equals", orbit).limit(32)
        s1_orbit_stats = manual_stats_from_s1(s1_orbit)
        preds = s1_orbit_stats.classify(trained_clf)
        return preds

    results = orbits.map(inference_one_orbit) # List of ee.Image
    return ee.ImageCollection(results).mean()

start_date = "2021-10-01"
geo = get_unosat_geo("UKR1")  # can be arbitrary geometry
preds = inference(geo, start_date, trained_clf)

In [None]:
threshold = 0.65
vis_params = {"min": threshold, "max": 1, "palette": ["yellow", "orange", "red"]}

def postprocessing(preds, threshold=0.5, smoothen=False, only_urban=True):

    # Mask predictions below threshold
    preds = preds.updateMask(preds.gte(threshold))

    if smoothen:
        # Smooth predictions
        preds = preds.convolve(ee.Kernel.gaussian(radius=30, sigma=10, units='meters'))

    #if only_urban:
        # Mask predictions outside urban areas
        #urban = ee.ImageCollection("JRC/GHSL/P2016/SMOD_POP_GLOBE_V1").mosaic().select("smod_code").eq(11)
        #preds = preds.updateMask(urban)

    return preds

dataset = ee.Image('JRC/GHSL/P2016/BUILT_LDSMT_GLOBE_V1')
builtUpMultitemporal = dataset.select('built')
vis_params_urban = {
  'min': 1.0,
  'max': 6.0,
  'palette': ['0c1d60', '000000', '448564', '70daa4', '83ffbf', 'ffffff'],
}

Map = geemap.Map()
Map.addLayer(postprocessing(preds, threshold=threshold), vis_params, "Predictions")
#Map.addLayer(builtUpMultitemporal, vis_params_urban, 'Built-Up Multitemporal')
Map.centerObject(geo, 12)
Map

In [None]:
for start_date in get_all_start_dates(first_start_date, last_start_date, every_n_months=every_n_months):
    preds = inference(geo, start_date, trained_clf)
    task = ee.batch.Export.image.toDrive(
        preds,
        description=f'Preds_Mariupol_{start_date}_2dates_random10.tif',
        scale=10,
        region=geo,
        crs='EPSG:4326',
        folder='TestPreds'
    ).start()

# 5. Dummy classification

In [None]:
from src.gee.data.unosat import get_unosat_labels
start = '2021-01-01'
end = '2022-01-01'
s1 = (
    ee.ImageCollection("COPERNICUS/S1_GRD")
    .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
    .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
    .filter(ee.Filter.eq("instrumentMode", "IW"))
    .filter(ee.Filter.eq("platform_number", "A"))
    .filterDate(ee.Date(start), ee.Date(end))
)
points = ee.FeatureCollection('GOOGLE/EE/DEMOS/demo_landcover_labels')

In [None]:
s1_sf = s1.filterBounds(points)
s1_ukraine = s1.filterBounds(ukraine_geo_ee)

In [None]:
Map = geemap.Map()
Map.addLayer(s1_sf.mean(), {}, 'S1')
Map.addLayer(s1_ukraine.mean(), {}, 'S1 Ukraine')
Map

In [None]:
stats_reducers = (
    ee.Reducer.mean()
    .combine(reducer2=ee.Reducer.stdDev(), sharedInputs=True)
    .combine(reducer2=ee.Reducer.median(), sharedInputs=True)
    .combine(reducer2=ee.Reducer.max(), sharedInputs=True)
    .combine(reducer2=ee.Reducer.min(), sharedInputs=True)
    .combine(reducer2=ee.Reducer.skew(), sharedInputs=True)
    .combine(reducer2=ee.Reducer.kurtosis(), sharedInputs=True)
    .combine(reducer2=ee.Reducer.variance(), sharedInputs=True)
)
stats_sf = s1_sf.select(['VV','VH']).reduce(stats_reducers)
stats_ukraine = s1_ukraine.select(['VV','VH']).reduce(stats_reducers)

# get names bands
bands = stats_sf.bandNames().getInfo()

In [None]:
label = 'landcover'
training = stats_sf.sampleRegions(collection=points, properties=[label], scale=10, tileScale=2)
#training.first().getInfo()

In [None]:
ds = ds_train.select(bands + ['label'])

In [None]:
classifier_trained = ee.Classifier.smileRandomForest(50).train(ds, 'label', bands)

# Classify the image with the same bands used for training.
preds = stats_ukraine.select(bands).classify(classifier_trained)

In [None]:
Map = geemap.Map()
Map.addLayer(preds,
             {'min': 0, 'max': 1, 'palette': ['orange', 'green']},
             'classification')
# Map.addLayer(s1.mean(), {'bands': ['VV' ,'VH', 'VV'], 'min': -10, 'max': 0}, 'image')
# Map.addLayer(points, {}, 'points')
Map

In [None]:
ds_test_inferred = preds.sampleRegions(collection=ds_test, properties=['label'], scale=10, tileScale=2)

In [None]:
ds_test_inferred.first().getInfo()

In [None]:
compute_metrics(ds_test_inferred)

In [None]:
preds_test = ds_test.select(bands + ['label']).classify(classifier_trained)

In [None]:
compute_metrics(preds_test)