In [None]:
#####Build STDFM-CNN 
import numpy as np
import rasterio
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import pickle

# ----------------------------
# Paths
# ----------------------------
base_path = r'D:\hagh\google drive/'
high_res_paths = [
    f'{base_path}LANDSATBCN/LST_Landsat9_20240203.tif',
    f'{base_path}LANDSATBCN/LST_Landsat9_20240219.tif',
    f'{base_path}LANDSATBCN/LST_Landsat9_20240322.tif'
]
low_res_paths = [
    f'{base_path}MODISBCN/LST_MODIS_Aqua_2024_02_03.tif',
    f'{base_path}MODISBCN/LST_MODIS_Aqua_2024_02_19.tif',
    f'{base_path}MODISBCN/LST_MODIS_Aqua_2024_03_22.tif'
]
output_path = f'{base_path}OUTPUTLST2024/STDFM_OUTPUT20240219.tif'
model_path = f'{base_path}OUTPUTLST2024/STDFM_model.pkl'

# ----------------------------
# Load raster
# ----------------------------
def load_raster(file_path):
    with rasterio.open(file_path) as src:
        return src.read(1)

high_res = [load_raster(p) for p in high_res_paths]
low_res = [load_raster(p) for p in low_res_paths]

# Compute differences
high_res_diff = high_res[0] - high_res[2]
low_res_diff = low_res[0] - low_res[2]

# Reshape for CNN input (samples, height, width, channels)
def reshape(x):
    return x[np.newaxis, ..., np.newaxis]

high_res = [reshape(x) for x in high_res]
low_res = [reshape(x) for x in low_res]
high_res_diff = reshape(high_res_diff)
low_res_diff = reshape(low_res_diff)

# ----------------------------
# Build STDFM-CNN
# ----------------------------
def conv_block(x, filters=64, kernel_size=(3,3)):
    x = Conv2D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    return x

def residual_block(x):
    res = Conv2D(64, (3,3), padding='same')(x)
    res = BatchNormalization()(res)
    res = ReLU()(res)
    res = Conv2D(64, (3,3), padding='same')(res)
    res = BatchNormalization()(res)
    return Add()([x, res])

