# Notebok for training the model

### Requirements

In [None]:
import os
import numpy as np
import rasterio
from rasterio.mask import mask
from rasterio import features
import geopandas as gpd
from shapely.geometry import box
from sklearn.preprocessing import StandardScaler
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report


### Auxiliary Functions

##### **Here you will prepare, standardize, and analyze raster data for machine learning.**
It includes creating labeled masks, selecting and padding bands, generating time-series data, and building corresponding label arrays.

It also performs standard scaling, calculates classification metrics, and offers visualization tools.

In [None]:
def create_labeled_mask(gdf, reference_raster, output_path, id_column='Species_id'):
    """
    Create and save a labeled (int32) mask from a shapefile.
    """
    shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[id_column]))
    transform = reference_raster.transform
    crs = reference_raster.crs
    out_shape = (reference_raster.height, reference_raster.width)

    mask_array = features.rasterize(
        shapes=shapes,
        out_shape=out_shape,
        transform=transform,
        fill=0,
        all_touched=True,
        dtype='int32',
    )

    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=mask_array.shape[0],
        width=mask_array.shape[1],
        count=1,
        dtype='int32',
        crs=crs,
        transform=transform,
    ) as dst:
        dst.write(mask_array, 1)


def save_masked_image(output_path, stacked_bands, transform, crs):
    """
    Save a multi-band raster (stacked_bands) as GeoTIFF.
    """
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=stacked_bands.shape[1],
        width=stacked_bands.shape[2],
        count=stacked_bands.shape[0],
        dtype='float32',
        crs=crs,
        transform=transform,
    ) as dst:
        for band_index in range(stacked_bands.shape[0]):
            dst.write(stacked_bands[band_index, :, :], band_index + 1)


def select_band_by_name(src, band_names, target_band_names):
    """
    Select the desired bands (by name) from a raster opened with rasterio.
    Returns a np.array with the selected bands.
    """
    band_indices = [
        band_names.index(target_band_name) + 1
        for target_band_name in target_band_names
        if target_band_name in band_names
    ]
    if not band_indices:
        print(f"None of the target bands found in {band_names}")
        return None
    return src.read(band_indices)


def pad_to_max_dimensions(image, max_height, max_width, fill_value=-9999):
    """
    Pads an array (C, H, W) to match (C, max_height, max_width).
    """
    bands, height, width = image.shape
    padded_image = np.full((bands, max_height, max_width), fill_value, dtype=image.dtype)
    padded_image[:, :height, :width] = image
    return padded_image


def generate_sequences_from_folders(folders, target_band_names, gdf, scaler=None, scaler_path=None):
    """
    Generates sequences from folders containing rasters.
    Each folder should contain 13 images (e.g., 13 months).
    Also performs padding and standardization (StandardScaler).
    Returns a 5D array: (tiles, timesteps=13, bands, height, width).
    """
    max_height, max_width = 0, 0
    all_images = []

    # Determine max height and width
    for folder in folders:
        image_paths = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.endswith(('.tif', '.tiff'))
        ]
        for image_path in image_paths:
            with rasterio.open(image_path) as src:
                max_height = max(max_height, src.height)
                max_width = max(max_width, src.width)

    # Process images in each folder
    for folder in folders:
        image_paths = sorted([
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.endswith(('.tif', '.tiff'))
        ])

        # Group into blocks of 13 timesteps
        num_tiles = len(image_paths) // 13
        if len(image_paths) % 13 != 0:
            raise ValueError("Number of images is not a multiple of 13.")

        for i in range(num_tiles):
            tile_images = []
            for j in range(13):
                with rasterio.open(image_paths[i * 13 + j]) as src:
                    band_names = list(src.descriptions)

                    # Select bands
                    selected_bands = select_band_by_name(src, band_names, target_band_names)
                    if selected_bands is None:
                        raise ValueError(f"Bands {target_band_names} not found in {band_names}.")

                    # Apply mask (optionally, if you want to clip with polygons from gdf)
                    out_image, _ = mask(
                        src,
                        gdf.geometry,
                        crop=False,
                        all_touched=True
                    )

                    # Filter only the target bands
                    out_image = out_image[np.isin(band_names, target_band_names)]

                    # Padding
                    padded_image = pad_to_max_dimensions(out_image, max_height, max_width, fill_value=-9999)
                    tile_images.append(padded_image)

            # Stack (timesteps=13, bands, height, width)
            tile_sequence = np.stack(tile_images, axis=0)
            all_images.append(tile_sequence)

    all_images = np.array(all_images)  # shape: (n_tiles, 13, bands, H, W)
    n_tiles, timesteps, n_bands, height, width = all_images.shape

    # Transpose and reshape for scaling
    transposed_images = all_images.transpose(0, 1, 3, 4, 2)
    reshaped_images = transposed_images.reshape(-1, n_bands)

    # Standardization
    if scaler is None:
        scaler = StandardScaler()
        is_training = True
    else:
        is_training = False

    valid_mask = (reshaped_images != -9999).all(axis=1)
    valid_data = reshaped_images[valid_mask]

    standardized_data = reshaped_images.copy()
    if is_training:
        standardized_data[valid_mask] = scaler.fit_transform(valid_data)
        if scaler_path:
            joblib.dump(scaler, scaler_path)
    else:
        standardized_data[valid_mask] = scaler.transform(valid_data)

    # Reshape back to original shape
    standardized_images = (
        standardized_data.reshape(n_tiles, timesteps, height, width, n_bands)
        .transpose(0, 1, 4, 2, 3)
    )

    # Print stats
    for band_idx in range(n_bands):
        valid_band_data = standardized_images[..., band_idx, :, :][
            standardized_images[..., band_idx, :, :] != -9999
        ]
        print(f"\nBand {band_idx} ({target_band_names[band_idx]}) stats:")
        print(f"Mean: {np.mean(valid_band_data):.8f}")
        print(f"Std:  {np.std(valid_band_data):.8f}")

    return standardized_images, scaler


