In [None]:
!pip install geedim

In [None]:
import ee
from geedim.mask import MaskedImage
import geedim as gd
import geemap
import os

ee.Authenticate()
ee.Initialize(project='your-project')


gd.Initialize()

In [None]:
# Complete grid of the area containing the references

gridRice = ee.FeatureCollection(
    [ee.Feature(
        ee.Geometry.Polygon(
            [[[-57.755403520658675, -30.273888236002545],
              [-57.414827348783675, -30.48240355814265],
              [-56.887483598783675, -30.16946399348635],
              [-53.196077348783675, -32.746274998519404],
              [-53.701448442533675, -33.69291864564506],
              [-53.382844926908675, -33.802539093056104],
              [-52.789583208158675, -33.27141183337489],
              [-52.163362505033675, -32.37589711132447],
              [-51.778841020658675, -31.92943245126368],
              [-51.009798051908675, -31.452677123648254],
              [-50.449495317533675, -30.784902414022813],
              [-50.075960161283675, -30.178961686316303],
              [-49.856233598783675, -29.731583524555113],
              [-50.570344926908675, -29.578826421806646],
              [-51.086702348783675, -29.617037430497966],
              [-51.866731645658675, -29.531042314547424],
              [-52.602815630033675, -29.511922346282788],
              [-53.251008989408675, -29.444974019819423],
              [-53.954133989408675, -29.416268349104133],
              [-54.657258989408675, -29.406697989424885],
              [-55.162630083158675, -29.50236100743873],
              [-55.492219926908675, -29.215101811421185],
              [-55.305452348783675, -28.89818403427908],
              [-55.316438676908675, -28.251766454405896],
              [-55.865755083158675, -28.164632555056674]]]),
        {
          "system:index": "0"
        })])

In [None]:
# Setting
year = 2024

startDate = '2019-12-01'
endDate = '2020-03-01'
startDate_aim = f'{year - 1}-12-01'
endDate_aim = f'{year}-03-01'

chirp_scale, chirp_size = 30, 1024
chirp_size_m = chirp_scale * chirp_size

cloudCoverValue = 80
uf_code = 'UF do estado'

output_folder = f'DATASET_RICE_PERC_{uf_code}_{year}'

# Collection and input layers
ref_map = ee.FeatureCollection('projects/assets/reference_map')
estados = ee.FeatureCollection('regions/ibge_estados_2019')
proj = gridRice.first().geometry().projection()

# Function
def filter_landsat(path, roi, start, end, cloud_max):
    return ee.ImageCollection(path) \
        .filterDate(start, end) \
        .filterBounds(roi) \
        .filter(ee.Filter.lt('CLOUD_COVER_LAND', cloud_max))


def padronize_band_names(image):
    spacecraft_id = image.get('SPACECRAFT_ID')

    old_band_names = ee.Dictionary({
        'LANDSAT_5': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7', 'B6', 'QA_PIXEL'],
        'LANDSAT_7': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7', 'B6_VCID_1', 'QA_PIXEL'],
        'LANDSAT_8': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'QA_PIXEL'],
        'LANDSAT_9': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'QA_PIXEL']
    }).get(spacecraft_id)

    new_band_names = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'tir1', 'BQA']

    return ee.Algorithms.If(
        old_band_names,
        image.select(ee.List(old_band_names), new_band_names),
        image
    )


def mask_clouds(image):
    qa = image.select('BQA')
    mask = (qa.bitwiseAnd(1 << 3)
            .And(qa.bitwiseAnd(1 << 8).Or(qa.bitwiseAnd(1 << 9)))
            .Or(qa.bitwiseAnd(1 << 1))
            .Or(qa.bitwiseAnd(1 << 4).And(qa.bitwiseAnd(1 << 10).Or(qa.bitwiseAnd(1 << 11))))
            .Or(qa.bitwiseAnd(1 << 5))
            .Or(qa.bitwiseAnd(1 << 7))
            .Or(qa.bitwiseAnd(1 << 14).And(qa.bitwiseAnd(1 << 15))))
    return image.updateMask(mask.Not())

def normalize_band(band_name, image, p1, p99):
    return image.select(band_name).unitScale(p1, p99).clamp(0, 1).rename(f'{band_name}_norm')

def get_evi2(image):
    evi2 = image.expression(
        '2.5 * (NIR - RED) / (NIR + 2.4 * RED + 1)',
        {
            'NIR': image.select('nir_norm'),
            'RED': image.select('red_norm')
        }).rename('evi2')
    return image.addBands(evi2)

