In [None]:
# import, authenticate and initialize ee
import ee
ee.Initialize()

In [None]:
# for S1_ARD correction
# make sure all the python-api files from GitHub is in path
# https://github.com/adugnag/gee_s1_ard
import wrapper as wp
import border_noise_correction as bnc
import speckle_filter as sf
import terrain_flattening as trf
import helper

In [None]:
ee_path = 'users/balakumaran247/swSentinel/state'

In [None]:
# state names as in SOI fc
states = {
    "KA": "KARN>TAKA",
    "TN": "TAMIL N>DU",
    "AP": "ANDHRA PRADESH"
    }

In [None]:
def filter_geometry(state_name: str) -> ee.Geometry:
  return (ee.FeatureCollection(
      'users/balakumaran247/swSentinel/SOI_States').filter(
          ee.Filter.eq("STATE", state_name)).geometry())

In [None]:
def fetch_dw(geometry: ee.Geometry) -> ee.ImageCollection:
  return ee.ImageCollection(
    'projects/wri-datalab/dynamic_world/v1/DW_LABELS'
    ).filterBounds(geometry)

In [None]:
# Function to fetch Sentinel-2 images by 'PRODUCT_ID'
def fetch_matching_s2_image(image):
    # Get the 'S2_PRODUCT_ID' property
    product_id = image.get('S2_PRODUCT_ID')

    # Filter the Sentinel-2 Surface Reflectance dataset by 'PRODUCT_ID'
    sentinel2_image = ee.ImageCollection('COPERNICUS/S2_SR') \
        .filterMetadata('PRODUCT_ID', 'equals', product_id) \
        .first()

    return sentinel2_image

In [None]:
def fetch_s2(icoll: ee.ImageCollection) -> ee.ImageCollection:
  return icoll.map(fetch_matching_s2_image)

In [None]:
# Function to calculate indices for Sentinel-2 Surface Reflectance Images
def calculate_indices(image):
    # Calculate NDVI
    ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI')

    # Calculate MNDWI
    mndwi = image.normalizedDifference(['B3', 'B11']).rename('MNDWI')

    # Calculate EVI
    evi = image.expression(
        '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))',
        {
            'NIR': image.select('B8'),
            'RED': image.select('B4'),
            'BLUE': image.select('B2')
        }
    ).rename('EVI')

    # Calculate NDTI
    ndti = image.normalizedDifference(['B3', 'B12']).rename('NDTI')

    # Calculate NDBI
    ndbi = image.normalizedDifference(['B11', 'B8']).rename('NDBI')

    # Calculate DI
    di = image.expression(
        '((SWIR1 + NIR) - (GREEN + SWIR2)) / ((SWIR1 + NIR) + (GREEN + SWIR2))',
        {
            'SWIR1': image.select('B11'),
            'NIR': image.select('B8'),
            'GREEN': image.select('B3'),
            'SWIR2': image.select('B12')
        }
    ).rename('DI')

    # Return the image with calculated indices
    return image.addBands([ndvi, mndwi, evi, ndti, ndbi, di])

In [None]:
def find_closest_s1(image):
    # Get the acquisition date of the image
    date = ee.Date(image.get('system:time_start'))

    # Filter the Sentinel-1 collection
    sar_image = ee.ImageCollection('COPERNICUS/S1_GRD_FLOAT')\
        .filterDate(date.advance(-15, 'day'), date.advance(15, 'day')) \
        .filterBounds(image.geometry())\
        .filter(ee.Filter.eq('instrumentMode', 'IW'))\
        .filter(ee.Filter.eq('resolution_meters', 10)) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))\
        .select(['VV', 'VH', 'angle'])

    return sar_image

In [None]:
def fetch_s1(icoll: ee.ImageCollection) -> ee.ImageCollection:
  return ee.ImageCollection(icoll.map(find_closest_s1).flatten())

In [None]:
def s1_correction(icoll: ee.ImageCollection) -> ee.ImageCollection:
  SPECKLE_FILTER = 'GAMMA MAP'
  SPECKLE_FILTER_KERNEL_SIZE = 9
  SPECKLE_FILTER_NR_OF_IMAGES = 10
  TERRAIN_FLATTENING_MODEL = 'VOLUME'
  DEM = ee.Image('USGS/SRTMGL1_003')
  TERRAIN_FLATTENING_ADDITIONAL_LAYOVER_SHADOW_BUFFER = 0
  s1_1 = icoll.map(bnc.f_mask_edges)
  s1_1 = ee.ImageCollection(sf.MultiTemporal_Filter(s1_1, SPECKLE_FILTER_KERNEL_SIZE, SPECKLE_FILTER, SPECKLE_FILTER_NR_OF_IMAGES))
  s1_1 = ee.ImageCollection(trf.slope_correction(s1_1,TERRAIN_FLATTENING_MODEL,DEM,TERRAIN_FLATTENING_ADDITIONAL_LAYOVER_SHADOW_BUFFER))
  return s1_1.map(helper.lin_to_db)

In [None]:
def calculate_combinations(image):
    # Calculate VH/VV ratio
    vh_vv_ratio = image.select('VV').divide(image.select('VH')).rename('VH_VV_Ratio')

    # Calculate VV/VH ratio
    vv_vh_ratio = image.select('VH').divide(image.select('VV')).rename('VV_VH_Ratio')

    # Calculate VV+VH
    vv_plus_vh = image.select('VV').add(image.select('VH')).rename('VV_Plus_VH')

    # Calculate VV*VH
    vv_times_vh = image.select('VV').multiply(image.select('VH')).rename('VV_Times_VH')

    # Calculate VV*VV
    vv_times_vv = image.select('VV').multiply(image.select('VV')).rename('VV_Times_VV')

    # Calculate VH*VH
    vh_times_vh = image.select('VH').multiply(image.select('VH')).rename('VH_Times_VH')

    # Return the image with calculated indices or band combinations
    return image.addBands([vh_vv_ratio, vv_vh_ratio, vv_plus_vh, vv_times_vh, vv_times_vv, vh_times_vh])

In [None]:
def preprocess_mosaic_s1(s1_coll: ee.ImageCollection) -> ee.Image:
  return s1_coll.map(calculate_combinations).mean()

In [None]:
def export_image(image, geometry, state_abb, satellite_name):
  task = ee.batch.Export.image.toAsset(
    image = image,
    description = f'{state_abb}_{satellite_name}_image',
    assetId = f'{ee_path}/{state_abb}_{satellite_name}_image',
    crs = 'EPSG:4326',
    scale = 10,
    region = geometry,
    maxPixels = 1e13
  )
  task.start()

In [None]:
for state_abb, state_name in states.items():
  geometry = filter_geometry(state_name)
  dw_filtered = fetch_dw(geometry)
  s2_collection = fetch_s2(dw_filtered)
  s2_final = s2_collection.map(calculate_indices).qualityMosaic('MNDWI')
  s1_collection = fetch_s1(dw_filtered)
  s1_corrected = s1_correction(s1_collection)
  s1_final = preprocess_mosaic_s1(s1_corrected)
  # export_image(s2_final, geometry, state_abb, 's2')
  # export_image(s1_final, geometry, state_abb, 's1')
  
  # After Feature selection reran below statements
  s2_bands = ['B8', 'MNDWI', 'NDTI', 'DI', 'NDVI']
  s1_bands = ['VV_Plus_VH', 'VV_Times_VH', 'VH_Times_VH', 'VV_Times_VV']
  export_image(s2_final.select(s2_bands), geometry, state_abb, 's2')
  export_image(s1_final.select(s1_bands), geometry, state_abb, 's1')