def create_sequence_labels(mask_path, sequence_shape):
    """
    Create binary labels from a mask file.
    Resulting shape: (n_sequences, timesteps, height, width)
    """
    import rasterio
    from skimage.transform import resize

    # Read the mask
    with rasterio.open(mask_path) as src:
        mask = src.read(1)

    # Binary (1=tree, 0=non-tree)
    binary_mask = (mask > 0).astype(np.int32)

    n_sequences, timesteps, _, height, width = sequence_shape

    # Resize if dimensions don't match
    if binary_mask.shape != (height, width):
        print(f"Resizing mask from {binary_mask.shape} to {(height, width)}")
        binary_mask = resize(binary_mask, (height, width), order=0, preserve_range=True).astype(np.int32)

    # Repeat across timesteps
    timestep_labels = np.repeat(binary_mask[np.newaxis, :, :], timesteps, axis=0)

    # Repeat across sequences
    labels = np.repeat(timestep_labels[np.newaxis, :, :, :], n_sequences, axis=0)

    return labels


def create_multi_sequence_labels(mask_paths, sequence_shape):
    """
    Create binary labels for multiple mask files.
    Example: if there are 10 training folders, pass 10 masks.
    """
    import rasterio
    from skimage.transform import resize

    n_sequences, timesteps, _, target_height, target_width = sequence_shape
    sequences_per_folder = n_sequences // len(mask_paths)

    all_labels = []
    for mask_path in mask_paths:
        with rasterio.open(mask_path) as src:
            mask = src.read(1)
            binary_mask = (mask > 0).astype(np.int32)

            # Adjust dimensions if needed
            if binary_mask.shape != (target_height, target_width):
                print(f"Resizing mask from {binary_mask.shape} to {(target_height, target_width)}")
                binary_mask = resize(binary_mask, (target_height, target_width),
                                     order=0, preserve_range=True).astype(np.int32)

            folder_labels = np.repeat(binary_mask[np.newaxis, np.newaxis, :, :], sequences_per_folder, axis=0)
            all_labels.append(folder_labels)

    return np.concatenate(all_labels, axis=0)


def calculate_statistics(true_labels, predicted_labels):
    """
    Calculate classification metrics: confusion matrix, classification report, etc.
    """
    true_labels = true_labels.flatten()
    predicted_labels = predicted_labels.flatten()

    cm = confusion_matrix(true_labels, predicted_labels, labels=[0, 1])
    print("Confusion Matrix:\n", cm)

    # Overall Accuracy
    oa = np.trace(cm) / np.sum(cm)

    # Producer's Accuracy
    pa = cm.diagonal() / cm.sum(axis=1)

    # User's Accuracy
    pu = cm.diagonal() / cm.sum(axis=0)

    report = classification_report(true_labels, predicted_labels, target_names=["Non-Tree", "Tree"])
    print("\nClassification Report:\n", report)

    return {
        "Overall Accuracy": oa,
        "Producers Accuracy": pa,
        "Users Accuracy": pu
    }