def get_ndwi(image):
    ndwi = image.expression(
        '(NIR - SWIR1) / (NIR + SWIR1)',
        {
            'NIR': image.select('nir_norm'),
            'SWIR1': image.select('swir1_norm')
        }).rename('ndwi')
    return image.addBands(ndwi)

# Interest region
roi = estados.filter(ee.Filter.eq('SIGLA_UF', uf_code))


# Polygon division function (split_pol)
def split_pol(ft):
    id_property_name_in_grid_rice = 'id'
    ft_original_id_val = ft.get(id_property_name_in_grid_rice)

    id_value_computed = ee.Algorithms.If(
        ft_original_id_val,
        ft_original_id_val,
        ee.String('grid_').cat(ee.String(ft.get('system:index')))
    )
    ft_original_id_eeString = ee.String(id_value_computed)

    geom_reproject = ft.transform(proj.atScale(chirp_size), 1)

    def map_over_cells(ftg):
        ftg = ee.Feature(ftg)
        cell_idx = ee.String(ftg.get('system:index')).split(',').join('_').replace('-', '1')
        unique_id_for_export = ft_original_id_eeString.cat('_').cat(cell_idx)
        return ftg.copyProperties(ft).set('id', unique_id_for_export)

    return geom_reproject.geometry().coveringGrid(proj, chirp_size_m).map(map_over_cells)

bigs_splitted = gridRice.map(split_pol).flatten()


print(gridRice.getInfo())

# Reference
reference = ee.Image(0).paint(ref_map, 1).rename('reference').clip(roi)


# Collection Landsat
l5 = filter_landsat("LANDSAT/LT05/C02/T1_TOA", roi, "2000-01-01", "2011-10-01", cloudCoverValue)
l7a = filter_landsat("LANDSAT/LE07/C02/T1_TOA", roi, "2000-01-01", "2003-05-31", cloudCoverValue)
l7b = filter_landsat("LANDSAT/LE07/C02/T1_TOA", roi, "2011-10-01", "2013-03-01", cloudCoverValue)
l8 = filter_landsat("LANDSAT/LC08/C02/T1_TOA", roi, "2013-03-01", "2030-01-01", cloudCoverValue)
l9 = filter_landsat("LANDSAT/LC09/C02/T1_TOA", roi, "2019-03-01", "2030-01-01", cloudCoverValue)


collection = l8.merge(l9).merge(l7a).merge(l7b).merge(l5) \
    .map(lambda img: ee.Image(padronize_band_names(img))) \
    .map(mask_clouds) \
    .filterDate(startDate, endDate)

median_ref = collection.median()


masked = median_ref.updateMask(reference)
bands = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']
percentiles = masked.select(bands).reduceRegion(
    reducer=ee.Reducer.percentile([1, 99]),
    geometry=roi.geometry(),
    scale=chirp_scale,
    maxPixels=1e13
)



# Assemble and process the referece collection
aim_collection = l8.merge(l9).merge(l7a).merge(l7b).merge(l5) \
    .map(lambda img: ee.Image(padronize_band_names(img))) \
    .map(mask_clouds) \
    .filterDate(startDate_aim, endDate_aim)

median = aim_collection.median()



# Normalized bands

norm_bands = []
for b in bands:
    p1 = ee.Number(percentiles.get(f'{b}_p1'))
    p99 = ee.Number(percentiles.get(f'{b}_p99'))
    norm_bands.append(normalize_band(b, median, p1, p99))

normalized = ee.Image(norm_bands).toFloat()
with_indices = get_ndwi(get_evi2(normalized))

# Normalized EVI2
evi2 = with_indices.select('evi2')
evi2_stats = evi2.reduceRegion(
    reducer=ee.Reducer.minMax(),
    geometry=roi.geometry(),
    scale=chirp_scale,
    maxPixels=1e13
)
evi2_norm = evi2.unitScale(evi2_stats.get('evi2_min'), evi2_stats.get('evi2_max')).rename('evi2_norm')
mosaic_unet = with_indices.addBands(evi2_norm).select(['evi2_norm', 'swir1_norm', 'swir2_norm'])


image_to_export = mosaic_unet.unmask().multiply(255).uint8()


In [None]:
# Exportation in iiles with geedim

grid_list = bigs_splitted.aggregate_array('id').getInfo()

output_folder = '/path/tile_export'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

gd_image = gd.MaskedImage(image_to_export)

for i, grid_id in enumerate(grid_list):
  gd_image.download(
      filename=f'{output_folder}/mosaic_{grid_id}_mosaic.tif',
      region=bigs_splitted.filter(ee.Filter.eq('id', grid_id)).geometry(),
      scale=chirp_scale,
      crs='EPSG:3857',
      overwrite=True,
      bands=['evi2_norm', 'swir1_norm', 'swir2_norm'],
      resampling='near',
      dtype='uint8',
      scale_offset=None
  )

