<a href="https://colab.research.google.com/github/jollygoodjacob/STF/blob/main/STF_starfm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# STARFM for Spatiotemporal Fusion

This Google Colab script is an implementation of the STARFM spatiotemporal fusion method, modified for use with UAV and Planet data.

## Install required packages

First, we need to install rasterio using pip, as Google Colab does not come preinstalled with this package.

In [4]:
!pip install rasterio
!pip install zarr
!pip install dask.array

Collecting zarr
  Downloading zarr-3.0.6-py3-none-any.whl.metadata (9.7 kB)
Collecting donfig>=0.8 (from zarr)
  Downloading donfig-0.8.1.post1-py3-none-any.whl.metadata (5.0 kB)
Collecting numcodecs>=0.14 (from numcodecs[crc32c]>=0.14->zarr)
  Downloading numcodecs-0.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting crc32c>=2.7 (from numcodecs[crc32c]>=0.14->zarr)
  Downloading crc32c-2.7.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Downloading zarr-3.0.6-py3-none-any.whl (196 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m196.4/196.4 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading donfig-0.8.1.post1-py3-none-any.whl (21 kB)
Downloading numcodecs-0.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m64.8 MB/s[0m eta [36m0:00:00[0m


## Mount Google Drive

Next, we want to mount our Google Drive so that we can share data with our Google Colab script. We need both the starfm4py.py and parameters.py functions that are required to run the script, as well as our UAV and Planet data.

In [2]:
from google.colab import drive
drive.mount('/content/drive',force_remount= True)

Mounted at /content/drive


In [14]:
!cat /content/drive/MyDrive/STF/parameters.py

import numpy as np



# Set the size of the moving window in which the search for similar pixels 
# is performed
windowSize = 31

# Set the path where the results should be stored
path = 'STARFM_demo/'

# Set to True if you want to decrease the sensitivity to the spectral distance
logWeight = False

# If more than one training pairs are used, set to True
temp = False

# The spatial impact factor is a constant defining the relative importance of 
# spatial distance (in meters)
# Take a smaller value of the spatial impact factor for heterogeneous regions 
# (e.g. A = 150 m)
spatImp = 150 

# increasing the number of classes limits the number of similar pixels
numberClass = 4 

# Set the uncertainty value for the fine resolution sensor
# https://earth.esa.int/web/sentinel/technical-guides/sentinel-2-msi/performance 
uncertaintyFineRes = 0.03

# Set the uncertainty value for the coarse resolution sensor
# https://sentinels.copernicus.eu/web/sentinel/technical

## Load Functions for STARFM

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Add your module directory to the Python path
import sys
sys.path.append('/content/drive/MyDrive/STF')

# (Optional) Set working directory for reading TIFFs or output
import os
os.chdir('/content/drive/MyDrive/STF')

# Enable auto-reload of modules
%load_ext autoreload
%autoreload 2

# Try importing your values
from parameters import path, sizeSlices
print("Imported values:", path, sizeSlices)

# Try importing the STARFM module
import starfm4py as stp
print("Successfully imported starfm4py")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Apply STARFM to our data

In [None]:
# Import required packages

import time
import rasterio
import numpy as np
#import starfm4py as stp
import matplotlib.pyplot as plt
#from parameters import (path, sizeSlices)



start = time.time()

#Set the path where the images are stored
product = rasterio.open('/content/drive/MyDrive/STF/20220802_RGB_UAV.tif')
profile = product.profile
UAVt0 = rasterio.open('/content/drive/MyDrive/STF/20220802_RGB_UAV.tif').read(1)
Planett0 = rasterio.open('/content/drive/MyDrive/STF/20220802_RGB_Planet.tif').read(1)
Planett1 = rasterio.open('/content/drive/MyDrive/STF/20220812_RGB_Planet.tif').read(1)

# Set the path where to store the temporary results
path_fineRes_t0 = 'Temporary/Tiles_fineRes_t0/'
path_coarseRes_t0 = 'Temporary/Tiles_coarseRes_t0/'
path_coarseRes_t1 = 'Temporary/Tiles_fcoarseRes_t1/'

# Flatten and store the moving window patches
fine_image_t0_par = stp.partition(UAVt0, path_fineRes_t0)
coarse_image_t0_par = stp.partition(Planett0, path_coarseRes_t0)
coarse_image_t1_par = stp.partition(Planett1, path_coarseRes_t1)

print ("Done partitioning!")

# Stack the the moving window patches as dask arrays
S2_t0 = stp.da_stack(path_fineRes_t0, UAVt0.shape)
S3_t0 = stp.da_stack(path_coarseRes_t0, Planett0.shape)
S3_t1 = stp.da_stack(path_coarseRes_t1, Planett1.shape)

shape = (sizeSlices, UAVt0.shape[1])

print ("Done stacking!")

# Perform the prediction with STARFM
for i in range(0, UAVt0.size-sizeSlices*shape[1]+1, sizeSlices*shape[1]):

    fine_image_t0 = S2_t0[i:i+sizeSlices*shape[1],]
    coarse_image_t0 = S3_t0[i:i+sizeSlices*shape[1],]
    coarse_image_t1 = S3_t1[i:i+sizeSlices*shape[1],]
    prediction = stp.starfm(fine_image_t0, coarse_image_t0, coarse_image_t1, profile, shape)

    if i == 0:
        predictions = prediction

    else:
        predictions = np.append(predictions, prediction, axis=0)


# Write the results to a .tif file
print ('Writing product...')
profile = product.profile
profile.update(dtype='float64', count=1) # number of bands
file_name = path + 'prediction.tif'

result = rasterio.open(file_name, 'w', **profile)
result.write(predictions, 1)
result.close()


end = time.time()
print ("Done in", (end - start)/60.0, "minutes!")

# Display input and output
plt.imshow(UAVt0)
plt.gray()
plt.show()
plt.imshow(Planett0)
plt.gray()
plt.show()
plt.imshow(Planett1)
plt.gray()
plt.show()
plt.imshow(predictions)
plt.gray()
plt.show()