In [2]:
import ee
import geopandas as gpd
import geemap
import os
from datetime import datetime

from fwi_predict.geo.ee import get_gfs

In [3]:
ee.Initialize(project='fwi-water-quality-sensing')

In [4]:
fp = "../data/clean/measurements_with_metadata.geojson"
measurements = gpd.read_file(fp)

In [5]:
gfs = get_gfs()

In [6]:
# Load measurements into EE feature collection
measurements['sample_dt'] = measurements['sample_dt']
fc = geemap.gdf_to_ee(measurements)
fc = fc.map(lambda f: f.set('sample_dt', ee.Date(f.get('sample_dt'))))

In [7]:
forecast_times = [3, 9, 15, 36, 0, -12, -36, -60]

In [8]:
# Map over the dates to get forecasts for each date
from typing import List

def get_sample_gfs_forecast(sample: ee.Feature,
							              forecast_times: List,
							              gfs: ee.ImageCollection = None) -> ee.FeatureCollection:
  """Add docstring."""

	# Get GFS data
  if gfs is None:
    gfs = get_gfs()
  
  # Get times for which we want forecasts.
  sample_dt = ee.Date(sample.get('sample_dt'))
  day_prior = sample_dt \
    .advance(5.5, 'hour') \
    .advance(-1, 'day') \
    .update(hour=0, minute=0, second=0) #Have to adjust for Asia/Kolkata timezone before finding previous day.
  
  forecast_time_list = ee.List(forecast_times).map(
    lambda hours: day_prior.advance(hours, 'hour').advance(-6, 'hour').millis() # Again adjusting for timezone so forecasts don't overlap with sample time
  )

  # Pre-filter GFS to reduce computation
  search_window = max(forecast_times) - min(forecast_times) + 24

  forecast_subset = gfs.filterDate( 
    ee.Date(forecast_time_list.get(1)).advance(-search_window, 'hour'), # Earliest forecast initialization time we are interested in 
    sample_dt.advance(-1, 'day') # Want forecasts initialized one day before sample was taken.
  )

  # Get latest forecast for each forecast (that is at least one day older than sample time)
  def get_latest_forecast_for_time(forecast_time: ee.Number) -> ee.Image:
    """Get most recent forecast for a given forecast time."""
    # Get frecast for specific time of interest
    subset = forecast_subset \
      .filter(ee.Filter.lt('creation_time', forecast_time)) \
      .filter(ee.Filter.eq('forecast_time', forecast_time)) # Less than in first so we get all datapoints.
   
    # Then get most recent forecast
    latest_init_time = subset.aggregate_array('creation_time').sort().get(-1)

    return subset.filter(ee.Filter.eq('creation_time', latest_init_time)).first()
  
  
  # Extract forecast values
  forecasts_for_times = ee.ImageCollection(
    forecast_time_list.map(get_latest_forecast_for_time)
  )

  # Assign metadata to forecast values and cumulative values
  forecast_values = forecasts_for_times \
    .map(lambda img: img.sample(sample.geometry())) \
    .flatten() \
    .map(lambda f: f # Set metadata
      .set('forecast_creation_dt', f.id().slice(0, 10)) # Same as below
      .set('forecast_time', f.id().slice(11, 14)) # Would be good to make this less hacky
      .set('sample_idx', sample.get('sample_idx'))
    )

    # Get forecast at time of sample
  sample_dt_rounded = sample_dt \
    .millis() \
    .divide(1000 * 60 * 60) \
    .round() \
    .multiply(1000 * 60 * 60) # Round sample time to nearest hour
  sample_time_forecast = ee.Image(get_latest_forecast_for_time(sample_dt_rounded))

  id = sample_time_forecast.getString('system:id').split("/").getString(2)
  sample_time_forecast = sample_time_forecast \
    .sample(sample.geometry()) \
    .first()
  
  sample_time_forecast = sample_time_forecast \
    .set('forecast_creation_dt', id.slice(0, 10)) \
    .set('forecast_time', 'sample') \
    .set('sample_idx', sample.get('sample_idx'))

  # Get cumulative values for the week prior to the sample time
  # 9 AM UTC is 3:30 PM IST
  def get_cumulative_history(lookback_days: ee.Number) -> ee.FeatureCollection:
    """Get cumulative history for a given number of days."""
    cum_days = ee.List.sequence(0, ee.Number(lookback_days).multiply(-1), step=-1)
    gfs_subset = gfs.filterDate(
      day_prior.advance(ee.Number(cum_days.sort().get(0)).subtract(1), 'day'),
      sample_dt
    )

    # Ought to check that you are summing correct number of days for each
    global_history = ee.ImageCollection(
      cum_days
      .map(lambda day: day_prior.advance(day, 'day').update(hour=9).millis())
      .map(lambda f_time: gfs_subset.filter(ee.Filter.eq('forecast_time', f_time)).sort('creation_time', False).first())
    )
    global_aggregate = global_history.reduce(ee.Reducer.sum())
    
    cum_values = ee.Image(global_aggregate)
    cum_values = cum_values \
      .rename(cum_values.bandNames().map(lambda name: ee.String(name).slice(0, -4))) \
      .sample(sample.geometry()) \
      .first() # Remove sum from end of band names

    return cum_values
  
  three_day_history = get_cumulative_history(3)
  week_history = get_cumulative_history(7)

  three_day_history = three_day_history \
    .set('sample_idx', sample.get('sample_idx')) \
    .set('forecast_time', 'three_day_cum')
  
  week_history = week_history \
    .set('sample_idx', sample.get('sample_idx')) \
    .set('forecast_time', 'seven_day_cum')

  # Merge and return
  forecast_values = forecast_values.merge(ee.FeatureCollection([sample_time_forecast, three_day_history, week_history]))
  
  return forecast_values

In [9]:
result = fc.map(lambda f: get_sample_gfs_forecast(f, forecast_times)).flatten()

### Later write config that logs what function was used to download the file.

In [10]:
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = os.path.splitext(os.path.basename(fp))[0] + '_' + current_time
task = ee.batch.Export.table.toCloudStorage(
  collection=result,
  description='gfs_forecast_export',
  bucket='fwi-predict',
  fileNamePrefix=f'train/gfs/{filename}',
  fileFormat='CSV'
)
task.start()