def visualize_raster_predictions(raster_path, predictions, save_path=None):
    """
    Visualize and (optionally) save predictions as GeoTIFF.
    """
    import matplotlib.pyplot as plt
    import rasterio

    with rasterio.open(raster_path) as src:
        profile = src.profile
        profile.update(dtype=rasterio.uint8, count=1)
        if save_path:
            with rasterio.open(save_path, 'w', **profile) as dst:
                dst.write(predictions.astype(rasterio.uint8), 1)

    plt.figure(figsize=(10, 8))
    plt.imshow(predictions, cmap='viridis', interpolation='none')
    plt.colorbar(label="Predicted Labels")
    plt.title("Predicted Labels Raster")
    plt.show()


### Label File Setup (Shapefiles or Polygons Generated from Points)

##### This section will load a polygon shapefile (e.g., generated from points), read a reference raster, and create a labeled mask (GeoTIFF) based on a specific attribute column (Species_id). Each polygon is assigned an integer label in the resulting mask, aligning with the reference raster's spatial properties.

In [None]:
# Example path to shapefile (polygons)
# This file should be the shapefile containing your polygons
shapefile_path = 'example/file/path/Merged_polygons.shp'
gdf = gpd.read_file(shapefile_path)

# Load a reference raster
# This file should be the reference raster you want to open
reference_raster_path = 'example/file/path/Quercus_11_Monthly_2023-05.tif'
reference_raster = rasterio.open(reference_raster_path)

# Output mask path
# This is where the output mask (label) will be saved
output_mask_path = 'example/file/path/quercus_11_label.tif'

# Create the mask
create_labeled_mask(
    gdf=gdf,
    reference_raster=reference_raster,
    output_path=output_mask_path,
    id_column='Species_id'  # Shapefile column containing species ID
)

reference_raster.close()


### Generating Sequences and Training/Test Tensors

##### Here we generate input (images) and label (mask) tensors from multiple image folders. Each folder has 13 images (e.g., 13 months).

In [None]:
# Define folders
# Each folder should contain your monthly mosaic files for different Quercus instances
folders = [
    'example/file/path/Quercus_1',
    'example/file/path/Quercus_2',
    'example/file/path/Quercus_4',
    'example/file/path/Quercus_5',
    'example/file/path/Quercus_6',
    'example/file/path/Quercus_7',
    'example/file/path/Quercus_8',
    'example/file/path/Quercus_9',
    'example/file/path/Quercus_10',
    'example/file/path/Quercus_11'
]
test_folder = 'example/file/path/Quercus_3'

# Re-load the shapefile for masking
# This file should be the shapefile containing your polygons
shapefile_path = 'example/file/path/Merged_polygons.shp'
gdf = gpd.read_file(shapefile_path)

# Target band names
band_names = ['NDWI', 'VARI', 'NDVI', 'rededge']

# Generate training sequences
train_sequences, scaler_train = generate_sequences_from_folders(
    folders,
    target_band_names=band_names,
    gdf=gdf
)

# Generate test sequences
test_sequences, scaler_test = generate_sequences_from_folders(
    [test_folder],
    target_band_names=band_names,
    gdf=gdf
)

# Create training masks (multi)
# These paths should point to the label TIFF files corresponding to each Quercus folder above
train_labels = create_multi_sequence_labels(
    [
        'example/file/path/quercus_1_label.tif',
        'example/file/path/quercus_2_label.tif',
        'example/file/path/quercus_4_label.tif',
        'example/file/path/quercus_5_label.tif',
        'example/file/path/quercus_6_label.tif',
        'example/file/path/quercus_7_label.tif',
        'example/file/path/quercus_8_label.tif',
        'example/file/path/quercus_9_label.tif',
        'example/file/path/quercus_10_label.tif',
        'example/file/path/quercus_11_label.tif'
    ],
    train_sequences.shape
)

# Convert to PyTorch tensors
train_tensor = torch.from_numpy(train_sequences).float()
train_labels_tensor = torch.from_numpy(train_labels).long()

# Generate test labels (if needed, single mask example):
# This path should be the label TIFF file for the test folder (Quercus_3)
test_mask_path = 'example/file/path/quercus_3_label.tif'
test_labels = create_sequence_labels(test_mask_path, test_sequences.shape)
test_tensor = torch.from_numpy(test_sequences).float()
test_labels_tensor = torch.from_numpy(test_labels).long()

