In [39]:
%load_ext autoreload
%autoreload 2

from datetime import datetime, timezone
from typing import List

import ee
import geemap
import geopandas as gpd
import pandas as pd
import pytz

from fwi_predict.constants import FORECAST_TIMES
from fwi_predict.geo.ee import get_gfs

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
ee.Initialize()

In [41]:
def get_sample_gfs_forecast(sample: ee.Feature,
							              forecast_times: List,
							              gfs: ee.ImageCollection = None) -> ee.FeatureCollection:
	"""Add docstring."""
	if gfs is None:
		gfs = get_gfs()


	# Get times for which we want forecasts.
	sample_idx = sample.get('sample_idx') # Get sample index
	sample_dt = ee.Date(sample.get('sample_dt'))
	day_prior = sample_dt \
		.advance(5.5, 'hour') \
    .advance(-1, 'day') \
    .update(hour=0, minute=0, second=0)
  
	forecast_time_list = ee.List(forecast_times).map(
    lambda hours: day_prior.advance(hours, 'hour').advance(-6, 'hour').millis() # Again adjusting for timezone so forecasts don't overlap with sample time
  )

  # Pre-filter GFS to reduce computation
	forecast_subset = gfs.filterDate( 
    ee.Date(forecast_time_list.sort().getNumber(0)).advance(-1, 'day'), # Earliest forecast initialization time we are interested in 
    day_prior # Want forecasts initialized one day before sample was taken (5:30am IST)
  )

  # Get latest forecast for each forecast (that is at least one day older than sample time)
	def get_latest_forecast_for_time(forecast_time: ee.Number) -> ee.Image:
		"""Get most recent forecast for a given forecast time."""
    # Get frecast for specific time of interest
		subset = forecast_subset \
			.filter(ee.Filter.lt('creation_time', forecast_time)) \
			.filter(ee.Filter.eq('forecast_time', forecast_time))
      
    # Then get most recent forecast
		latest_init_time = subset.aggregate_array('creation_time').sort().get(-1)
	
		return subset.filter(ee.Filter.eq('creation_time', latest_init_time)).first()
  
  
  # Extract forecast values
	forecasts_for_times = ee.ImageCollection(
    forecast_time_list.map(get_latest_forecast_for_time)
  )

	return forecasts_for_times
	
  # Assign metadata to forecast values and cumulative values
	forecast_values = forecasts_for_times \
    .map(lambda img: img.sample(sample.geometry())) \
    .flatten() \
    .map(lambda f: f # Set metadata
      .set('forecast_creation_dt', f.id().slice(0, 10)) # Same as below
      .set('forecast_hour', f.id().slice(11, 14)) # Would be good to make this less hacky
      .set('sample_idx', sample_idx)
    )
	
	# Map each element of forecast_time_list to each feature of forecast_values
	forecast_values_list = forecast_values.toList(forecast_values.size())
	forecast_values = ee.FeatureCollection(
		forecast_values.map(lambda f: f.set('forecast_time', 
			ee.List(forecast_times).get(forecast_values_list.indexOf(f))))
	)

  # Get forecast at time of sample
	sample_dt_rounded = sample_dt \
    .millis() \
    .divide(1000 * 60 * 60) \
    .round() \
    .multiply(1000 * 60 * 60) # Round sample time to nearest hour
	sample_time_forecast = ee.Image(get_latest_forecast_for_time(sample_dt_rounded))

	return sample_time_forecast

In [42]:
ds = gpd.read_file("../data/clean/measurements_with_metadata.geojson")
temp = ds.iloc[3380:3383]
ee_frame = geemap.gdf_to_ee(temp, date='sample_dt', date_format="yyyy-MM-dd'T'HH:mm:ssZ")

