#### Imports & Setup

In [None]:
import numpy as np
import seaborn as sns
import pandas as pd
import ee
import geemap
import matplotlib.pyplot as plt
import geopandas as gpd
from google.colab import drive
from datetime import datetime, timedelta
import concurrent.futures
import torch
from scipy import stats

In [None]:
drive.mount('/content/drive')

In [None]:
# authenticate & init GEE
ee.Authenticate()
ee.Initialize(project='')

### Retrieved filtered year data

In [None]:
df = pd.read_csv("/content/drive/2018_LA_merged_scaled.csv", delimiter=",")
df.rename(columns={'latitude_x': 'latitude', 'longitude_x':'longitude'}, inplace=True)
print(df.shape)
df.head()

In [None]:
def get_bounding_box(df):
  '''
  create bounding box based on available weather stations
  '''
  mean_lon, mean_lat = df.longitude.mean(), df.latitude.mean()
  min_lon, min_lat = df[['longitude', 'latitude']].values.min(axis=0)
  max_lon, max_lat = df[['longitude', 'latitude']].values.max(axis=0)
  bounding_box = ee.Geometry.Rectangle([min_lon, min_lat, max_lon, max_lat])
  print('min lon,lat:', min_lon, min_lat)
  print('max lon, lat:', max_lon, max_lat)
  return bounding_box


def get_landsat_tiles(landsat_path, bounding_box, start_date, end_date, overlap):
  '''
  fetch tiles from path via GEE and filter based on overlap with bounding box
  '''
  landsat = (ee.ImageCollection(landsat_path).filterBounds(bounding_box).filterDate(start_date, end_date))
  bounding_box_area = bounding_box.area()

  def compute_overlap(image):
      overlap = image.geometry().intersection(bounding_box).area().divide(bounding_box_area)
      return image.set('overlap_fraction', overlap)

  landsat_filt = landsat.map(compute_overlap).filter(ee.Filter.gte('overlap_fraction', overlap))
  print('number of days:', landsat_filt.size().getInfo())
  return landsat, landsat_filt


def get_avail_landsat_dates(landsat_filt):
  '''
  get dates at which landsat images are available (use correct timezone!!)
  '''
  avail_dates = landsat_filt.aggregate_array('system:time_start').map(
      lambda d: ee.Date(d).format('YYYY-MM-dd HH:mm:ss', 'America/Los_Angeles')
  ).distinct().getInfo()
  return avail_dates

In [None]:
start_date = '2018-01-01'
end_date = '2018-12-31'

### Patch Extraction

In [None]:
bounding_box = get_bounding_box(df)
landsat, landsat_filt = get_landsat_tiles("LANDSAT/LC08/C02/T1_L2", bounding_box, start_date, end_date, 0.7)
avail_dates = get_avail_landsat_dates(landsat_filt)
print(avail_dates)

In [None]:
def extract_bands(landsat):
  '''
  extract, preprocess, and scale landsat bands
  '''
  # Develop masks for unwanted pixels (fill, cloud, cloud shadow).
  qa_mask = landsat.select('QA_PIXEL').bitwiseAnd(0b11111).eq(0)
  saturation_mask = landsat.select('QA_RADSAT').eq(0)

  # Apply the scaling factors to the appropriate bands.
  def _get_factor_img(factor_names):
      factor_list = landsat.toDictionary().select(factor_names).values()
      return ee.Image.constant(factor_list)

  scale_img = _get_factor_img([
      'REFLECTANCE_MULT_BAND_.|TEMPERATURE_MULT_BAND_ST_B10'])
  offset_img = _get_factor_img([
      'REFLECTANCE_ADD_BAND_.|TEMPERATURE_ADD_BAND_ST_B10'])
  scaled = landsat.select('SR_B.|ST_B10').multiply(scale_img).add(offset_img)

  landsat = landsat.addBands(scaled, None, True).updateMask(
      qa_mask).updateMask(saturation_mask)

  red = landsat.select('SR_B4').rename('Red')
  green = landsat.select('SR_B3').rename('Green')
  blue = landsat.select('SR_B2').rename('Blue')
  swir1 = landsat.select('SR_B6').rename('SWIR1')
  swir2 = landsat.select('SR_B7').rename('SWIR2')
  lst = landsat.select('ST_B10').rename('LST').subtract(273.15)

  # compute indices
  ndvi = landsat.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')
  ndbi = landsat.normalizedDifference(['SR_B6', 'SR_B5']).rename('NDBI')
  nbai = landsat.normalizedDifference(['SR_B7', 'SR_B5']).rename('NBAI')
  mndwi = landsat.normalizedDifference(['SR_B3', 'SR_B6']).rename('MNDWI')

  band_dict = {
      'NIR': landsat.select('SR_B5'),
      'RED': landsat.select('SR_B4'),
      'BLUE': landsat.select('SR_B2')
  }

  evi = landsat.expression('2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))', band_dict).rename('EVI')
  savi = landsat.expression('((NIR - RED) / (NIR + RED + 0.5)) * 1.5', band_dict).rename('SAVI')

  return lst.addBands([red, green, blue, swir1, swir2, ndvi, evi, savi, mndwi, nbai, ndbi])

