In [None]:
## MSc_Thesis
## Jacotte Monroe 
## 19/03/24

## Script containing all Python functions necessary in 3_RetrieveMODISImages.ipynb 


# function to extract and combined relevant QA band bits/flags 
# source: https://mygeoblog.com/2017/09/08/modis-cloud-masking/
def getQABits(image, start_bit, end_bit, band_name):
    pattern = 0
    
    # for each bit/flag of the QA band --> assign new value to its bits 
    # so each flag will have a different value and the pixel will be the sum of all flags 
    for i in range(start_bit, end_bit):
        pattern += math.pow(2,i)
    # return single band image of extracted QA bits
    # source: https://developers.google.com/earth-engine/apidocs/ee-image-bitwiseand
    return image.select([0], [band_name]) \
                .bitwiseAnd(pattern) \
                .rightShift(start_bit)



# function to mask out cloudy pixels 
# source: https://gis.stackexchange.com/questions/308456/get-a-mask-for-modis-250m-mod09gq-using-modis-500m-mod09ga-in-google-earth
# source: https://mygeoblog.com/2017/09/08/modis-cloud-masking/
def maskClouds(image):
    # selects the MODIS 500m QA band 
    QA = image.select('state_1km')

    # creates a cloud&shadow flag from specified bits --> in this case: 'Cloud state' and 'Cloud shadow'
    pixelQuality = getQABits(QA, 0, 2, 'cloud_and_shadow_quality_flag')

    # returns image masking out cloudy pixels 
    return image.updateMask(pixelQuality.eq(0))



# function to add timestamp band to image
# source: https://spatialthoughts.com/2021/11/08/temporal-interpolation-gee/
def addTimestamp(image): 
    # create new image where pixel value = time of original image
    timeImage = image.metadata('system:time_start').rename('timestamp')

    # mask new time image with original image to remove cloudmasked pixels
    timeImageMasked = timeImage.updateMask(image.mask().select(0))

    # return original image with time image as new band 
    return image.addBands(timeImageMasked)



# function that takes image and replaces masked pixels with linearly interpolated values from bef/aft images
def interpolateImage(image):
    image = ee.Image(image)

    # get list of before/after images from image property
    beforeImages = ee.List(image.get('before'))
    afterImages = ee.List(image.get('after'))

    # create image collection of before/after images
    # mosaic() combines images into one image accordint to their position in collection 
    #  image first has all pixels from last image in collection 
    #  gaps filled with second to last image from collection ...
    beforeMosaic = ee.ImageCollection.fromImages(beforeImages).mosaic()
    afterMosaic = ee.ImageCollection.fromImages(afterImages).mosaic()

    # rename time band of images 
    time_bef = beforeMosaic.select('timestamp').rename('time_bef')
    time_aft = afterMosaic.select('timestamp').rename('time_aft')
    time0 = image.metadata('system:time_start').rename('time0')

    # combine all three single band time images into one image with three time bands 
    timeImage = ee.Image.cat([time_bef, time_aft, time0])

    # compute image of interpolated surface reflectance values 
    timeRatio = timeImage.expression('(time0 - time_bef) / (time_aft - time_bef)', \
                    {'time0': timeImage.select('time0'), 
                     'time_bef': timeImage.select('time_bef'), 
                     'time_aft': timeImage.select('time_aft')})

    interpolated = beforeMosaic.add((afterMosaic.subtract(beforeMosaic).multiply(timeRatio)))

    # replace masked pixels in current image with pixels from interpolated mosaic
    result = image.unmask(interpolated)

    # return gap-filled image
    return result.copyProperties(image, ['system:time_start'])



# function to reproject (elephant fixes reprojected from 4326 to 32733 same needs to be done to images) 
def reprojectModis(image):
    return image.reproject('EPSG:32733', None, 250)



# function to clip image to study area
def clipToAOI(image): 
    result = image.clip(bbox)
    
    # test
    #result = image.clip(large_region)
    return result.copyProperties(image, ['system:id'])



# function to calculate NDVI 
def addNDVI(image): 
    ndvi = image.normalizedDifference(['nir', 'red']).rename('NDVI')
    return image.addBands(ndvi)