# Check shapes
print(f"Train Tensor Shape: {train_tensor.shape}, dtype: {train_tensor.dtype}")
print(f"Train Labels Tensor Shape: {train_labels_tensor.shape}, dtype: {train_labels_tensor.dtype}")



##### If you wish to save these tensors for future use:

In [None]:
# Save tensors to disk
# This file will store the training and testing tensors along with their labels
torch.save({
    'train_data': train_tensor,
    'train_labels': train_labels_tensor,
    'test_data': test_tensor,
    'test_labels': test_labels_tensor
}, 'example/file/path/q1_11train_q3test_4_bands_tensors.pt')

# Load saved tensors
# This path should point to the saved tensors file
loaded_tensors = torch.load('example/file/path/q1_11train_q3test_4_bands_tensors.pt')
train_tensor = loaded_tensors['train_data']
train_labels_tensor = loaded_tensors['train_labels']
test_tensor = loaded_tensors['test_data']
test_labels_tensor = loaded_tensors['test_labels']


### Defining and Training the Model

##### **Now you will define a spatio-temporal CNN model for tree crown (cork oak) classification and a corresponding training pipeline.**

The ModelConfig dataclass specifies default hyperparameters (e.g., number of bands, channel sizes, dropout rate). The SpatioTemporalModel processes time-series inputs through convolutional layers, producing both a pixel-wise crown prediction and a rough estimate of tree count. The CorkOakTrainer orchestrates the training loop, computing the cross-entropy loss for pixel classification, overall accuracy, and tree count estimates.

In [None]:
from dataclasses import dataclass

@dataclass
class ModelConfig:
    """
    Configuration for the SpatioTemporalModel.
    """
    n_bands: int = 17
    cnn_channels: list = None
    cnn_kernel_size: int = 3
    cnn_dropout: float = 0.1
    learning_rate: float = 0.001
    avg_tree_pixels: int = 10  # Used for tree count estimation

    def __post_init__(self):
        if self.cnn_channels is None:
            self.cnn_channels = [32, 64, 128]


class SpatioTemporalModel(nn.Module):
    """
    Spatio-temporal CNN model, including tree count estimation.
    """
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # CNN layers
        cnn_layers = []
        in_channels = config.n_bands
        for out_channels in config.cnn_channels:
            cnn_layers.extend([
                nn.Conv2d(in_channels, out_channels, kernel_size=config.cnn_kernel_size, padding=config.cnn_kernel_size//2),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Dropout2d(config.cnn_dropout)
            ])
            in_channels = out_channels
        self.cnn = nn.Sequential(*cnn_layers)

        # Pixel classifier
        self.pixel_classifier = nn.Conv2d(
            in_channels=config.cnn_channels[-1],
            out_channels=2,
            kernel_size=1
        )

    def estimate_tree_count(self, crown_predictions):
        """
        Estimate the number of trees based on canopy cover pixels.
        """
        crown_pixels = torch.sum(crown_predictions, dim=(1,2))
        tree_count = crown_pixels / self.config.avg_tree_pixels
        return tree_count

    def forward(self, x):
        """
        x: (batch_size, timesteps, bands, height, width)
        Returns (crown_pred, tree_count).
        """
        batch_size, timesteps, bands, height, width = x.shape
        cnn_features = []

        for t in range(timesteps):
            features = self.cnn(x[:, t])
            cnn_features.append(features)

        # Average over timesteps
        cnn_features = torch.stack(cnn_features, dim=1).mean(dim=1)
        crown_pred = self.pixel_classifier(cnn_features)

        # Binary prediction
        binary_pred = crown_pred.argmax(dim=1)
        tree_count = self.estimate_tree_count(binary_pred)
        return crown_pred, tree_count


class CorkOakTrainer:
    """
    Trainer for the SpatioTemporalModel.
    """
    def __init__(self, config: ModelConfig, device=None):
        self.config = config
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = SpatioTemporalModel(config).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.learning_rate)
        self.criterion = nn.CrossEntropyLoss()

    def train_epoch(self, train_loader):
        """
        Trains the model for 1 epoch.
        """
        self.model.train()
        total_loss = 0

        for data, target in train_loader:
            data = data.to(self.device)
            # target is (n, 1, h, w) -> remove extra dimension: target.squeeze(1)
            target = target.squeeze(1).to(self.device)

            self.optimizer.zero_grad()
            crown_pred, _ = self.model(data)
            loss = self.criterion(crown_pred, target)
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(train_loader)

    def evaluate(self, val_loader):
        """
        Evaluates the model on val_loader.
        """
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        tree_counts = []

        with torch.no_grad():
            for data, target in val_loader:
                data = data.to(self.device)
                target = target.squeeze(1).to(self.device)

                crown_pred, tree_count = self.model(data)
                loss = self.criterion(crown_pred, target)
                total_loss += loss.item()

                pred = crown_pred.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.numel()

                tree_counts.extend(tree_count.cpu().numpy())

        return (
            total_loss / len(val_loader),
            correct / total,
            tree_counts
        )