In [43]:
out = ee.Date(ee_frame.first().get('sample_dt')).getInfo()
datetime.fromtimestamp(out['value']/1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S")

'2022-04-23 01:13:00'

In [44]:
result = ee_frame.map(lambda f: get_sample_gfs_forecast(f, FORECAST_TIMES))

In [45]:
sample = ee_frame.first()

In [46]:
gfs = get_gfs()

In [47]:
FORECAST_TIMES[6] = -36

In [48]:
sample_idx = sample.get('sample_idx') # Get sample index
sample_dt = ee.Date(sample.get('sample_dt'))
day_prior = sample_dt \
	.advance(5.5, 'hour') \
	.advance(-1, 'day') \
	.update(hour=0, minute=0, second=0)

forecast_time_list = ee.List(FORECAST_TIMES).map(
	lambda hours: day_prior.advance(hours, 'hour').advance(-6, 'hour').millis() # Again adjusting for timezone so forecasts don't overlap with sample time
)

forecast_subset = gfs.filterDate( 
	ee.Date(forecast_time_list.sort().getNumber(0)).advance(-2, 'day'), # Earliest forecast initialization time we are interested in 
	day_prior # Want forecasts initialized one day before sample was taken (5:30am IST)
)

In [49]:
# Get latest forecast for each forecast (that is at least one day older than sample time)
def get_latest_forecast_for_time(forecast_time: ee.Number) -> ee.Image:
	"""Get most recent forecast for a given forecast time."""
	# Get frecast for specific time of interest
	subset = forecast_subset \
		.filter(ee.Filter.lt('creation_time', forecast_time)) \
		.filter(ee.Filter.eq('forecast_time', forecast_time))
		
	# Then get most recent forecast
	latest_init_time = subset.aggregate_array('creation_time').sort().get(-1)

	return subset.filter(ee.Filter.eq('creation_time', latest_init_time)).first()


In [62]:
times = forecast_time_list.sort().getInfo()
[datetime.fromtimestamp(d/1000, tz=pytz.timezone('Asia/Kolkata')) for d in times]

[datetime.datetime(2022, 4, 20, 11, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2022, 4, 21, 14, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2022, 4, 22, 8, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2022, 4, 22, 14, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2022, 4, 22, 20, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2022, 4, 23, 8, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2022, 4, 23, 14, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>)]

In [50]:
creation_times = ee.ImageCollection(result).aggregate_array('creation_time').getInfo()

In [46]:
temp['sample_dt']

0   2023-10-10 06:34:00+05:30
1   2023-10-10 16:40:00+05:30
2   2023-11-16 07:01:00+05:30
3   2023-11-16 16:11:00+05:30
4   2023-12-16 07:21:00+05:30
5   2023-12-16 16:37:00+05:30
6   2024-01-05 07:38:00+05:30
7   2024-01-05 16:36:00+05:30
8   2024-02-06 07:15:00+05:30
9   2024-02-06 16:03:00+05:30
Name: sample_dt, dtype: datetime64[ms, UTC+05:30]

In [64]:
[datetime.fromtimestamp(d/1000, tz=pytz.timezone('Asia/Kolkata')) for d in forecast_times]

[datetime.datetime(2023, 10, 10, 6, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2023, 10, 10, 16, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2023, 11, 16, 7, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2023, 11, 16, 16, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2023, 12, 16, 7, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2023, 12, 16, 16, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2024, 1, 5, 7, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2024, 1, 5, 16, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2024, 2, 6, 7, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>),
 datetime.datetime(2024, 2, 6, 16, 30, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>)]

In [67]:
temp['sample_dt']

0   2023-10-10 06:34:00+05:30
1   2023-10-10 16:40:00+05:30
2   2023-11-16 07:01:00+05:30
3   2023-11-16 16:11:00+05:30
4   2023-12-16 07:21:00+05:30
5   2023-12-16 16:37:00+05:30
6   2024-01-05 07:38:00+05:30
7   2024-01-05 16:36:00+05:30
8   2024-02-06 07:15:00+05:30
9   2024-02-06 16:03:00+05:30
Name: sample_dt, dtype: datetime64[ms, UTC+05:30]

In [80]:
list(map(lambda d: d.astimezone(timezone.utc), temp['sample_dt'].dt.to_pydatetime()))

  list(map(lambda d: d.astimezone(timezone.utc), temp['sample_dt'].dt.to_pydatetime()))


[datetime.datetime(2023, 10, 10, 1, 4, tzinfo=datetime.timezone.utc),
 datetime.datetime(2023, 10, 10, 11, 10, tzinfo=datetime.timezone.utc),
 datetime.datetime(2023, 11, 16, 1, 31, tzinfo=datetime.timezone.utc),
 datetime.datetime(2023, 11, 16, 10, 41, tzinfo=datetime.timezone.utc),
 datetime.datetime(2023, 12, 16, 1, 51, tzinfo=datetime.timezone.utc),
 datetime.datetime(2023, 12, 16, 11, 7, tzinfo=datetime.timezone.utc),
 datetime.datetime(2024, 1, 5, 2, 8, tzinfo=datetime.timezone.utc),
 datetime.datetime(2024, 1, 5, 11, 6, tzinfo=datetime.timezone.utc),
 datetime.datetime(2024, 2, 6, 1, 45, tzinfo=datetime.timezone.utc),
 datetime.datetime(2024, 2, 6, 10, 33, tzinfo=datetime.timezone.utc)]

### Check date consistency in geo dataset

In [105]:
gfs_data = pd.read_csv("../data/gcs/train/gfs/measurements_with_metadata.csv")
gfs_data = gfs_data.drop(columns=['system:index', '.geo'])

In [110]:
gfs_data = gfs_data[gfs_data['forecast_hour'].notna()] # Stick to numeric forecast times
gfs_data['creation_time'] = pd.to_datetime(gfs_data['forecast_creation_dt'].astype(int).astype(str), format='%Y%m%d%H', utc=True).dt.tz_convert(pytz.timezone('Asia/Kolkata'))
gfs_data[['creation_time', 'forecast_creation_dt']].head()

Unnamed: 0,creation_time,forecast_creation_dt
0,2023-10-08 23:30:00+05:30,2023101000.0
1,2023-10-08 23:30:00+05:30,2023101000.0
2,2023-10-08 23:30:00+05:30,2023101000.0
3,2023-10-08 23:30:00+05:30,2023101000.0
4,2023-10-08 23:30:00+05:30,2023101000.0


In [None]:
gfs_data['sample_dt'] = gfs_data['sample_idx'].map(ds.set_index('sample_idx')['sample_dt'])
gfs_data['forecast_time_check'] = gfs_data['creation_time'] + pd.to_timedelta(gfs_data['forecast_hour'], 'hour')
gfs_data[['sample_idx', 'creation_time', 'sample_dt', 'forecast_time_check', 'forecast_hour', 'forecast_time']].head(40)

Times all appear correct to me. Now check subsequent cleaning.

In [113]:
observations_per_measurement = gfs_data.groupby('sample_idx').size()
assert (observations_per_measurement.eq(observations_per_measurement.iloc[0]).all()), (
	"Number of observations per measurement varies."
)

# Reorder columns and rows
front_cols = ['sample_idx', 'forecast_time', 'forecast_creation_dt', 'forecast_hour'] 
gfs_data = gfs_data[front_cols + [col for col in gfs_data.columns if col not in front_cols]]
gfs_data = gfs_data.sort_values(['sample_idx', 'forecast_time'])

# Pivot wide to one observation per measurement
value_cols = gfs_data.columns[~gfs_data.columns.isin(front_cols)].tolist()
gfs_data_wide = gfs_data.pivot(index='sample_idx', columns='forecast_time', values=value_cols)

This also looks correct to me