#Notebook for running the model

# 1 - Access to Data


For access to the satellite data of **PlanetScope**, we use an academic access that is free but limited. Students or researchers can apply for access by following [this link](https://www.planet.com/industries/education-and-research/). It is important to review the access limitations specific to the user account.

A complete guide to download the satellite images using **QGIS** is available [here](https://github.com/domwelsh/quercus-detection/blob/main/documentation/satellite_images_download/satellite_images_download_doc.md).


# 2- Data Preparation

### Data Preparation Requirements

To prepare the data for processing, follow these steps after downloading the satellite images:

1. **Upload Satellite Images to a Google Cloud Bucket**  
   - You can find a detailed description of the upload process in [this guide](https://github.com/domwelsh/quercus-detection/blob/main/documentation/indexes_calculation/indexes_calculation_doc.md).  
   - Alternatively, you can upload your images manually by following the instructions in [this video tutorial](https://www.youtube.com/watch?v=4MnQnYMTEU0&ab_channel=Codible).

2. **Transfer Images from Google Cloud to a GEE Asset**  
   - Use a notebook to facilitate this process (you also can see a example here: [example](https://github.com/domwelsh/quercus-detection/blob/main/notebooks/Index_calculation_planet_scope/1.Upload_images_from_GC_to_GEE.ipynb). Ensure your images are properly organized within Google Cloud before transferring them.

3. **Execute a Notebook for GEE Integration**  
   - This will enable you to transfer the images to Google Drive and run the notebook using Google Colab.

By following these steps, your satellite images will be ready for further analysis and processing.


### 2.1 Code to Transfer Images from Google Cloud to a GEE Asset

NOTE:

For this you need to acces to your credential of Google Cloud, for obtain the files of service_account_path: [https://developers.google.com/workspace/guides/create-credentials?hl=es-419](Access here)

In [None]:
%pip install rasterio earthengine-api google-cloud-storage google-auth google-auth-oauthlib

In [None]:
from google.oauth2 import service_account
from google.cloud import storage

# Replace with the path to your service account key file
service_account_path = 'C:/Users/your-user/ee-your-user-02041da0749c.json'

# Set up the credentials
credentials = service_account.Credentials.from_service_account_file(service_account_path)

# Initialize the storage client
storage_client = storage.Client(credentials=credentials)

# List all buckets
buckets = list(storage_client.list_buckets())
for bucket in buckets:
    print(bucket.name)


# NOTE: for this you need to access to your credential of Google Cloud

In [None]:
import ee
ee.Authenticate()
ee.Initialize()

In [None]:
import os
import re
import pandas as pd
from google.oauth2 import service_account
from google.cloud import storage
import ee
from datetime import datetime

# Configure credentials file and parameters
service_account_path = 'C:/Users/user-name/ee-your-user-02041da0749c.json'
BUCKET_NAME = 'name_of_your_bucket'
BASE_PATH = 'Folder_with_your_images_of_planet'
COLLECTION = 'projects/your_asset_in_GEE'
BAND_NAMES = ['coastal_blue', 'blue', 'green_i', 'green', 'yellow', 'red', 'rededge', 'nir']
EXCEL_OUTPUT = 'gee_output.xlsx'  # Output file with a list of the uploaded files

def initialize_ee():
    try:
        ee.Initialize()
        print("Google Earth Engine initialized.")
    except Exception as e:
        print(f"Error initializing Earth Engine: {e}")
        raise

def initialize_storage_client():
    creds = service_account.Credentials.from_service_account_file(service_account_path)
    return storage.Client(credentials=creds)

def list_files(bucket_name, base_path):
    print("Listing files in the bucket...")
    storage_client = initialize_storage_client()
    blobs = storage_client.list_blobs(bucket_name, prefix=base_path)
    # Only add files that match the exact pattern YYYY-MM-DD_strip_######_composite.tif
    pattern = r'\d{4}-\d{2}-\d{2}_strip_\d+_composite\.tif$'
    files = [blob.name for blob in blobs if re.search(pattern, blob.name)]
    print("Files listed:", files)
    return files

def extract_date_from_filename(filename):
    print(f"Extracting date from file: {filename}")
    match = re.search(r'(\d{4}-\d{2}-\d{2})', filename)  # Date in format YYYY-MM-DD
    if match:
        formatted_date = match.group(1)
        print(f"Date extracted: {formatted_date}")
        return formatted_date
    else:
        print(f"No date found in filename: {filename}")
        return None

def upload_file(uri_gcs, asset_name, formatted_date):
    print(f"Uploading file to GEE: {asset_name} with date {formatted_date}")
    asset_id = f"{COLLECTION}/{asset_name}"
    bands = [{'id': name, 'tileset_band_index': i} for i, name in enumerate(BAND_NAMES)]
    start_time = f"{formatted_date}T00:00:00Z"
    end_time = f"{formatted_date}T23:59:59Z"

    manifest = {
        'name': asset_id,
        'tilesets': [{
            'sources': [{
                'uris': [uri_gcs]
            }]
        }],
        'bands': bands,
        'startTime': start_time,
        'endTime': end_time
    }
    try:
        task_id = ee.data.newTaskId()[0]
        ee.data.startIngestion(task_id, manifest)
        print(f"Uploaded: {asset_name}")
        return True
    except Exception as e:
        print(f"Error uploading {asset_name}: {e}")
        return False

def main():
    initialize_ee()
    files = list_files(BUCKET_NAME, BASE_PATH)
    print("Files found:", files)

    uploaded_files = []  # To store data for Excel log

    for file in files:
        try:
            file_name = file.split('/')[-1]
            print("Processing file:", file_name)
            formatted_date = extract_date_from_filename(file_name)
            print("Extracted date:", formatted_date)
            if formatted_date:
                uri_gcs = f"gs://{BUCKET_NAME}/{file}"
                asset_name = file_name.split('.')[0]
                success = upload_file(uri_gcs, asset_name, formatted_date)
                # Add entry to the log if upload is successful
                if success:
                    uploaded_files.append({'File Name': file_name, 'Extracted Date': formatted_date})
        except Exception as e:
            print(f"Error processing file {file}: {e}")

    # Save the log to an Excel file
    df = pd.DataFrame(uploaded_files)
    df.to_excel(EXCEL_OUTPUT, index=False)
    print(f"Log saved to {EXCEL_OUTPUT}")

if __name__ == "__main__":
    main()

### 2.2 Execute a Notebook for GEE Integration: Obtain the vegetation indexes and put your satellite images available in Google Drive




In [None]:
# Add a map to display the image and index
Map = geemap.Map(center = (39.22, -8.97), zoom = 6) # You can change your coordinates to your study area
Map

# And Please Draw a point for your area of study in the displayed map

In [None]:
Map.user_roi.getInfo()
#Define the geometry of the study area
geometry = Map.user_roi # Please copy the result in to the next cell

In [None]:
# Add the here the values of the last result
geometry = ee.Geometry.Point([-8.607292, 38.134876]) # This is a example

In [None]:
# Call the image collection
collection = ee.ImageCollection("projects/your-user/assets/location_of_the_images_in_the_asset_of_gee")

# Filter by date (in this case is the same month of soil sampling)
collection = collection.filterDate('YYYY-MM-DD', 'YYYY-MM-DD').filterBounds(geometry)

# This function maps spectral indices using PlanetScope Imagery
def addIndices(img):
    # NDVI
    NDVI = img.normalizedDifference(['nir', 'red']).rename('NDVI')

    # NDWI (Normalized Difference Water Index)
    NDWI = img.expression('(GREEN - NIR) / (GREEN + NIR)', {'NIR': img.select('nir'), 'GREEN': img.select('green')}).rename('NDWI')

    # VARI (Visible Atmospherically Resistant Index)
    VARI = img.expression('(Green - Red) / (Green + Red - Blue)', {'Blue': img.select('blue'), 'Red': img.select('red'), 'Green': img.select('green')}).rename('VARI')


    return img.addBands([NDVI, NDWI, VARI])

# Example usage
ps = collection
#Add the indices
ps = ps.map(addIndices)
composite = ps \
              .mean()

Map.addLayer(composite, {'bands': ['red',  'green',  'blue'], 'min': 201, 'max': 2464}, 'RGB')

### 2.3 Export the montly mosaics to Google Drive

In [None]:
# Define the date range for the collection
start_date = ee.Date('YYYY-MM-DD')
end_date = ee.Date('YYYY-MM-DD')


collection_with_indices = collection.map(add_indices)

# Function to create a monthly mosaic
def average_by_month(start):
    start = ee.Date(start)
    end = start.advance(1, 'month')
    monthly_images = collection_with_indices.filterDate(start, end)
    mosaic = monthly_images.mean()
    return mosaic.set('system:time_start', start.millis())

# Create monthly mosaics
monthly_mosaics = ee.ImageCollection.fromImages(
    months.map(lambda m: ee.Image(average_by_month(m)))
)

# Export each mosaic to Google Drive
for i in range(monthly_mosaics.size().getInfo()):
    img = ee.Image(monthly_mosaics.toList(monthly_mosaics.size()).get(i))
    date_str = ee.Date(img.get('system:time_start')).format('YYYY-MM').getInfo()
    file_name = f"Add_a_name_{date_str}.tif" # Please add a name for the file this will be like 'Add_a_name_Monthly_YYYY-MM"

    # Export all bands as a single image
    task = ee.batch.Export.image.toDrive(
        image=img.toFloat(),
        description=f"Export_Add_a_name_{date_str}",
        folder='GEE_Exports',
        fileNamePrefix=f"{file_name}",
        scale=3,
        crs="EPSG:3763",
        maxPixels=1e13
    )
    task.start()
    print(f"Exporting {file_name} to Google Drive")

# 3- Model

###3.1  Define the Needed Functions


##### This section defines essential functions for preprocessing raster data before model training. It includes:

1. **identify_band_indices**: Finds the indices of specified spectral bands in a raster file.
2. **generate_sequences_with_band_indices**: Creates sequences of images by selecting and padding the identified bands.
3. **replace_nan_in_images**: Replaces NaN values in images with -9999 to ensure data consistency.
4. **verify_band_statistics**: Checks the minimum, maximum, and mean values of each band to validate data quality.

In [None]:
#Define the needed functions
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import rasterio

# Step 1: Identify the indices of the necessary bands
def identify_band_indices(file_path, target_band_names):
    with rasterio.open(file_path) as src:
        band_names = list(src.descriptions)
        band_indices = [
            band_names.index(band) + 1 if band in band_names else None
            for band in target_band_names
        ]
        for band, index in zip(target_band_names, band_indices):
            if index is None:
                raise ValueError(f"The band {band} was not found in the file.")
            print(f"Band {band}: Index {index}")
    return band_indices

# Step 2: Generate sequences using band indices
def generate_sequences_with_band_indices(folders, band_indices):
    sequences = []
    max_height, max_width = 0, 0

    # Calculate the maximum dimensions
    for folder in folders:
        image_paths = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.endswith(('.tif', '.tiff'))
        ]
        for image_path in image_paths:
            with rasterio.open(image_path) as src:
                max_height = max(max_height, src.height)
                max_width = max(max_width, src.width)

    # Process and group images into sequences
    for folder in folders:
        image_paths = sorted([
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.endswith(('.tif', '.tiff'))
        ])

        num_tiles = len(image_paths) // 13
        if len(image_paths) % 13 != 0:
            raise ValueError("The number of images is not a multiple of 13.")

        for i in range(num_tiles):
            tile_images = []
            for j in range(13):
                with rasterio.open(image_paths[i * 13 + j]) as src:
                    # Read the bands selected by index
                    selected_bands = src.read(band_indices)
                    padded_image = np.full(
                        (selected_bands.shape[0], max_height, max_width),
                        -9999, dtype=selected_bands.dtype
                    )
                    padded_image[:, :selected_bands.shape[1], :selected_bands.shape[2]] = selected_bands
                    tile_images.append(padded_image)

            sequences.append(np.stack(tile_images, axis=0))

    return np.array(sequences)

# Step 3: Replace NaN with -9999 in the images.
def replace_nan_in_images(folder):
    processed_folder = os.path.join(folder, "processed")
    os.makedirs(processed_folder, exist_ok=True)

    for file in sorted(os.listdir(folder)):
        if file.endswith(('.tif', '.tiff')):
            input_path = os.path.join(folder, file)
            output_path = os.path.join(processed_folder, file)

            with rasterio.open(input_path) as src:
                profile = src.profile
                data = src.read()
                # Replace NaN with -9999
                data = np.nan_to_num(data, nan=-9999)
                # Save the processed image
                with rasterio.open(output_path, 'w', **profile) as dst:
                    dst.write(data)
                print(f"Processed {file} and saved to {output_path}")
    return processed_folder

# Step 4: Verify the statistics of the bands
def verify_band_statistics(folder, band_indices):
    for file in sorted(os.listdir(folder)):
        if file.endswith(('.tif', '.tiff')):
            file_path = os.path.join(folder, file)
            with rasterio.open(file_path) as src:
                for idx in band_indices:
                    band_data = src.read(idx)
                    print(f"Band {idx} - Min: {np.min(band_data)}, Max: {np.max(band_data)}, Mean: {np.mean(band_data)}")

### 3.2 Define the model

#####**This section defines the spatio-temporal CNN model and its training pipeline.**

The ModelConfig dataclass sets the model's hyperparameters. The SpatioTemporalModel class builds the CNN architecture for processing raster sequences. The CorkOakTrainer class manages the training and evaluation processes, including loss computation and accuracy measurement.

In [None]:
#Define the model
from dataclasses import dataclass
import torch.nn as nn
import torch

# Model Configuration and Definition
@dataclass
class ModelConfig:
    """
    Configuration class for the SpatioTemporalModel.

    Attributes
    ----------
    n_bands : int
        Number of input spectral bands.
    cnn_channels : list of int, optional
        List of output channels for the CNN layers. Default is [32, 64, 128].
    cnn_kernel_size : int, optional
        Kernel size for the CNN convolutions. Default is 3.
    cnn_dropout : float, optional
        Dropout rate for the CNN layers. Default is 0.3.
    learning_rate : float, optional
        Learning rate for the optimizer. Default is 0.001.
    """
    n_bands: int = 17
    cnn_channels: list = None
    cnn_kernel_size: int = 3
    cnn_dropout: float = 0.1
    learning_rate: float = 0.001

    def __post_init__(self):
        """
        Initializes default values for cnn_channels if not provided.
        """
        if self.cnn_channels is None:
            self.cnn_channels = [32, 64, 128]


class SpatioTemporalModel(nn.Module):
    """
    Spatio-temporal CNN model for processing raster sequences.

    Parameters
    ----------
    config : ModelConfig
        Configuration object containing model hyperparameters.

    Methods
    -------
    forward(x)
        Perform a forward pass of the model.
    """

    def __init__(self, config: ModelConfig):
        """
        Initializes the CNN layers and pixel classifier based on the configuration.

        Parameters
        ----------
        config : ModelConfig
            Configuration object containing model hyperparameters.
        """
        super().__init__()
        self.config = config

        # Build CNN layers dynamically
        cnn_layers = []
        in_channels = config.n_bands
        for out_channels in config.cnn_channels:
            cnn_layers.extend([
                nn.Conv2d(in_channels, out_channels, kernel_size=config.cnn_kernel_size, padding=config.cnn_kernel_size // 2),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Dropout2d(config.cnn_dropout)
            ])
            in_channels = out_channels
        self.cnn = nn.Sequential(*cnn_layers)

        # Pixel classifier layer
        self.pixel_classifier = nn.Conv2d(in_channels=config.cnn_channels[-1], out_channels=2, kernel_size=1)

    def forward(self, x):
        """
        Forward pass of the model.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor with shape (batch_size, timesteps, bands, height, width).

        Returns
        -------
        torch.Tensor
            Output tensor with shape (batch_size, 2, height, width), where 2 is the number of classes.
        """
        batch_size, timesteps, bands, height, width = x.shape
        cnn_features = []
        for t in range(timesteps):
            features = self.cnn(x[:, t])
            cnn_features.append(features)
        cnn_features = torch.stack(cnn_features, dim=1).mean(dim=1)
        return self.pixel_classifier(cnn_features)


class CorkOakTrainer:
    """
    Training and evaluation class for the SpatioTemporalModel.

    Parameters
    ----------
    config : ModelConfig
        Configuration object containing model hyperparameters.
    device : str, optional
        Device to run the model on ('cuda' or 'cpu'). Defaults to 'cuda' if available, otherwise 'cpu'.

    Methods
    -------
    train_epoch(train_loader)
        Train the model for one epoch.
    evaluate(val_loader)
        Evaluate the model on a validation dataset.
    """

    def __init__(self, config: ModelConfig, device=None):
        """
        Initializes the trainer with the model, optimizer, and loss function.

        Parameters
        ----------
        config : ModelConfig
            Configuration object containing model hyperparameters.
        device : str, optional
            Device to run the model on ('cuda' or 'cpu'). Defaults to 'cuda' if available, otherwise 'cpu'.
        """
        self.config = config
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = SpatioTemporalModel(config).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.learning_rate)
        self.criterion = nn.CrossEntropyLoss()

    def train_epoch(self, train_loader):
        """
        Trains the model for one epoch.

        Parameters
        ----------
        train_loader : torch.utils.data.DataLoader
            DataLoader for the training dataset.

        Returns
        -------
        float
            Average training loss for the epoch.
        """
        self.model.train()
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(self.device), target.squeeze(1).to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        return total_loss / len(train_loader)

    def evaluate(self, val_loader):
        """
        Evaluates the model on a validation dataset.

        Parameters
        ----------
        val_loader : torch.utils.data.DataLoader
            DataLoader for the validation dataset.

        Returns
        -------
        tuple
            A tuple containing:
            - float: Average validation loss.
            - float: Validation accuracy (0 to 1).
        """
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.squeeze(1).to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                total_loss += loss.item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.numel()
        return total_loss / len(val_loader), correct / total

##### This code initializes the model configuration with the specified number of spectral bands and creates an instance of CorkOakTrainer. It then loads the pre-trained model weights from the designated .pth file and sets the model to evaluation mode. This prepares the model for making predictions on validation or test datasets.

In [None]:
# Load the model
print("\nLoading the model...")
config = ModelConfig(n_bands=4)
trainer = CorkOakTrainer(config)
# Saved model can be downloaded at https://github.com/domwelsh/quercus-detection/tree/main/trained_model
trainer.model.load_state_dict(torch.load('/content/drive/MyDrive/Your_Folder/model.pth'))
trainer.model.eval()

### 3.3 Create input sequences base on satellite images


In [None]:
# Configuration of the workflow
validation_folder = "/content/drive/MyDrive/Your_Folder"
# Select a file from the validation_folder
reference_raster_path = '/content/drive/MyDrive/Your_Folder/Add_a_name_Monthly_YYYY-MM.tif'
band_names = ['NDWI', 'VARI', 'NDVI', 'rededge']

# Identify the indices of the required bands
print("Identifying band indices...")
band_indices = identify_band_indices(reference_raster_path, band_names)

# Replace NaN with -9999 in the validation images
print("Replacing NaN with -9999...")
processed_validation_folder = replace_nan_in_images(validation_folder)

# Check statistics after replacing NaN
print("\nChecking band statistics...")
verify_band_statistics(processed_validation_folder, band_indices)

# Generate sequences with the band indices
print("\nGenerating sequences...")
validation_sequences = generate_sequences_with_band_indices([processed_validation_folder], band_indices)
print(f"Shape of the generated sequences: {validation_sequences.shape}")



### 3.4 Run the model


##### **Apply the Model**
This code converts the validation sequences into a PyTorch tensor and feeds them into the trained model. It performs inference without computing gradients, generating the raw output predictions. The predictions are then processed to obtain the final class labels as NumPy arrays for further analysis.

In [None]:
# Apply the model with the specified function
print("\nApplying the model...")
validation_tensor = torch.tensor(validation_sequences, dtype=torch.float32)
with torch.no_grad():
    test_output = trainer.model(validation_tensor.to(trainer.device))
    test_predictions = test_output.argmax(dim=1).cpu().numpy()

### 3.5 Output Calculation and Visualization

##### **Validation and Visualization of Results**
This code removes padding values (-9999) from the model's predictions, saves the filtered predictions as a GeoTIFF file, and visualizes both the original data and the model's predictions side by side. It ensures that only valid prediction data is retained, facilitates spatial analysis by saving predictions in GeoTIFF format, and provides visual confirmation of the model's performance through plotted images.

In [None]:
# Exclude padding values (-9999)
valid_mask = (validation_sequences[0, 0, 0] != -9999)
filtered_predictions = np.full_like(test_predictions[0], fill_value=-9999, dtype=np.int32)
filtered_predictions[valid_mask] = test_predictions[0][valid_mask]

# Save predictions as GeoTIFF
def save_predictions_as_geotiff(predictions, reference_raster_path, output_path):
    with rasterio.open(reference_raster_path) as ref_raster:
        profile = ref_raster.profile
        profile.update(dtype=rasterio.int32, count=1, nodata=-9999)

        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(predictions, 1)

output_raster_path = '/content/drive/MyDrive/Your_Folder/predictions.tif'
save_predictions_as_geotiff(filtered_predictions, reference_raster_path, output_raster_path)
print(f"Predictions saved to: {output_raster_path}")

# Visualize the predictions
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.imshow(validation_sequences[0, 0, 0], cmap="gray")
plt.title("Original Data")
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(filtered_predictions, cmap="viridis", vmin=0, vmax=1)
plt.title("Model Predictions")
plt.colorbar()
plt.tight_layout()
plt.show()