#####**Training Example**

##### **How to set up a training loop for the cork oak model.**

First, it creates a PyTorch TensorDataset and DataLoader from the training tensors, specifying a batch_size and shuffling. Next, it initializes the model configuration (ModelConfig) using the number of input bands from the training data. The CorkOakTrainer is then used to train the model for a specified number of epochs, printing the training loss each time. Finally, the trained model’s parameters are saved to a .pth file.

In [None]:
# Create Dataset and DataLoader
train_dataset = TensorDataset(train_tensor, train_labels_tensor)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

config = ModelConfig(n_bands=train_tensor.shape[2])
trainer = CorkOakTrainer(config)

num_epochs = 2  # example
for epoch in range(num_epochs):
    train_loss = trainer.train_epoch(train_loader)
    print(f"Epoch {epoch+1}: Training Loss = {train_loss:.4f}")

print("Training complete.")

# Save model
# This path should point to where you want to save the trained model
model_save_path = 'example/file/path/q1_11train_q3test_model.pth'
torch.save(trainer.model.state_dict(), model_save_path)
print(f"Model saved at: {model_save_path}")


### Evaluation and Visualization of Results

##### Load a trained spatio-temporal model and apply it to the test set, generating predicted labels and estimating the number of trees. It calculates various classification metrics (overall, producer’s, and user’s accuracy) and visualizes the predictions over a reference raster. Finally, it compares the model’s estimated tree count to the actual number of trees derived from the shapefile.

In [None]:
# Load trained model (if needed)
config = ModelConfig(n_bands=train_tensor.shape[2])
model = SpatioTemporalModel(config)
# This path should point to the saved trained model
model_save_path = 'example/file/path/q1_11train_q3test_model.pth'
model.load_state_dict(torch.load(model_save_path))
model.eval()

# Inference on the test set (assuming a single batch or adapting DataLoader)
with torch.no_grad():
    crown_pred, tree_count = model(test_tensor)

# Predictions
test_predictions = crown_pred.argmax(dim=1).cpu().numpy()

# True labels, if you want stats
true_labels_test = test_labels_tensor.squeeze(1).cpu().numpy()

# Statistics
stats = calculate_statistics(true_labels_test, test_predictions)
print("\nMetrics:")
print(f"OA (Overall Accuracy): {stats['Overall Accuracy']:.2f}")
print(f"PA (Producer's Accuracy): {stats['Producers Accuracy']}")
print(f"UA (User's Accuracy): {stats['Users Accuracy']}")

# Visualization (example for the first tile/time)
# This path should point to the raster file you want to visualize
quercus_3_raster_path = 'example/file/path/Quercus_3_Monthly_2023-05.tif'
# This path should point to where you want to save the prediction visualization
visualize_raster_predictions(
    raster_path=quercus_3_raster_path,
    predictions=test_predictions[0],
    save_path='example/file/path/example_prediction.tif'
)

# Tree count comparison
# This path should point to the shapefile containing your polygons
shapefile_path = 'example/file/path/Merged_polygons.shp'
gdf = gpd.read_file(shapefile_path)
with rasterio.open(quercus_3_raster_path) as src:
    bounds = src.bounds
    raster_bbox = box(*bounds)
    gdf_filtered = gdf[gdf.geometry.intersects(raster_bbox)]
    actual_tree_count = len(gdf_filtered)

print("\nTree Count Comparison:")
print(f"Actual trees in shapefile: {actual_tree_count}")
print(f"Estimated trees by the model (first tile): {int(tree_count[0].item())}")


### Validation in a Test Area / Inference

##### **How to validate the model on a different test area (folder).**

First, it identifies the indices of relevant bands in a reference raster (e.g., NDVI, NDWI, VARI). It then generates standardized sequences by reading and padding the images in the folder to consistent dimensions. Using a previously trained scaler, it standardizes the data and converts it to a PyTorch tensor for inference. The model’s predictions are then analyzed to compute basic statistics on the proportion of tree pixels.