def build_stdfm(input_shape):
    # Inputs
    low_res_input2 = Input(shape=input_shape, name="Low_Res_Input2")
    low_res_input3 = Input(shape=input_shape, name="Low_Res_Input3")
    high_res_input = Input(shape=input_shape, name="High_Res_Input1")
    high_res_diff_input = Input(shape=input_shape, name="High_Res_Diff_Input")
    low_res_diff_input = Input(shape=input_shape, name="Low_Res_Diff_Input")
    
    # Feature Extraction
    lr2 = conv_block(low_res_input2)
    lr2 = conv_block(lr2)
    lr3 = conv_block(low_res_input3)
    lr3 = conv_block(lr3)
    hr1 = conv_block(high_res_input)
    hr1 = conv_block(hr1)
    hr_diff = conv_block(high_res_diff_input)
    hr_diff = conv_block(hr_diff)
    lr_diff = conv_block(low_res_diff_input)
    lr_diff = conv_block(lr_diff)
    
    # Fusion
    merged = concatenate([lr2, lr3, hr1, hr_diff, lr_diff])
    
    # Deeper Conv layers
    x = Conv2D(128, (3,3), padding='same')(merged)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(64, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = residual_block(x)
    
    # Output
    output = Conv2D(1, (1,1), activation='linear')(x)
    
    model = Model(inputs=[low_res_input2, low_res_input3, high_res_input, high_res_diff_input, low_res_diff_input], outputs=output)
    return model

input_shape = high_res[0].shape[1:]
best_loss = np.inf
best_model = None
best_params = {}

# ----------------------------
# Grid Search for hyperparameters
# ----------------------------
learning_rates = [0.001, 0.0005]
batch_sizes = [1, 2]
epochs = 100  # Fixed as in article

for lr in learning_rates:
    for bs in batch_sizes:
        model = build_stdfm(input_shape)
        model.compile(optimizer=Adam(learning_rate=lr), loss='mse', metrics=['mae'])
        history = model.fit(
            [low_res[1], low_res[2], high_res[0], high_res_diff, low_res_diff],
            high_res[1],
            epochs=epochs,
            batch_size=bs,
            verbose=0
        )
        final_loss = history.history['loss'][-1]
        if final_loss < best_loss:
            best_loss = final_loss
            best_model = model
            best_params = {'learning_rate': lr, 'batch_size': bs}

print("Best Grid Search Params:", best_params, "with Loss:", best_loss)

# ----------------------------
# Save best model
# ----------------------------
with open(model_path, 'wb') as f:
    pickle.dump(best_model, f)

# Predict
predicted = best_model.predict([low_res[1], low_res[2], high_res[0], high_res_diff, low_res_diff])
predicted = predicted[0, ..., 0]

# Save predicted raster
with rasterio.open(output_path, 'w', driver='GTiff', height=predicted.shape[0], width=predicted.shape[1], count=1, dtype='float32') as dst:
    dst.write(predicted, 1)

print(f"Best model saved to {model_path} and predicted image saved to {output_path}")


In [None]:
####Prediction Build STDFM-CNN 
import numpy as np
import rasterio
import pickle
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, concatenate
from tensorflow.keras.models import Model

# ----------------------------
# Paths
# ----------------------------
base_path = r'D:\hagh\google drive/'
high_res1_path = f'{base_path}LANDSATACN/LST_Landsat9_20240203.tif'
high_res3_path = f'{base_path}LANDSATACN/LST_Landsat9_20240322.tif'
low_res1_path = f'{base_path}MODISACN/LST_MODIS_Aqua_2024_02_03.tif'
low_res2_path = f'{base_path}MODISACN/LST_MODIS_Aqua_2024_02_19.tif'
low_res3_path = f'{base_path}MODISACN/LST_MODIS_Aqua_2024_03_22.tif'
output_path = f'{base_path}OUTPUTLST2024/STDFM_OutputA_predict_20240219.tif'
model_path = f'{base_path}OUTPUTLST2024/STDFM_model.pkl'

# ----------------------------
# Load raster
# ----------------------------
def load_raster(file_path):
    with rasterio.open(file_path) as src:
        return src.read(1)

high_res1 = load_raster(high_res1_path)
high_res3 = load_raster(high_res3_path)
low_res1 = load_raster(low_res1_path)
low_res2 = load_raster(low_res2_path)
low_res3 = load_raster(low_res3_path)

# Compute differences
high_res_diff = high_res1 - high_res3
low_res_diff = low_res1 - low_res3

# Reshape for CNN
def reshape(x):
    return x[np.newaxis, ..., np.newaxis]

high_res1 = reshape(high_res1)
high_res3 = reshape(high_res3)
low_res1 = reshape(low_res1)
low_res2 = reshape(low_res2)
low_res3 = reshape(low_res3)
high_res_diff = reshape(high_res_diff)
low_res_diff = reshape(low_res_diff)

# ----------------------------
# Rebuild STDFM architecture for prediction
# ----------------------------
def conv_block(x, filters=64):
    x = Conv2D(filters, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    return x

def residual_block(x):
    res = Conv2D(64, (3,3), padding='same')(x)
    res = BatchNormalization()(res)
    res = ReLU()(res)
    res = Conv2D(64, (3,3), padding='same')(res)
    res = BatchNormalization()(res)
    return Add()([x, res])

def build_stdfm_predict(input_shape):
    # Inputs
    low_res_input2 = Input(shape=input_shape, name="Low_Res_Input2")
    low_res_input3 = Input(shape=input_shape, name="Low_Res_Input3")
    high_res_input1 = Input(shape=input_shape, name="High_Res_Input1")
    high_res_diff_input = Input(shape=input_shape, name="High_Res_Diff_Input")
    low_res_diff_input = Input(shape=input_shape, name="Low_Res_Diff_Input")
    
    # Feature extraction
    lr2 = conv_block(low_res_input2)
    lr2 = conv_block(lr2)
    lr3 = conv_block(low_res_input3)
    lr3 = conv_block(lr3)
    hr1 = conv_block(high_res_input1)
    hr1 = conv_block(hr1)
    hr_diff = conv_block(high_res_diff_input)
    hr_diff = conv_block(hr_diff)
    lr_diff = conv_block(low_res_diff_input)
    lr_diff = conv_block(lr_diff)
    
    # Feature fusion
    merged = concatenate([lr2, lr3, hr1, hr_diff, lr_diff])
    
    # Deeper layers
    x = Conv2D(128, (3,3), padding='same')(merged)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(64, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = residual_block(x)
    
    # Output layer
    output = Conv2D(1, (1,1), activation='linear')(x)
    
    model = Model(inputs=[low_res_input2, low_res_input3, high_res_input1, high_res_diff_input, low_res_diff_input], outputs=output)
    return model

input_shape = high_res1.shape[1:]
stdfm_model = build_stdfm_predict(input_shape)

# ----------------------------
# Load trained model weights
# ----------------------------
with open(model_path, 'rb') as f:
    trained_model = pickle.load(f)

# ----------------------------
# Prediction
# ----------------------------
predicted = trained_model.predict([low_res2, low_res3, high_res1, high_res_diff, low_res_diff])
predicted = predicted[0, ..., 0]

# Save predicted raster
with rasterio.open(output_path, 'w', driver='GTiff',
                   height=predicted.shape[0], width=predicted.shape[1],
                   count=1, dtype='float32') as dst:
    dst.write(predicted, 1)

print(f"Prediction done. Output saved at {output_path}")
