In [43]:
%config IPCompleter.greedy=True

In [44]:
from pprint import pprint
import folium
import geopy.distance
import math
import numpy as np
import os
import time
import datetime
from IPython import embed
import rasterio as rio
import shutil
import fileinput
from rasterio.windows import Window
from matplotlib import pyplot
import subprocess
import logging

In [None]:
import ee
ee.Authenticate()

In [4]:
ee.Initialize()

In [5]:
# logging setup
logging.basicConfig(filename='PSRFM_runner.log', level=logging.WARNING)

In [6]:
# constants for each satellite
# potentially think about adding date ranges fetched from the google earth engine info pages to help debugging
synthesized_mask_name = 'cloud_mask_pixelqa'
ls5 = {
    'imagecollection_id' : 'LANDSAT/LT05/C01/T1_SR',
    'pixel_size' : 30,
    'bands' : ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
    'mask_band' : ['pixel_qa'],
    'mask_bits_clear' : 1,
    'time_field' : 'system:time_start',
    'ee_id_date_format' : 'YYYYMMdd',
    'id_date_format' : '%Y%m%d',
    'id_date_indexes': (-8, 0)
}
ls8 = {
    'imagecollection_id' : 'LANDSAT/LC08/C01/T1_SR',
    'pixel_size' : 30,
    'bands' : ['B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
    'mask_band' : ['pixel_qa'], 
    'mask_bits_clear' : 1,
    'time_field' : 'system:time_start',
    'ee_id_date_format' : 'YYYYMMdd',
    'id_date_format' : '%Y%m%d',
    'id_date_indexes': (-8, 0)
}
s2 = {
    'imagecollection_id' : 'COPERNICUS/S2_SR',
    'pixel_size' : 20,
    'bands' : ['B2', 'B3', 'B4', 'B8A', 'B11', 'B12'], #check up msk and qa60 msk is 20m qa60 is 60m
    'mask_band' : ['QA60'],
    'mask_bits_clear' : (10, 11), #https://gis.stackexchange.com/questions/333883/removing-clouds-from-sentinel-2-surface-reflectance-in-google-earth-engine
    'time_field' : 'system:time_start',
    'ee_id_date_format' : 'YYYYMMdd',
    'id_date_format' : '%Y%m%d',
    'id_date_indexes': (16, 24)
}
modis = {
    'imagecollection_id' : 'MODIS/006/MCD43A4',
    'pixel_size' : 500,
    'bands' : ['Nadir_Reflectance_Band3', 'Nadir_Reflectance_Band4', 'Nadir_Reflectance_Band1', 'Nadir_Reflectance_Band2', 'Nadir_Reflectance_Band6', 'Nadir_Reflectance_Band7'],
    'time_field' : 'system:time_start',
    'id_date_format' : '%Y_%m_%d',
    'id_date_indexes': (-10, 0)
}
image_sets = {
    'ls5' : ls5,
    'ls8' : ls8,
    's2' : s2,
    'modis' : modis
}

In [36]:
# User inputted data
# seattle for rainy boi hrs
# west_long = -123.40369004277056 - 0.25
# east_long = -123.31133622196978
# north_lat = 47.254959278137015
# south_lat = 47.19573761464784 - 0.25

# west_long = 129.22395490745654
# east_long = 129.70735334495654 - 0.3
# north_lat = -20.984313762127464 
# south_lat = -21.39916978624304 + 0.25

# area south of kelowna
west_long = -118.45721879827498
east_long = -118.29671541082381
north_lat = 48.81204219953305
south_lat = 48.70980208174063

# straya for clear boi hrs
# west_long = 133.84410607205484
# east_long = west_long+(134.7422383962736 - west_long)/8 #temporary just to make a smaller area
# north_lat = -22.638743081339985
# south_lat = north_lat - (north_lat - (-23.52560283638598))/10 #temporary just to make a smaller area

# dates
# date range must be within same calendar year
date_range = ('2019-07-01', '2019-12-12') 
# date_range = ('2017-08-25','2017-12-31')
prediction_dates = ['2019-08-30', '2019-11-23', '2014-09-01']
# If reference images aren't found for a prediction date skip those tiles, if this is placed at false, the code will stop running when finding images for the tiles
skip_if_reference_not_found = True

# dataset selections
satellite_choice = 's2'
overlap_fraction = 1 # how much of the minimum block size to add for the overlap. Default is one.
percent_cloudy = 5 # Recommending 5 percent
drive_folder = 'test6KelownaS2Aug30Nov23'
drive_folder_location = 'C:\\Users\\karan\\Google Drive'

# calculated baseline variables from user data
corner_coords = [[east_long, south_lat], [west_long, south_lat], [west_long, north_lat], [east_long, north_lat], [east_long, south_lat]]
block_size = math.floor(image_sets['modis']['pixel_size']/image_sets[satellite_choice]['pixel_size'])
# minimum tile size is 10 MODIS pixels, this following formula just accounts for how many pixels we need at minimum to have it as a multiple of block size which is a PSRFM req
min_tile_dim_px = block_size*10
min_tile_dim_km = (min_tile_dim_px * image_sets[satellite_choice]['pixel_size'])/1000
min_tile_dim_km, east_long, block_size, west_long, east_long, north_lat, south_lat, min_tile_dim_px

# paths for local PSRFM processing and conversions
tif_paths = {
    'fine_res_path' : f'{drive_folder_location}\\{drive_folder}_{satellite_choice}',
    'mask_path' : f'{drive_folder_location}\\{drive_folder}_{satellite_choice}_mask',
    'coarse_res_path' : f'{drive_folder_location}\\{drive_folder}_modis'
}
psrfm_info_dst = {
    'PSRFM_exe_path' : 'C:\\Users\\karan\\OneDrive\\Documents\\CCRS_2A_COOP\\PSRFM_Wrapper\\PSRM_exe_params',
    'dst_base_path' : 'C:\\Users\\karan\\OneDrive\\Documents\\CCRS_2A_COOP\\PSRFM_Testing\\SentinelKelownaTest6',
}
psrfm_info_dst['dst_input_path'] = psrfm_info_dst['dst_base_path'] + '\\input'
psrfm_info_dst['dst_param_path'] = psrfm_info_dst['dst_base_path'] + '\\params'
psrfm_info_dst['dst_output_path'] = psrfm_info_dst['dst_base_path'] + '\\output'
psrfm_info_dst['dst_output_gtif_path'] = psrfm_info_dst['dst_output_path'] + '\\geotiffs'
psrfm_info_dst['dst_temp_path'] = psrfm_info_dst['dst_base_path'] + '\\temp'



In [37]:
# Creating the geometry for tiles within the specified coordinates
# first calculate the side lengths of the region selected
x_dist = geopy.distance.geodesic([south_lat, east_long], [south_lat, west_long]).km
y_dist = geopy.distance.geodesic([north_lat, east_long], [south_lat, east_long]).km

# determine the number of tile segments to fully cover the region, rounded up to ensure overlap
# need to look at this later to ensure that each tile has enough of an overlap to ensure pixels are a multiple of block size (aka cropping)
x_tile_segments = math.floor(x_dist/min_tile_dim_km)
y_tile_segments = math.floor(y_dist/min_tile_dim_km)

# generate a list of ordered coordinates (west to east, north to south) based on the number of tiles
# creating an overlap of approx 1 block_size pixels to ensure that each tile after cropping for PSRFM will still retain some overlap
km_long = abs(east_long - west_long)/x_dist
km_lat = abs(north_lat - south_lat)/y_dist
long_overlap = ((overlap_fraction * block_size * image_sets[satellite_choice]['pixel_size'])/1000) * km_long
lat_overlap = ((overlap_fraction * block_size * image_sets[satellite_choice]['pixel_size'])/1000) * km_lat

# determining the coordinate jumps for each tile(sans overlap)
x_coord_increment = abs(east_long - west_long)/x_tile_segments + long_overlap
y_coord_increment = abs(north_lat - south_lat)/y_tile_segments + lat_overlap

# creating the lists
west_tile_coords = [east_long - (tile_no + 1) * x_coord_increment for tile_no in reversed(range(x_tile_segments))]
east_tile_coords = [west_long + (tile_no + 1) * x_coord_increment for tile_no in range(x_tile_segments)]

north_tile_coords = [south_lat + (tile_no + 1) * y_coord_increment for tile_no in reversed(range(y_tile_segments))]
south_tile_coords = [north_lat - (tile_no + 1) * y_coord_increment for tile_no in range(y_tile_segments)]

tiles = np.empty((x_tile_segments, y_tile_segments), ee.Geometry)
for col in range(x_tile_segments):
    east_coord = east_tile_coords[col]
    west_coord = west_tile_coords[col]
    for row in range(y_tile_segments):
        north_coord = north_tile_coords[row]
        south_coord = south_tile_coords[row]
        tiles[col, row] = ee.Geometry.Rectangle([west_coord, south_coord, east_coord, north_coord])
        
x_dist, y_dist, min_tile_dim_km, x_coord_increment, y_coord_increment, west_tile_coords, east_tile_coords, north_tile_coords, south_tile_coords, tiles, x_tile_segments, y_tile_segments


(11.812396800216147,
 11.36962287214396,
 5.0,
 0.0870455472578443,
 0.05561625542542795,
 [-118.4708065053395, -118.38376095808165],
 [-118.37017325101714, -118.28312770375929],
 [48.82103459259148, 48.765418337166054],
 [48.756425944107626, 48.7008096886822],
 array([[<ee.geometry.Geometry object at 0x000001AF05AE1F48>,
         <ee.geometry.Geometry object at 0x000001AF05AE13C8>],
        [<ee.geometry.Geometry object at 0x000001AF05AE1248>,
         <ee.geometry.Geometry object at 0x000001AF05AE19C8>]],
       dtype=object),
 2,
 2)

In [38]:
# functions to add bands
# source: https://gis.stackexchange.com/questions/277059/cloud-mask-for-landsat8-on-google-earth-engine/277151
# second source: https://mygeoblog.com/2019/07/25/working-with-bitmasks/

def getQABitsls(image_qa, clear_bit, new_band_name):
    pattern = 0
#   landsat images only have one bit that needs to be masked so they're presented as integers in the ls8 bit
    if type(clear_bit) == int:
        pattern = pow(2, clear_bit)
        return image_qa.addBands(image_qa.select(['pixel_qa'], [new_band_name]).bitwiseAnd(pattern).rightShift(clear_bit).eq(0))
#  sentinel-2 images have 2 bits for cloud masking
    if type(clear_bit) == tuple:
        pattern0 = pow(2, clear_bit[0])
        pattern1 = pow(2, clear_bit[1])
        return image_qa.addBands(image_qa.select(['QA60'], [new_band_name])
                                 .bitwiseAnd(pattern0)
                                 .bitwiseOr(image_qa.select(['QA60'], [new_band_name]).bitwiseAnd(pattern1))
                                 .bitwiseOr(image_qa.select(['B2']).eq(0))
                                 .eq(0).eq(0)) #turns any non zero values into 0's and then inverts it

def getPercentageAreaInvalid(imgwithmask, mask_name, tile_geometry):
#   retrieves the proportion of image compared to it's clipped tile that is viable to process (includes areas not covered by image as well as cloud cover)
    cloudpercentage = ee.Number(imgwithmask.reduceRegion(reducer = ee.Reducer.mean()).get(mask_name))
    areaoftile = tile_geometry.area()
    areaofimage = imgwithmask.geometry().area()
    proportionoftileinvalid = ee.Number(1).subtract(areaofimage.divide(areaoftile))
    percentage_covered = cloudpercentage.add(proportionoftileinvalid)
    return imgwithmask.set('calculated_invalid', percentage_covered)
    
    
# def getPercentageClear(imgwithmask, mask_name):
#     reducedimg = imgwithmask.reduceRegion(reducer = ee.Reducer.mean())
#     return imgwithmask.set('calculated_cloud', reducedimg.get(mask_name))

In [39]:
# pprint(tiles[0,0].getInfo())
# pprint(tiles[0,1].getInfo())
# pprint(tiles[1,1].getInfo())
# pprint(tiles[1,0].getInfo())

In [40]:
# generating arrays of the fine and coarse res imagecollections
fine_res_tiles = np.empty((x_tile_segments, y_tile_segments), ee.ImageCollection)
# array to track which dates of fine res images are used to get the MODIS images corresponding to the dates
fine_res_dates = np.empty((x_tile_segments, y_tile_segments), dtype = list)

for col in range(x_tile_segments):
    for row in range(y_tile_segments):
        #get a collection for all images within the date range, clipped to region and filtered for cloud cover
        initial_collection = ee.ImageCollection(image_sets[satellite_choice]['imagecollection_id'])\
                                .filterBounds(tiles[col, row])\
                                .filterDate(*date_range)\
                                .map(lambda image: image.clip(tiles[col, row]))\
                                .map(lambda image: getQABitsls(image, image_sets[satellite_choice]['mask_bits_clear'], synthesized_mask_name))\
                                .map(lambda image: getPercentageAreaInvalid(image.clip(tiles[col, row]), synthesized_mask_name, tiles[col, row]))\
                                .filterMetadata('calculated_invalid', 'less_than', percent_cloudy/100)
        selected_images = []
        fine_res_dates[col][row] = []
        for date in prediction_dates:
#           create a field calculating the distance of images from each prediction date, and find the lowest two for each date to use for PSRFM (having to use the ID as the system start time is when the picture was taken rather than the date, which PSRFM needs them to be different)
            id_date_start_index = image_sets[satellite_choice]['id_date_indexes'][0]
            id_date_end_index = image_sets[satellite_choice]['id_date_indexes'][1]
            if id_date_end_index == 0:
                initial_collection = initial_collection.map(
                    lambda image: image.set(f'dateDist{date}', 
                                            ee.Date.difference(
                                                ee.Date.parse('YYYYMMdd', ee.String.slice(ee.String(image.id()), id_date_start_index)),
                                                ee.Date(date), 'day')
                                           )
                )
            else:
                initial_collection = initial_collection.map(
                    lambda image: image.set(f'dateDist{date}', 
                                            ee.Date.difference(
                                                ee.Date.parse('YYYYMMdd', ee.String.slice(ee.String(image.id()), id_date_start_index, id_date_end_index)),
                                                ee.Date(date), 'day')
                                           )
                )

#             pprint(initial_collection.filterMetadata(f'dateDist{date}', 'less_than', 0).sort(f'dateDist{date}', False).getInfo()['features'])
#             initial_collection = initial_collection.map(
#                 lambda image: image.set(f'dateDist{date}', 
#                                         ee.Number(image.get('system:time_start'))
#                                         .subtract(ee.Date.millis(ee.Date(date)))
#                                        )
#             )
#           aggregate dates then filter an imagecollection
#             for feature in (initial_collection.filterMetadata(f'dateDist{date}', 'less_than', 0).sort(f'dateDist{date}', False).getInfo()['features']):
#                 pprint(feature['id'])
            
            
            try:
                first_img_before = initial_collection.filterMetadata(f'dateDist{date}', 'less_than', 0).sort(f'dateDist{date}', False).getInfo()['features'][0]
# debugging lines 2
#                 if col == row == 1:
#                     pprint(first_img_before['id'])
                    
            except IndexError:
                if skip_if_reference_not_found:
                    first_img_before = None
                    logging.warning(f"For Col {col} Row {row} on {date}, reference image before {date} not found and hence skipped")
                else:
                    raise Exception(f"Reference image before {date} not found")
                    
            try:
                first_img_after = initial_collection.filterMetadata(f'dateDist{date}', 'greater_than', 0).sort(f'dateDist{date}', True).getInfo()['features'][0]

            except IndexError:
                if skip_if_reference_not_found:
                    first_img_after = None
                    logging.warning(f"For Col {col} Row {row} on {date}, reference image after {date} not found and hence skipped")
                else:
                    raise Exception(f"Reference image after {date} not found")

            
#           if there isn't a valid pair of images don't add them anywhere
            if first_img_before != None and first_img_after != None:
                selected_images.append(first_img_before['id'])
                selected_images.append(first_img_after['id'])

    #           Insert the dates selected for PSRFM to track which MODIS images to extract later
                if id_date_end_index == 0:
                    first_image_date = first_img_before['id'][id_date_start_index:]
                    second_image_date = first_img_after['id'][id_date_start_index:]
                else:
                    first_image_date = first_img_before['id'][id_date_start_index+1:id_date_end_index+1]
                    second_image_date = first_img_after['id'][id_date_start_index+1:id_date_end_index+1]
    #             make this a triplet
                fine_res_dates[col][row].append([date.replace('-', ''), first_image_date, second_image_date])
#         shouldn't need this bit anymore, ignored tiles are logged
#             else:
# #                 append null to fine res dates
#                 fine_res_dates[col][row].append(None)
#note: removes duplicate images
        selected_images = list(dict.fromkeys(selected_images))
        fine_res_tiles[col, row] = ee.ImageCollection(selected_images).map(lambda image: image.clip(tiles[col, row])) \
                                     .map(lambda image: getQABitsls(image, image_sets[satellite_choice]['mask_bits_clear'], synthesized_mask_name))

# pprint((fine_res_tiles[0,0].getInfo()['features']))
# for dates in fine_res_dates[0,0]:
#     for date in dates:
#         datetime
#     print(date)
#     print('pls')
# pprint(fine_res_dates)
# pprint((coarse_res_tiles[0,0].getInfo()['features']))

'COPERNICUS/S2_SR/20190828T184921_20190828T185739_T11ULP'
'COPERNICUS/S2_SR/20191121T185659_20191121T185702_T11ULP'


In [41]:
coarse_res_tiles = np.empty((x_tile_segments, y_tile_segments), ee.ImageCollection)
for col in range(x_tile_segments):
    for row in range(y_tile_segments):
        collection = (ee.ImageCollection(image_sets['modis']['imagecollection_id'])
                        .filterBounds(tiles[col, row])
                        .filterDate(*date_range)
                        .sort(image_sets['modis']['time_field'])
                        .map(lambda image: image.clip(tiles[col, row])))
        modis_name_format = collection.first().getInfo()['id'][:image_sets['modis']['id_date_indexes'][0]]
        selected_modis_images = []
        
        for date_set in fine_res_dates[col, row]:
            for date in date_set:
                selected_modis_images.append(modis_name_format + date[:4] + '_' + date[4:6] + '_' + date[6:])
#       ensures unique dates so no duplicates
        selected_modis_images = list(dict.fromkeys(selected_modis_images))
        coarse_res_tiles[col, row] = ee.ImageCollection(selected_modis_images).map(lambda image: image.clip(tiles[col, row]))
pprint(selected_modis_images)

['MODIS/006/MCD43A4/2019_08_30',
 'MODIS/006/MCD43A4/2019_08_28',
 'MODIS/006/MCD43A4/2019_08_31',
 'MODIS/006/MCD43A4/2019_11_23',
 'MODIS/006/MCD43A4/2019_11_21',
 'MODIS/006/MCD43A4/2019_11_24']


In [42]:
# export imagecollection tile array function
fine_res_filenames = []
fine_res_mask_filenames = []
coarse_res_filenames = []

def date_to_day_of_year(date, format='%Y%m%d'):
    date = datetime.datetime.strptime(date, format=format)
    new_year_day = datetime.datetime(year=date.year, month=1, day=1)
    return (date - new_year_day).days + 1
# pprint(ee.Image(fine_res_tiles[0, 0].toList(fine_res_tiles[0, 0].size()).get(0)).getInfo())


def export_tile_array(tile_array, satellite_name):
#   Used to keep track of tasks for polling status
    task_list = []
    mask_task_list = []
    
    folder_into = f'{drive_folder}_{satellite_name}'    
    extract_mask = False
    if satellite_name.lower().strip() != 'modis':
        extract_mask = True
    for col in range(len(tile_array)):
        task_list.append([])
        mask_task_list.append([])
        for row in range(len(tile_array[col])):
            for elem in range(tile_array[col, row].size().getInfo()):
                elem_image = ee.Image(tile_array[col, row].toList(tile_array[col, row].size()).get(elem))
#                 if col == row == 0:
#                     pprint(elem_image.getInfo())
                if image_sets[satellite_name]['id_date_indexes'][1] == 0:
                    img_date = datetime.datetime.strptime(elem_image.getInfo()['id'][image_sets[satellite_name]['id_date_indexes'][0]:], 
                                                 image_sets[satellite_name]['id_date_format'])
                else:
                    img_date = datetime.datetime.strptime(elem_image.getInfo()['id'][image_sets[satellite_name]['id_date_indexes'][0]+1:image_sets[satellite_name]['id_date_indexes'][1]+1], 
                                                 image_sets[satellite_name]['id_date_format'])
                first_day = datetime.datetime(year = img_date.year, month = 1, day = 1)
                img_day = (img_date - first_day).days + 1
                img_date_str = img_date.strftime('%d-%b-%Y')
                img_filename = f'{satellite_name}C{col}R{row}_{img_day}_{img_date_str}'
                task = ee.batch.Export.image.toDrive(
                                        image = elem_image.select(image_sets[satellite_name]['bands']).toInt16(),
                                        region = elem_image.getInfo()['properties']['system:footprint']['coordinates'],
                                        crs = 'EPSG:32611',
                                        scale = image_sets[satellite_choice]['pixel_size'],
                                        folder = folder_into,
                                        description = img_filename)
                
                task.start()
                task_list[col].append(task)
                
                if extract_mask:
                    mask_filename = f'{satellite_name}C{col}R{row}_mask_{img_day}_{img_date_str}'
                    mask_task = ee.batch.Export.image.toDrive(
                                        image = elem_image.select(synthesized_mask_name).toUint8(),
                                        region = elem_image.getInfo()['properties']['system:footprint']['coordinates'],
                                        crs = 'EPSG:4326',
                                        scale = image_sets[satellite_choice]['pixel_size'],
                                        folder = f"{folder_into}_mask",
                                        description = mask_filename)
                    mask_task.start()
                    mask_task_list[col].append(mask_task)
                    fine_res_filenames.append(img_filename)
                    fine_res_mask_filenames.append(mask_filename)
                else:
                    coarse_res_filenames.append(img_filename)
    return [task_list, mask_task_list]

modis_tasks = export_tile_array(coarse_res_tiles, 'modis')
fine_res_tasks = export_tile_array(fine_res_tiles, satellite_choice)
fine_res_filenames = list(dict.fromkeys(fine_res_filenames))
fine_res_mask_filenames = list(dict.fromkeys(fine_res_mask_filenames))
coarse_res_filenames = list(dict.fromkeys(coarse_res_filenames))

KeyboardInterrupt: 

In [None]:
taskstat = modis_tasks[0][0][0].status()['state']
pprint(taskstat)
while taskstat == 'READY' or taskstat == 'RUNNING':
    taskstat = modis_tasks[0][0][0].status()['state']
pprint(modis_tasks[0][0][0].status())

In [None]:
# # task based status checking

# def check_task_status(task_tile_array)
# # loop through task list then mask task exports
#     for task_array in task_tile_array:
#         for col in range(len(task_array)):
#             for row in range(len(task_array[col])):
#                 taskstatus = task_array[col][row].status()['state']
#                 while taskstatus == 'READY' or taskstatus == 'RUNNING':
#                     time.sleep(1)
#                     taskstatus = task_array[col][row].status()['state']
#                 logging.warning(f'{task_array[col][row].status()["description"]} at (col, row): ({col}, {row}) exported with final state: {taskstatus}')
        

In [None]:
# wait for a full set of images to appear locally to begin psrfm processing
# for now just check that all of the images exist in the folder

fine_res_path = f'{drive_folder_location}\\{drive_folder}_{satellite_choice}'
coarse_res_path = f'{drive_folder_location}\\{drive_folder}_modis'
fine_mask_path = f'{drive_folder_location}\\{drive_folder}_{satellite_choice}_mask'
print('started')
while not (f'{drive_folder}_modis' in os.listdir(drive_folder_location)):
    time.sleep(30)
print('coarse res image folder exists')

while not (f'{drive_folder}_{satellite_choice}' in os.listdir(f'{drive_folder_location}')):
    time.sleep(30)
print('fine res image folder exists')

while not (f'{drive_folder}_{satellite_choice}_mask' in os.listdir(f'{drive_folder_location}')):
    time.sleep(30)
print('fine res mask folder exists')

# wait for all the files to arrive
while len(os.listdir(fine_mask_path)) != len(fine_res_mask_filenames):
    time.sleep(30)
print('all fine res mask images present')
    
while len(os.listdir(coarse_res_path)) != len(coarse_res_filenames):
    time.sleep(30)
print('all coarse res images present')

while len(os.listdir(fine_res_path)) != len(fine_res_filenames):
    time.sleep(30)
print('all fine res images present')

print('Ready for PSRFM with following files:')
pprint(os.listdir(fine_res_path))
pprint(os.listdir(coarse_res_path))
pprint(os.listdir(fine_mask_path))

In [None]:
# function to extract date as datetime for conversion and param file
def get_date_from_filename(filename):
    date = datetime.datetime.strptime(filename[-15:-4], '%d-%b-%Y')
    return date

In [None]:
# Ensuring all local paths exist for PSRFM to run and for TIF files to be exported into
if not os.path.exists(psrfm_info_dst['dst_base_path']):
    os.mkdir(psrfm_info_dst['dst_base_path'])
if not os.path.exists(psrfm_info_dst['dst_input_path']):
    os.mkdir(psrfm_info_dst['dst_input_path'])
if not os.path.exists(psrfm_info_dst['dst_param_path']):
    os.mkdir(f'{psrfm_info_dst["dst_param_path"]}')
if not os.path.exists(f'{psrfm_info_dst["dst_base_path"]}\\PSRFM_Main.exe'):
    shutil.copy(f"{psrfm_info_dst['PSRFM_exe_path']}\\PSRFM_Main.exe", psrfm_info_dst['dst_param_path'])
if not os.path.exists(psrfm_info_dst['dst_output_path']):
    os.mkdir(f'{psrfm_info_dst["dst_output_path"]}')
if not os.path.exists(psrfm_info_dst['dst_temp_path']):
    os.mkdir(f'{psrfm_info_dst["dst_temp_path"]}')
# if not os.path.exists(f'{drive_folder_location}\\{drive_folder}'):
#     os.mkdir(f'{drive_folder_location}\\{drive_folder}')

In [None]:
# Cropping, converting to ENVI, and moving all the TIF files from GDrive to a local location with some small band removals or other processing to make files ready for PSRFM to be run on them
filenames = []
for path in tif_paths:
    for filename in os.listdir(tif_paths[path]):
        filenames.append(filename)
        with rio.open(f'{tif_paths[path]}\\{filename}') as image_to_crop:
            finalx = image_to_crop.meta["width"] - image_to_crop.meta["width"] % block_size
            finaly = image_to_crop.meta["height"] - image_to_crop.meta["height"] % block_size
            col_offset = (image_to_crop.width - finalx)/2
            row_offset = (image_to_crop.height - finaly)/2
            
            subset_window = Window(col_offset, row_offset, finalx, finaly)
            newargs = image_to_crop.meta.copy()
            newargs.update({
                'height': subset_window.height,
                'width': subset_window.width,
                'transform': rio.windows.transform(subset_window, image_to_crop.transform),
                'driver': 'ENVI'
            })
            if newargs['count'] == 7:
                newargs.update({'count' : 6})
            with rio.open(f'{psrfm_info_dst["dst_input_path"]}\\{filename[:-4]}.dat', 'w', **newargs) as dst:
                if 'ls8' in filename and 'mask' not in filename:
                    dst.write(image_to_crop.read(indexes=[1, 2, 3, 4, 5, 6], window=subset_window))
                else:
                    dst.write(image_to_crop.read(window=subset_window))
# filenames

In [None]:
# finding the correct images to put into the param files and place all their metadata into the param files
for col in range(x_tile_segments):
    for row in range(y_tile_segments):
#     creating tuples of the filenames and dates corresponding to the current tile in {filename, date} format
        fine_res_filenames_dates = [(filename, get_date_from_filename(filename)) for filename in os.listdir(psrfm_info_dst['dst_input_path']) if filename.startswith(f'{satellite_choice}C{col}R{row}') and not filename.startswith(f'{satellite_choice}C{col}R{row}_mask') and filename.endswith('.dat')]
        mask_filenames_dates = [(filename, get_date_from_filename(filename)) for filename in os.listdir(psrfm_info_dst['dst_input_path']) if filename.startswith(f'{satellite_choice}C{col}R{row}_mask') and filename.endswith('.dat')]
        coarse_res_filenames_dates = [(filename, get_date_from_filename(filename)) for filename in os.listdir(psrfm_info_dst['dst_input_path']) if filename.startswith(f'modisC{col}R{row}') and filename.endswith('.dat')]
#     sorting the coarse and fine resolution files by date (second element in tuple) to help create "sets" of images for PSRFM processing
        coarse_res_filenames_dates = sorted(coarse_res_filenames_dates, key = lambda x: x[1])
        fine_res_filenames_dates = sorted(fine_res_filenames_dates, key = lambda x: x[1])
        mask_filenames_dates = sorted(mask_filenames_dates, key = lambda x: x[1])

#     finding the indexes in coarse res images where both a fine res and coarse res image exist, which will then create pairs of start and end dates, and any coarse images inbetween will be used for predictions
        ref_coarse_indexes = []
        for name_date_pair in fine_res_filenames_dates:
            matching_date_index = [coarse_res_filenames_dates.index(tupl) for tupl in coarse_res_filenames_dates if tupl[1] == name_date_pair[1]]
            ref_coarse_indexes.append(matching_date_index[0])
#         creating "sets" for each start and end date pairing that's ready for PSRFM
#         pprint(coarse_res_filenames_dates)
#         pprint(fine_res_filenames_dates)
#         pprint(ref_coarse_indexes)
        PSRFM_sets = []
        for ref_img_index in range(len(ref_coarse_indexes)-1):
#     check whether a coarse res prediction image exists inbetween ref date 1-2, 2-3, 3-4...
            coarse_res_set = coarse_res_filenames_dates[ref_coarse_indexes[ref_img_index]:ref_coarse_indexes[ref_img_index + 1] + 1]
#             print(ref_img_index)
#         coarse res set should only trigger a PSRFM if there is a prediction date, rather than just start and end dates, so only append if it has that criteria aka len more than 2
            if len(coarse_res_set) > 2:
#             the first coarse res image is a start date, and the last is end date, find the two fine res images which correspond.
                start_date = coarse_res_set[0][1]
                end_date = coarse_res_set[-1][1]
                fine_res_set = [name_date_pair for name_date_pair in fine_res_filenames_dates if name_date_pair[1] == start_date or name_date_pair[1] == end_date]
                mask_set = [name_date_pair for name_date_pair in mask_filenames_dates if name_date_pair[1] == start_date or name_date_pair[1] == end_date]
#             appending all necessary information for a PSRFM set to create a valid param file
                PSRFM_sets.append({
                    'coarse_res_filenames' : coarse_res_set,
                    'fine_res_filenames': fine_res_set,
                    'mask_filenames' : mask_set
                })
#         print(PSRFM_sets)
#     adding prediction date PSRFM runs
        PSRFM_ref_date_sets = []
        for pred_img_index in range(len(ref_coarse_indexes) - 2):
            coarse_res_set = [coarse_res_filenames_dates[ref_coarse_indexes[pred_img_index]], coarse_res_filenames_dates[ref_coarse_indexes[pred_img_index + 1]], coarse_res_filenames_dates[ref_coarse_indexes[pred_img_index + 2]]]
            start_date = coarse_res_set[0][1]
            end_date = coarse_res_set[-1][1]
            fine_res_set = [name_date_pair for name_date_pair in fine_res_filenames_dates if name_date_pair[1] == start_date or name_date_pair[1] == end_date]
            mask_set = [name_date_pair for name_date_pair in mask_filenames_dates if name_date_pair[1] == start_date or name_date_pair[1] == end_date]
            PSRFM_ref_date_sets.append({
                    'coarse_res_filenames' : coarse_res_set,
                    'fine_res_filenames': fine_res_set,
                    'mask_filenames' : mask_set  
            })
#         pprint(PSRFM_ref_date_sets)
        all_PSRFM_sets = [PSRFM_sets, PSRFM_ref_date_sets]
        for set_idx in range(len(all_PSRFM_sets)):
            reference_precursor = ""
            if set_idx == 1:
                reference_precursor = 'reference'
            PSRFM_instances = len(all_PSRFM_sets[set_idx])
            for instance in (range(PSRFM_instances)):
    #             creating an output folder for each each instance
                ouptut_dir = f'{psrfm_info_dst["dst_output_path"]}\\{reference_precursor}instance_{instance+1}'
                if not os.path.exists(ouptut_dir):
                    os.mkdir(ouptut_dir)
    #             creating the parameter file for the current tile and instance of PSRFM
                if not os.path.exists(f'{psrfm_info_dst["dst_param_path"]}\\{col}_{row}_params{reference_precursor}_{instance + 1}.txt'):
                    shutil.copy(f'{psrfm_info_dst["PSRFM_exe_path"]}\\psrfm_template.txt', f'{psrfm_info_dst["dst_param_path"]}\\{col}_{row}_params{reference_precursor}_{instance + 1}.txt')
    #             now that file exists, insert the parameters
                with rio.open(f"{psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['fine_res_filenames'][0][0]}") as img:
                    nrows = img.meta['height']
                    ncols = img.meta['width']
                    nbands = img.meta['count']
                with fileinput.FileInput(f'{psrfm_info_dst["dst_param_path"]}\\{col}_{row}_params{reference_precursor}_{instance + 1}.txt', inplace = True) as paramfile:
                    for line in paramfile:
                        if line.strip().startswith('IN_PAIR_COARSE_FNAME'):
                            print(f"IN_PAIR_COARSE_FNAME = {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['coarse_res_filenames'][0][0]} {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['coarse_res_filenames'][-1][0]}")
                        elif line.strip().startswith('IN_PAIR_FINE_FNAME'):
                            print(f"IN_PAIR_FINE_FNAME = {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['fine_res_filenames'][0][0]} {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['fine_res_filenames'][1][0]}")                                            
                        elif line.strip().startswith('IN_PAIR_FINE_MASK_FNAME'):
                            print(f"IN_PAIR_FINE_MASK_FNAME = {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['mask_filenames'][0][0]} {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['mask_filenames'][1][0]}")                                                     
                        elif line.strip().startswith('IN_PDAY_COARSE_NO'):
                            prediction_filenames = ''
                            for pair in all_PSRFM_sets[set_idx][instance-1]['coarse_res_filenames'][1:-1]:
                                prediction_filenames += f" {psrfm_info_dst['dst_input_path']}\\{pair[0]} \n"
                            print(f"IN_PDAY_COARSE_NO = {len(all_PSRFM_sets[set_idx][instance-1]['coarse_res_filenames']) - 2} \n {prediction_filenames}")
                        elif line.strip().startswith('OUT_PREDICTION_DIR'):
                            print(f"OUT_PREDICTION_DIR = {ouptut_dir}")
                        elif line.strip().startswith('OUT_TEMP_DIR'):
                            print(f"OUT_TEMP_DIR = {psrfm_info_dst['dst_temp_path']}")
                        elif line.strip().startswith('OUT_ENVI_HDR'):
                            print(f"OUT_ENVI_HDR = {psrfm_info_dst['dst_input_path']}\\{all_PSRFM_sets[set_idx][instance-1]['fine_res_filenames'][0][0][:-4]}.hdr")
                        elif line.strip().startswith('NROWS'):
                            print(f"NROWS = {nrows}")
                        elif line.strip().startswith('COARSE_ROWS'):
                            print(f"COARSE_ROWS = {nrows}")
                        elif line.strip().startswith('NCOLS'):
                            print(f"NCOLS = {ncols}")
                        elif line.strip().startswith('COARSE_COLS'):
                            print(f"COARSE_COLS = {ncols}")
                        elif line.strip().startswith('NBANDS'):
                            print(f"NBANDS = {nbands}")
                        elif line.strip().startswith('RESOLUTION'):
                            print(f"RESOLUTION = {image_sets[satellite_choice]['pixel_size']}")
                        elif line.strip().startswith('BLOCK_SIZE'):
                            print(f"BLOCK_SIZE = {block_size}")
                        else:
                            print(line, end = '')
# fine_res_filenames_dates, mask_filenames_dates, coarse_res_filenames_dates, PSRFM_instances
# coarse_res_filenames_dates, matching_date_indexes, PSRFM_sets, len(matching_date_indexes)

In [None]:
for param_file in [filename for filename in os.listdir(psrfm_info_dst['dst_param_path']) if filename != 'PSRFM_Main.exe']:
    print(f'{psrfm_info_dst["dst_param_path"]}\\{param_file}')
    subprocess.run([f'{psrfm_info_dst["dst_param_path"]}\\PSRFM_Main.exe', f'{psrfm_info_dst["dst_param_path"]}\\{param_file}'])

In [None]:
if not os.path.exists(psrfm_info_dst['dst_output_gtif_path']):
    os.mkdir(psrfm_info_dst['dst_output_gtif_path'])

instance_dirs = []
for directoryname in os.listdir(psrfm_info_dst['dst_output_path']):
    if "instance" in directoryname:
        instance_dirs.append(directoryname)
    
for instance_dir in instance_dirs:
    if not os.path.exists(f'{psrfm_info_dst["dst_output_gtif_path"]}\\{instance_dir}'):
        os.mkdir(f'{psrfm_info_dst["dst_output_gtif_path"]}\\{instance_dir}')
    for filename in [filename for filename in os.listdir(f'{psrfm_info_dst["dst_output_path"]}\\{instance_dir}') 
                     if ('.dat' in filename and not '_mask' in filename and not '_Q' in filename)]:
#         pprint(filename)
        with rio.open(f'{psrfm_info_dst["dst_output_path"]}\\{instance_dir}\\{filename}') as image_to_convert:
            
            newargs = image_to_convert.meta.copy()
            newargs.update({
                'driver': 'GTiff'
            })
            
            with rio.open(f'{psrfm_info_dst["dst_output_gtif_path"]}\\{instance_dir}\\{filename[:-4]}.tif', 'w', **newargs) as dst:
                dst.write(image_to_convert.read())
                    
# Reference code for converting dat files to gtif
# filenames = []
# for path in tif_paths:
#     for filename in os.listdir(tif_paths[path]):
#         filenames.append(filename)
#         with rio.open(f'{tif_paths[path]}\\{filename}') as image_to_crop:
#             finalx = image_to_crop.meta["width"] - image_to_crop.meta["width"] % block_size
#             finaly = image_to_crop.meta["height"] - image_to_crop.meta["height"] % block_size
#             col_offset = (image_to_crop.width - finalx)/2
#             row_offset = (image_to_crop.height - finaly)/2

#             subset_window = Window(col_offset, row_offset, finalx, finaly)
#             newargs = image_to_crop.meta.copy()
#             newargs.update({
#                 'height': subset_window.height,
#                 'width': subset_window.width,
#                 'transform': rio.windows.transform(subset_window, image_to_crop.transform),
#                 'driver': 'ENVI'
#             })
#             if newargs['count'] == 7:
#                 newargs.update({'count' : 6})
#             with rio.open(f'{psrfm_info_dst["dst_input_path"]}\\{filename[:-4]}.dat', 'w', **newargs) as dst:
#                 if 'ls8' in filename and 'mask' not in filename:
#                     dst.write(image_to_crop.read(indexes=[1, 2, 3, 4, 5, 6], window=subset_window))
#                 else:
#                     dst.write(image_to_crop.read(window=subset_window))