In [None]:
def identify_band_indices(file_path, target_band_names):
    """
    Identify indices of the required bands in a raster.
    """
    with rasterio.open(file_path) as src:
        band_names = list(src.descriptions)
        band_indices = [
            band_names.index(band) + 1 if band in band_names else None
            for band in target_band_names
        ]
        for band, index in zip(target_band_names, band_indices):
            if index is None:
                raise ValueError(f"Band {band} not found.")
    return band_indices


def generate_sequences_with_band_indices(folders, band_indices, scaler=None):
    """
    Generate sequences (13 timesteps) using already identified band indices.
    Apply a given scaler if provided.
    """
    sequences = []
    max_height, max_width = 0, 0

    # Determine max dimensions
    for folder in folders:
        image_paths = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.endswith(('.tif', '.tiff'))
        ]
        for image_path in image_paths:
            with rasterio.open(image_path) as src:
                max_height = max(max_height, src.height)
                max_width = max(max_width, src.width)

    # Read and stack images
    for folder in folders:
        image_paths = sorted([
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.endswith(('.tif', '.tiff'))
        ])

        num_tiles = len(image_paths) // 13
        if len(image_paths) % 13 != 0:
            raise ValueError("Number of images is not a multiple of 13.")

        for i in range(num_tiles):
            tile_images = []
            for j in range(13):
                with rasterio.open(image_paths[i * 13 + j]) as src:
                    selected_bands = src.read(band_indices)
                    padded_image = np.full(
                        (selected_bands.shape[0], max_height, max_width),
                        -9999,
                        dtype=selected_bands.dtype
                    )
                    padded_image[:, :selected_bands.shape[1], :selected_bands.shape[2]] = selected_bands
                    tile_images.append(padded_image)
            sequences.append(np.stack(tile_images, axis=0))

    sequences = np.array(sequences)  # shape: (tiles, timesteps, bands, H, W)
    n_tiles, timesteps, n_bands, height, width = sequences.shape

    # Standardization
    transposed_sequences = sequences.transpose(0, 1, 3, 4, 2)
    reshaped_sequences = transposed_sequences.reshape(-1, n_bands)

    if scaler is None:
        scaler = StandardScaler()
        is_training = True
    else:
        is_training = False

    valid_mask = (reshaped_sequences != -9999).all(axis=1)
    valid_data = reshaped_sequences[valid_mask]

    standardized_data = reshaped_sequences.copy()
    if is_training:
        standardized_data[valid_mask] = scaler.fit_transform(valid_data)
    else:
        standardized_data[valid_mask] = scaler.transform(valid_data)

    standardized_sequences = (
        standardized_data.reshape(n_tiles, timesteps, height, width, n_bands)
        .transpose(0, 1, 4, 2, 3)
    )

    return standardized_sequences


# Example usage in validation
# This path should point to the validation folder containing the raster files
validation_folder = "example/file/path/Validacion_holm_oak"
# This path should point to the reference raster within the validation folder
reference_raster_path = os.path.join(validation_folder, "Validacion_holm_oak_Monthly_2023-05.tif")
band_names = ['NDWI', 'VARI', 'NDVI', 'rededge']

band_indices = identify_band_indices(reference_raster_path, band_names)
validation_sequences = generate_sequences_with_band_indices([validation_folder], band_indices, scaler=scaler_train)

# We take only 1 timestep if we want to test a single month...
validation_tensor = torch.tensor(validation_sequences[:, 0, :, :, :], dtype=torch.float32)

# Inference
model.eval()
with torch.no_grad():
    crown_pred_val, tree_count_val = model(validation_tensor.to('cpu'))  # Adjust if device is 'cuda'
    test_predictions_val = crown_pred_val.argmax(dim=1).cpu().numpy()

# Basic stats
def calculate_tree_stats_from_predictions(predictions, nodata_value=-9999):
    valid_mask = predictions != nodata_value
    tree_pixels = np.sum(predictions[valid_mask] == 1)
    total_pixels = np.sum(valid_mask)
    tree_fraction = tree_pixels / total_pixels if total_pixels > 0 else 0
    return {
        'tree_pixels': tree_pixels,
        'total_pixels': total_pixels,
        'tree_fraction': tree_fraction,
    }

stats_val = calculate_tree_stats_from_predictions(test_predictions_val[0])
print("Validation image statistics:", stats_val)