In [None]:
def extract_patch(fcd_feature, bands_date, patch_size):
  '''
  extract an image patch/sample a rectangle at the given target location,
  ensure all aptch sizes are the same
  '''
  patch_square = ee.Geometry.Point(fcd_feature['geometry']['coordinates']).buffer(patch_size/2 * 31).bounds() # extra 1 to ensure bigger & clip instead of clip
  patch = bands_date.sampleRectangle(region=patch_square, defaultValue=0, defaultArrayValue=0)
  patch_data = patch.getInfo()

  patch_image = []
  for i, band in enumerate(patch_data['properties']):
      band_patch = np.array(patch_data['properties'][band])

      # sometimes incorrectly sized, crop all to 64x64
      height, width = band_patch.shape
      if height > patch_size or width > patch_size:
          start_x = (width - patch_size) // 2 if width > patch_size else 0
          start_y = (height - patch_size) // 2 if height > patch_size else 0
          band_patch = band_patch[start_y:start_y + patch_size, start_x:start_x + patch_size]
      if height < patch_size or width < patch_size:
          band_patch = np.pad(band_patch, ((max(0, (patch_size - height) // 2), max(0, (patch_size - height) // 2)),
            (max(0, (patch_size - width) // 2), max(0, (patch_size - width) // 2))),mode='constant', constant_values=0)

      patch_image.append(band_patch)

  patch_image = np.stack(patch_image, axis=-1)
  patch_image = patch_image / 10000

  return patch_image

In [None]:
def sample_bands_via_target_location_as_patches(landsat_filt, df, patch_size):
  '''
  for the given available date, extract landsat bands and terrain images, and sample patches given target locations
  '''
  arr_temp = []
  arr_date = []
  arr_image = []

  for i in range(landsat_filt.size().getInfo()):
      image = ee.Image(landsat_filt.toList(1, i).get(0))
      datet = ee.Date(image.get('system:time_start')).format('YYYY-MM-dd').getInfo()
      print(datet)

      features = []
      # for _, row in df[df.date == datetime.strptime(datet, '%Y-%m-%d').date()].iterrows(): # if df.date is datetime
      for _, row in df[df.date == datet].iterrows(): # if df.date is string
          point = ee.Geometry.Point([row['longitude'], row['latitude']])
          feature = ee.Feature(point, {
              "temperature": row["temperature"],
              "sid": row["sid"],
              "date": str(row["date"])
          })
          features.append(feature)

      fc_date = ee.FeatureCollection(features)
      bands_date = extract_bands(image)

      # extra non-temporally dependent
      srtm = ee.Image("USGS/SRTMGL1_003").select("elevation").rename("Elevation")
      elevation = srtm.clip(bounding_box)
      slope = ee.Terrain.slope(srtm).rename("Slope")
      slope = slope.clip(bounding_box)
      # distance to water using ESA WorldCover (water mask = 80), resample to 30m res
      water = ee.Image("ESA/WorldCover/v200/2021").select("Map").eq(80).selfMask()
      water_dist = water.fastDistanceTransform().rename("Water_dist")
      water_dist = water_dist.reduceResolution(reducer=ee.Reducer.mean(), bestEffort=True).reproject(crs=bands_date.projection(), scale=30)

      bands_date = bands_date.addBands([elevation, slope, water_dist])

      # parallelise
      with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = []
        # sample patch at each feature

        for fcd_feature in fc_date.getInfo()['features']:
            future = executor.submit(extract_patch, fcd_feature, bands_date, patch_size)
            futures.append(future)

            # encode date
            props = fcd_feature['properties']
            sid = props['sid']
            date = props['date']
            temperature = props['temperature']

            dayofyear = datetime.strptime(date, '%Y-%m-%d').timetuple().tm_yday
            sintime = np.sin(2 * np.pi * dayofyear / 365)
            costime = np.cos(2 * np.pi * dayofyear / 365)

            arr_temp.append(temperature)
            arr_date.append([sintime, costime])

        for future in concurrent.futures.as_completed(futures):
            print('.', end='', flush=True)
            patch_image = future.result()
            arr_image.append(patch_image)

  X_image = torch.tensor(np.array(arr_image), dtype=torch.float32)
  # date array implementaiton is actually wrong, but is fixed in the notebook applying the actual data
  X_date = torch.tensor(np.array(arr_date), dtype=torch.float32)
  Y_temp = torch.tensor(np.array(arr_temp), dtype=torch.float32)

  return X_image, date, Y_temp

In [None]:
patch_size = 64
X_image, X_date, Y_temp = sample_bands_via_target_location_as_patches(landsat_filt, df, patch_size)

#### Size & visualisation checks

In [None]:
print(X_image.shape)
print(Y_temp.shape)

In [None]:
# to plot a one "row", i.e. all bands (no ittles)
patch_bands = X_image[28]
fig, axs = plt.subplots(3, 5, figsize=(15, 8))
axs = axs.flatten()

for i in range(15):
  axs[i].imshow(patch_bands[:,:,i])

fig.show()

In [None]:
torch.save(X_image, '/content/drive/64/X_image_2018_64.pt')
# torch.save(X_date, '/content/drive/64/X_date_2017_64.pt') # wrong hihi
torch.save(Y_temp, '/content/drive/64/Y_temp_2018_64.pt')