print(f"\nExport started. The imagesAs imagens will be save at path '{output_folder}'.")
print(f"A file will be generated .tif for each of the {bigs_splitted.size().getInfo()} geometries of your grid.")

# INFERENCE

## UNET

In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    """
    A standard U-Net architecture for semantic segmentation.
    Args:
        in_channels (int): Number of input channels (e.g., 3 for RGB).
        out_channels (int): Number of output classes (e.g., 1 for binary segmentation).
        init_features (int): Number of features in the first convolutional layer.
        no_drop (bool): If True, dropout layers are disabled (replaced with Identity).
    """
    def __init__(self, in_channels=3, out_channels=1, init_features=64, no_drop=True):
        super(UNet, self).__init__()
        self.no_drop = no_drop

        features = init_features

        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop1 = nn.Dropout(0.25) if not no_drop else nn.Identity()

        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop2 = nn.Dropout(0.25) if not no_drop else nn.Identity()

        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop3 = nn.Dropout(0.5) if not no_drop else nn.Identity()

        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop4 = nn.Dropout(0.5) if not no_drop else nn.Identity()

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
        self.drop5 = nn.Dropout(0.5) if not no_drop else nn.Identity()

        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = UNet._block(features * 16, features * 8, name="dec4")

        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet._block(features * 8, features * 4, name="dec3")

        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet._block(features * 4, features * 2, name="dec2")

        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        # Camada final
        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.drop1(self.encoder1(x))
        enc2 = self.drop2(self.encoder2(self.pool1(enc1)))
        enc3 = self.drop3(self.encoder3(self.pool2(enc2)))
        enc4 = self.drop4(self.encoder4(self.pool3(enc3)))

        # Bottleneck
        bottleneck = self.drop5(self.bottleneck(self.pool4(enc4)))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.conv(dec1)

    @staticmethod
    def _block(in_channels, features, name):
        """
        Creates a standard U-Net block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),
        )

In [None]:
"""
This section defines the `GeoTIFFPredictor` class, which uses the trained model
to perform segmentation on large, new GeoTIFF images. It employs a sliding
window approach with blending to produce seamless prediction maps.
"""

import rasterio
import numpy as np
import torch
import torch.nn.functional as F
from rasterio.windows import Window
from pathlib import Path
from tqdm import tqdm
import os

class GeoTIFFPredictor:
    """
    Performs windowed inference on GeoTIFF files using a trained PyTorch model.
    Handles patch-based prediction with overlapping windows and smooth blending.
    """
    def __init__(self, model, device, window_size=256, overlap=64):
        """
        Initializes the predictor.
        Args:
            model (torch.nn.Module): Trained segmentation model.
            device (str or torch.device): Device to run the model on ('cpu' or 'cuda').
            window_size (int): The size of the processing window (patch).
            overlap (int): Overlap between windows for smooth blending.
        """
        self.model = model.to(device)
        self.device = device
        self.window_size = window_size
        self.overlap = overlap
        self.stride = window_size - overlap
        # A ramp for smooth blending at the edges
        self.ramp = np.linspace(0, 1, overlap // 2)
        self.band_stats = {}

    def normalize_band(self, band, scale='mm'):
        """
        Normalizes a single image band according to the selected method.
        Args:
            band (np.ndarray): Band data.
            scale (str): Normalization method: 'mm' (min-max), 'ss' (standard score), or 'div255'.
        Returns:
            np.ndarray: Normalized band.
        """
        if scale == 'mm':
            min_val, max_val = band.min(), band.max()
            return (band - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(band, dtype=np.float32)
        elif scale == 'ss':
            mean_val, std_val = band.mean(), band.std()
            return (band - mean_val) / std_val if std_val > 0 else np.zeros_like(band, dtype=np.float32)
        elif scale == 'div255':
            return band.astype(np.float32) / 255.0
        else:
            raise ValueError("Invalid normalization option. Use 'mm', 'ss', or 'div255'.")

    def get_blend_weights(self, h, w, y_start, x_start, src_h, src_w):
        """
        Generates spatial blending weights for a patch to ensure smooth transitions.
        Weights are 1 in the center and ramp down to 0 at the edges of the overlap area.
        """
        weights = np.ones((h, w), dtype=np.float32)
        half_overlap = self.overlap // 2

        # Apply ramps to the edges of the patch
        if y_start > 0:
            weights[:half_overlap, :] *= self.ramp[:, np.newaxis]
        if y_start + h < src_h:
            weights[-half_overlap:, :] *= self.ramp[::-1, np.newaxis]
        if x_start > 0:
            weights[:, :half_overlap] *= self.ramp[np.newaxis, :]
        if x_start + w < src_w:
            weights[:, -half_overlap:] *= self.ramp[np.newaxis, ::-1]

        return torch.from_numpy(weights).to(self.device)

    def predict_geotiff(self, input_path, output_path, return_probs=True, scaler='ss'):
        """
        Runs prediction over a GeoTIFF file using sliding window inference.
        Args:
            input_path (str): Path to the input GeoTIFF file.
            output_path (str): Path to save the predicted output GeoTIFF.
            return_probs (bool): If True, saves a probability map; otherwise, saves a binary mask.
            scaler (str): Normalization method ('mm', 'ss', or 'div255').
        """
        with rasterio.open(input_path) as src:
            full_image = src.read().astype(np.float32)
            print(full_image.shape)
            if full_image.shape[0] == 4:
              full_image = full_image[:-1, :, :]
            print(full_image.shape)
            # Normalize each band of the image
            normalized_image = np.array([self.normalize_band(full_image[b], scaler) for b in range(full_image.shape[0])])

            orig_height, orig_width = normalized_image.shape[1], normalized_image.shape[2]

            # Pad the image to ensure it's divisible by the window size
            pad_h = (self.stride - (orig_height - self.overlap) % self.stride) % self.stride
            pad_w = (self.stride - (orig_width - self.overlap) % self.stride) % self.stride
            padded_image = np.pad(normalized_image, ((0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
            padded_height, padded_width = padded_image.shape[1], padded_image.shape[2]

            # Create empty arrays to accumulate predictions and weights
            full_pred = np.zeros((padded_height, padded_width), dtype=np.float32)
            full_count = np.zeros((padded_height, padded_width), dtype=np.float32)

            # Generate all window offsets
            offsets = [
                (y, x) for y in range(0, padded_height - self.overlap, self.stride)
                       for x in range(0, padded_width - self.overlap, self.stride)
            ]

            # Perform model inference on each window
            for y_start, x_start in tqdm(offsets, desc="Processing windows"):
                y_end, x_end = y_start + self.window_size, x_start + self.window_size
                chip = padded_image[:, y_start:y_end, x_start:x_end]
                input_tensor = torch.from_numpy(chip).unsqueeze(0).to(self.device)

                with torch.no_grad():
                    output = self.model(input_tensor)
                    pred = torch.sigmoid(output).squeeze().cpu().numpy()

                h, w = pred.shape
                weights = self.get_blend_weights(h, w, y_start, x_start, padded_height, padded_width)

                full_pred[y_start:y_end, x_start:x_end] += pred * weights.cpu().numpy()
                full_count[y_start:y_end, x_start:x_end] += weights.cpu().numpy()

            # Normalize the prediction by the sum of weights to get the final blended result
            full_pred = np.divide(full_pred, full_count, where=full_count > 0)

            # Remove padding to return to original dimensions
            final_pred = full_pred[:orig_height, :orig_width]

            if not return_probs:
                final_pred = (final_pred > 0.5).astype(np.uint8)

            self.save_geotiff(output_path, final_pred, src.profile, return_probs)

    def save_geotiff(self, output_path, data, profile, return_probs):
        """
        Saves the prediction output as a GeoTIFF file using original metadata.
        """
        profile.update({
            'driver': 'GTiff',
            'height': data.shape[0],
            'width': data.shape[1],
            'count': 1,
            'dtype': 'float32' if return_probs else 'uint8',
            'nodata': None,
            'compress': 'lzw'
        })
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(data, 1)

def process_directory(input_dir, output_dir, predictor, return_probs=True, scaler='ss'):
    """
    Processes all .tif files in a directory using the GeoTIFFPredictor.
    """
    os.makedirs(output_dir, exist_ok=True)
    input_files = list(Path(input_dir).glob("*.tif"))

    for input_file in tqdm(input_files, desc="Processing files"):
        output_file = Path(output_dir) / f"{input_file.stem}_pred.tif"
        print(f"\nPredicting on {input_file.name} -> {output_file.name}")
        predictor.predict_geotiff(str(input_file), str(output_file), return_probs, scaler)


In [None]:
UF_MODEL = "RS"
UF_PREDICT = "RS"

MODEL_CHECKPOINT_NAME = "trained_model_name" #

CHECKPOINT_DIR = f"/path/trained_model/checkpoints/" # SC
INPUT_TIFF_DIR = f"/path/tiff_input"
OUTPUT_PRED_DIR = f"/path/tiff_input/pred"

# Construct the full path to the model checkpoint
checkpoint_path = os.path.join(CHECKPOINT_DIR, MODEL_CHECKPOINT_NAME)

# Ensure output directory exists
os.makedirs(OUTPUT_PRED_DIR, exist_ok=True)

# --- Model and Predictor Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model and load the trained weights
model = UNet(in_channels=3).to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval() # Set the model to evaluation mode

# Initialize the predictor
predictor = GeoTIFFPredictor(model, device, window_size=256, overlap=32)

# Process all GeoTIFF files in the specified directory
print(f"Starting inference on files in: {INPUT_TIFF_DIR}")
process_directory(
    input_dir=INPUT_TIFF_DIR,
    output_dir=OUTPUT_PRED_DIR,
    predictor=predictor,
    return_probs=True,      # Save as a probability map (float32)
    scaler='div255'         # Use simple division by 255 for normalization
)
print("Inference complete.")

In [None]:
import matplotlib.pyplot as plt
import rasterio
import os

def visualize_predictions(input_dir, prediction_dir):
    """
    Visualiza as imagens de entrada e suas predições lado a lado.

    Args:
        input_dir (str): Diretório contendo as imagens de entrada (.tif).
        prediction_dir (str): Diretório contendo as imagens de predição (.tif).
    """
    input_files = sorted([f for f in os.listdir(input_dir) if f.endswith('.tif')])
    prediction_files = sorted([f for f in os.listdir(prediction_dir) if f.endswith('.tif')])

    # Assuming a one-to-one correspondence between input and prediction files
    # based on naming convention (prediction file has "_pred" suffix)
    # You might need to adjust this logic based on your exact file naming
    prediction_map = {f.replace('_pred.tif', '.tif'): f for f in prediction_files}

    if not input_files:
        print(f"Nenhum arquivo .tif encontrado em {input_dir}")
        return
    if not prediction_files:
        print(f"Nenhum arquivo .tif encontrado em {prediction_dir}")
        return

    print(f"Encontrados {len(input_files)} arquivos de entrada e {len(prediction_files)} arquivos de predição.")

    for input_filename in input_files:
        if input_filename in prediction_map:
            prediction_filename = prediction_map[input_filename]

            input_path = os.path.join(input_dir, input_filename)
            prediction_path = os.path.join(prediction_dir, prediction_filename)

            try:
                with rasterio.open(input_path) as src_input, \
                     rasterio.open(prediction_path) as src_pred:

                    input_img = src_input.read()
                    pred_img = src_pred.read(1) # Read the single band of prediction

                    # Assuming input is multi-channel, potentially > 3
                    # Select bands to display (e.g., first 3) or create a composite
                    # This example displays the first 3 bands if available
                    if input_img.shape[0] >= 3:
                         display_input = input_img[:3].transpose(1, 2, 0) # Rearrange bands for matplotlib
                         # Simple normalization for display
                         display_input = (display_input - display_input.min()) / (display_input.max() - display_input.min())
                    elif input_img.shape[0] == 1:
                         display_input = input_img[0]
                    else:
                         print(f"Warning: Could not display input image {input_filename}. Needs 1 or at least 3 bands.")
                         continue


                    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

                    # Display Input Image
                    if input_img.shape[0] >= 3:
                        axes[0].imshow(display_input)
                    elif input_img.shape[0] == 1:
                         axes[0].imshow(display_input, cmap='gray') # Use grayscale for single band
                    axes[0].set_title(f'Imagem de Entrada: {input_filename}')
                    axes[0].axis('off')

                    # Display Prediction Image
                    # For probability maps, use a colormap like 'viridis' or 'hot'
                    # For binary masks (if saved as uint8), use 'gray'
                    cmap = 'viridis' if src_pred.profile['dtype'] == 'float32' else 'gray'
                    axes[1].imshow(pred_img, cmap=cmap)
                    axes[1].set_title(f'Predição: {prediction_filename}')
                    axes[1].axis('off')

                    plt.tight_layout()
                    plt.show()

            except Exception as e:
                print(f"Erro ao processar {input_filename} ou {prediction_filename}: {e}")
        else:
            print(f"Predição não encontrada para a imagem de entrada: {input_filename}")

# Define the directories containing the original images and the predictions
# Make sure these match the directories used in your inference code
input_image_directory = INPUT_TIFF_DIR # Assuming images are downloaded here
predicted_image_directory = OUTPUT_PRED_DIR

# Run the visualization
visualize_predictions(input_image_directory, predicted_image_directory)
