In [None]:
# USE THE GPUS-T4 ENVIRONMENT (HIGH RAM)

## RAW DATASET EXPORT

In [None]:
!pip install geedim

In [None]:
import ee
from geedim.mask import MaskedImage
import geedim as gd
import geemap
import os

ee.Authenticate()
ee.Initialize(project='mapbiomas-remap-sentinel')


gd.Initialize()

In [None]:
# Complete grid of the area containing the references

gridCitrus = [ee.Feature(
            ee.Geometry.Polygon(
                [[[-48.70588801192858, -20.035935384858842],
                  [-49.57930109786608, -20.350421721213056],
                  [-49.54634211349108, -20.931300558885997],
                  [-48.77729914474108, -21.090266628321324],
                  [-48.75532648849108, -21.56614118525384],
                  [-50.29890559005358, -22.223641598287102],
                  [-50.24946711349108, -22.629142663970672],
                  [-49.04646418380358, -23.14533155137643],
                  [-49.26619074630358, -23.55383721044761],
                  [-49.05745051192858, -23.68972531577456],
                  [-48.78279230880358, -23.37243440542318],
                  [-48.54658625411608, -23.558872615027653],
                  [-46.66792414474108, -21.89200397442101],
                  [-47.14033625411608, -21.243221782473224],
                  [-48.01374934005358, -21.437649544995725]]]),
            {
              "system:index": "0"
            })]

In [None]:
# Configurações
startDate = '2020-01-01'
endDate = '2020-12-31'


chirp_scale, chirp_size = 30, 1024
chirp_size_m = chirp_scale * chirp_size

cloudCoverValue = 80
uf_code = 'SP'

output_folder = f'DATASET_CITRUS_PERC_{uf_code}'

# collection and input layers
ref_map = ee.FeatureCollection('projects/assets/reference_map')
estados = ee.FeatureCollection('regions/ibge_estados_2019')
proj = gridCitrus.first().geometry().projection()

# Funções
def filter_landsat(path, roi, start, end, cloud_max):
    return ee.ImageCollection(path) \
        .filterDate(start, end) \
        .filterBounds(roi) \
        .filter(ee.Filter.lt('CLOUD_COVER_LAND', cloud_max))


def padronize_band_names(image):
    spacecraft_id = image.get('SPACECRAFT_ID')

    old_band_names = ee.Dictionary({
        'LANDSAT_5': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7', 'B6', 'QA_PIXEL'],
        'LANDSAT_7': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7', 'B6_VCID_1', 'QA_PIXEL'],
        'LANDSAT_8': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'QA_PIXEL'],
        'LANDSAT_9': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'QA_PIXEL']
    }).get(spacecraft_id)

    new_band_names = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'tir1', 'BQA']

    return ee.Algorithms.If(
        old_band_names,
        image.select(ee.List(old_band_names), new_band_names),
        image
    )


def mask_clouds(image):
    qa = image.select('BQA')
    mask = (qa.bitwiseAnd(1 << 3)
            .And(qa.bitwiseAnd(1 << 8).Or(qa.bitwiseAnd(1 << 9)))
            .Or(qa.bitwiseAnd(1 << 1))
            .Or(qa.bitwiseAnd(1 << 4).And(qa.bitwiseAnd(1 << 10).Or(qa.bitwiseAnd(1 << 11))))
            .Or(qa.bitwiseAnd(1 << 5))
            .Or(qa.bitwiseAnd(1 << 7))
            .Or(qa.bitwiseAnd(1 << 14).And(qa.bitwiseAnd(1 << 15))))
    return image.updateMask(mask.Not())

def normalize_band(band_name, image, p1, p99):
    return image.select(band_name).unitScale(p1, p99).clamp(0, 1).rename(f'{band_name}_norm')

def get_evi2(image):
    evi2 = image.expression(
        '2.5 * (NIR - RED) / (NIR + 2.4 * RED + 1)',
        {
            'NIR': image.select('nir_norm'),
            'RED': image.select('red_norm')
        }).rename('evi2')
    return image.addBands(evi2)

def get_ndwi(image):
    ndwi = image.expression(
        '(NIR - SWIR1) / (NIR + SWIR1)',
        {
            'NIR': image.select('nir_norm'),
            'SWIR1': image.select('swir1_norm')
        }).rename('ndwi')
    return image.addBands(ndwi)

# Interest region
roi = estados.filter(ee.Filter.eq('SIGLA_UF', uf_code))


# Polygon division function (split_pol)
def split_pol(ft):
    id_property_name_in_grid_rice = 'id'
    ft_original_id_val = ft.get(id_property_name_in_grid_rice)

    id_value_computed = ee.Algorithms.If(
        ft_original_id_val,
        ft_original_id_val,
        ee.String('grid_').cat(ee.String(ft.get('system:index')))
    )
    ft_original_id_eeString = ee.String(id_value_computed)

    geom_reproject = ft.transform(proj.atScale(chirp_size), 1)

    def map_over_cells(ftg):
        ftg = ee.Feature(ftg)
        cell_idx = ee.String(ftg.get('system:index')).split(',').join('_').replace('-', '1')
        unique_id_for_export = ft_original_id_eeString.cat('_').cat(cell_idx)
        return ftg.copyProperties(ft).set('id', unique_id_for_export)

    return geom_reproject.geometry().coveringGrid(proj, chirp_size_m).map(map_over_cells)


bigs_splitted = gridCitrus.map(split_pol).flatten()


print(gridCitrus.getInfo())

# Reference
reference = ee.Image(0).paint(ref_map, 1).rename('reference').clip(roi)


# Collection Landsat
l5 = filter_landsat("LANDSAT/LT05/C02/T1_TOA", roi, "2000-01-01", "2011-10-01", cloudCoverValue)
l7a = filter_landsat("LANDSAT/LE07/C02/T1_TOA", roi, "2000-01-01", "2003-05-31", cloudCoverValue)
l7b = filter_landsat("LANDSAT/LE07/C02/T1_TOA", roi, "2011-10-01", "2013-03-01", cloudCoverValue)
l8 = filter_landsat("LANDSAT/LC08/C02/T1_TOA", roi, "2013-03-01", "2030-01-01", cloudCoverValue)
l9 = filter_landsat("LANDSAT/LC09/C02/T1_TOA", roi, "2019-03-01", "2030-01-01", cloudCoverValue)

# Assemble and process the collection
collection = l8.merge(l9).merge(l7a).merge(l7b).merge(l5) \
    .map(lambda img: ee.Image(padronize_band_names(img))) \
    .map(mask_clouds) \
    .filterDate(startDate, endDate)

median = collection.median()

# Calculates percentiles in the reference area
masked = median.updateMask(reference)
bands = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']
percentiles = masked.select(bands).reduceRegion(
    reducer=ee.Reducer.percentile([1, 99]),
    geometry=roi.geometry(),
    scale=chirp_scale,
    maxPixels=1e13
)


# Normalized bands
norm_bands = []
for b in bands:
    p1 = ee.Number(percentiles.get(f'{b}_p1'))
    p99 = ee.Number(percentiles.get(f'{b}_p99'))
    norm_bands.append(normalize_band(b, median, p1, p99))

normalized = ee.Image(norm_bands).toFloat()

mosaic_unet = normalized.select(['red_norm', 'nir_norm', 'swir1_norm'])


image_to_export = mosaic_unet.unmask().multiply(255).uint8()
label_to_export = reference.unmask().byte()


final_image_to_export = image_to_export.addBands(label_to_export)

In [None]:
# Exportation in iiles with geedim

grid_list = bigs_splitted.aggregate_array('id').getInfo()

output_folder = '/path/tile_export'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# 5. Instancia a imagem no geedim
# Changed gd.download.BaseImage to gd.MaskedImage
gd_image = gd.MaskedImage(final_image_to_export)

# 6. Executa o download
#    - region: Sua 'grid'. O geedim vai criar um arquivo para cada feature da coleção.
#    - filename_pattern: Define como os arquivos serão nomeados. Usamos a propriedade 'id'
#      que você criou na função split_pol.
#    - scale_factor: Equivalente ao seu .multiply(255).
#    - dtype: Equivalente ao seu .uint8().
for i, grid_id in enumerate(grid_list):
  gd_image.download(
      filename=f'{output_folder}/mosaic_{grid_id}_mosaic.tif',
      region=bigs_splitted.filter(ee.Filter.eq('id', grid_id)).geometry(),
      scale=chirp_scale,
      crs='EPSG:3857',
      overwrite=True,
      bands=['red_norm', 'nir_norm', 'swir1_norm'],
      resampling='near',
      dtype='uint8',
      scale_offset=None
  )

  gd_image.download(
      filename=f'{output_folder}/label_{grid_id}_label.tif',
      region=bigs_splitted.filter(ee.Filter.eq('id', grid_id)).geometry(),
      scale=chirp_scale,
      crs='EPSG:3857',
      overwrite=True,
      bands=['reference'],
      resampling='near',
      dtype='uint8',
      scale_offset=None
  )

print(f"\nExport started. The imagesAs imagens will be save at path '{output_folder}'.")
print(f"A file will be generated .tif for each of the {bigs_splitted.size().getInfo()} ggeometries of your grid.")

# WINDOWED DATASET


In [None]:
"""
This script processes the large GeoTIFF files exported from GEE. It uses a sliding
window approach to create smaller, fixed-size patches (e.g., 256x256 pixels).
These patches are saved as individual .npy files and will serve as the input for
the U-Net model.
"""

import os
import numpy as np
from tqdm import tqdm
import rasterio
import re


input_dir = '/path/tile_export'
output_img_dir = '/path/tile_export/image'
output_label_dir = '/path/tile_export/gt'


os.makedirs(output_img_dir, exist_ok=True)
os.makedirs(output_label_dir, exist_ok=True)

# windowing parameters
patch_size = 256
#stride = int(patch_size * 0.875)
stride = int(patch_size * 0.5)
#stride = int(patch_size * 0.125)
print(stride)

# windowing function
def sliding_window(image, label, patch_size, stride):
    h, w = image.shape[:2]
    for i in range(0, h - patch_size + 1, stride):
        for j in range(0, w - patch_size + 1, stride):
            img_patch = image[i:i + patch_size, j:j + patch_size]
            label_patch = label[i:i + patch_size, j:j + patch_size]

            if np.any(label_patch):
                yield img_patch, label_patch

for file in tqdm(os.listdir(input_dir)):
    if file.endswith('mosaic.tif') and file.startswith('mosaic_'):
        match = re.match(r'mosaic_(.*)_mosaic\.tif', file)
        if not match:
            print(f"Nome inválido: {file}, pulando.")
            continue
        nome_base = match.group(1)

        label_file = f"label_{nome_base}_label.tif"

        img_path = os.path.join(input_dir, file)
        label_path = os.path.join(input_dir, label_file)

        print(img_path)
        print(label_path)

        if not os.path.exists(label_path):
            print(f"Rótulo não encontrado para {file}, pulando.")
            continue

        with rasterio.open(img_path) as src:
            image = src.read()
        image = image.transpose(1, 2, 0)

        with rasterio.open(label_path) as src:
            label = src.read(1)

        print(f"Shape original da imagem: {image.shape}")
        print(f"Shape original do label: {label.shape}")


        # Apply windowing
        for idx, (img_patch, label_patch) in enumerate(sliding_window(image, label, patch_size, stride)):
            img_save_path = os.path.join(output_img_dir, f"{nome_base}_{idx:04d}.npy")
            label_save_path = os.path.join(output_label_dir, f"{nome_base}_{idx:04d}.npy")

            np.save(img_save_path, img_patch)
            np.save(label_save_path, label_patch)


In [None]:
import matplotlib.pyplot as plt
import random
import os
import numpy as np


image_files = [f for f in os.listdir(output_img_dir) if f.endswith('.npy')]
label_files = [f for f in os.listdir(output_label_dir) if f.endswith('.npy')]


image_files.sort()
label_files.sort()

if len(image_files) != len(label_files):
    print("Warning: The number of image and label files does not match.")

# Função para visualizar um par de imagem e ground truth
def visualize_random_pair():
    if not image_files:
        print("No .npy files found in the specified directories.")
        return

    # Escolhe um arquivo aleatório
    random_index = random.randint(0, len(image_files) - 1)
    img_filename = image_files[random_index]
    label_filename = label_files[random_index] # Assume que a ordem corresponde

    img_path = os.path.join(output_img_dir, img_filename)
    label_path = os.path.join(output_label_dir, label_filename)
    # Carrega os dados
    try:
        image_data = np.load(img_path)
        label_data = np.load(label_path)
        print(image_data.shape)
        print(label_data.shape)
        print(np.unique(label_data))
    except Exception as e:
        print(f"Error loading {img_path} or {label_path} files: {e}")
        return

    # Visualiza
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Exibe a imagem. Dependendo das bandas, pode ser necessário ajustar a exibição.
    # Se for RGB ou similar (3 canais), visualize diretamente.
    # Se for de 6 canais como no exemplo, pode ser necessário selecionar alguns.
    # Aqui assumimos que as bandas normais (evi2_norm, swir1_norm, swir2_norm) estão presentes
    # e visualizamos as 3 primeiras (evi2_norm, swir1_norm, swir2_norm como RGB)
    # Ajuste conforme as bandas que você salvou e deseja visualizar.
    if image_data.shape[-1] >= 3:
        # Visualiza as 3 primeiras bandas como RGB
        axes[0].imshow(image_data[:, :, :3].astype(np.uint8)) # Assumindo valores 0-255
    else:
        # Se tiver menos de 3 bandas, visualize como grayscale (primeira banda)
        axes[0].imshow(image_data[:, :, 0], cmap='gray')

    axes[0].set_title(f'Image\n({img_filename})')
    axes[0].axis('off')

    # Exibe o ground truth
    # Use um cmap apropriado para labels binários (0 ou 1)
    axes[1].imshow(label_data, cmap='gray', vmin=0, vmax=1)
    axes[1].set_title(f'Ground Truth\n({label_filename})')
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()

# Exemplo de uso: visualize 5 pares aleatórios
for _ in range(5):
    visualize_random_pair()


# MODEL TRAINING

## UNET

In [None]:
# Install necessary libraries quietly
!pip install -q git+https://github.com/qubvel/segmentation_models.pytorch
!pip install -q rasterio
!pip install -q torchmetrics

In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    """
    A standard U-Net architecture for semantic segmentation.
    Args:
        in_channels (int): Number of input channels (e.g., 3 for RGB).
        out_channels (int): Number of output classes (e.g., 1 for binary segmentation).
        init_features (int): Number of features in the first convolutional layer.
        no_drop (bool): If True, dropout layers are disabled (replaced with Identity).
    """
    def __init__(self, in_channels=3, out_channels=1, init_features=64, no_drop=True):
        super(UNet, self).__init__()
        self.no_drop = no_drop  # Controle global do dropout

        features = init_features

        # Blocos do Encoder
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop1 = nn.Dropout(0.25) if not no_drop else nn.Identity()

        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop2 = nn.Dropout(0.25) if not no_drop else nn.Identity()

        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop3 = nn.Dropout(0.5) if not no_drop else nn.Identity()

        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop4 = nn.Dropout(0.5) if not no_drop else nn.Identity()

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
        self.drop5 = nn.Dropout(0.5) if not no_drop else nn.Identity()

        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = UNet._block(features * 16, features * 8, name="dec4")

        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet._block(features * 8, features * 4, name="dec3")

        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet._block(features * 4, features * 2, name="dec2")

        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        # Camada final
        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.drop1(self.encoder1(x))
        enc2 = self.drop2(self.encoder2(self.pool1(enc1)))
        enc3 = self.drop3(self.encoder3(self.pool2(enc2)))
        enc4 = self.drop4(self.encoder4(self.pool3(enc3)))

        # Bottleneck
        bottleneck = self.drop5(self.bottleneck(self.pool4(enc4)))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)  # Sem dropout no decoder

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.conv(dec1)

    @staticmethod
    def _block(in_channels, features, name):
        """
        Creates a standard U-Net block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),
        )

## UTIL FUNCTIONS

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision import transforms
import random

def set_seed(seed=42):
    """
    Sets the seed for reproducibility of the train/validation split.
    """
    random.seed(seed)  # Semente para operações aleatórias do Python
    np.random.seed(seed)  # Semente para NumPy
    torch.manual_seed(seed)  # Semente para PyTorch

# Define a semente ANTES do random_split
set_seed(42)

class UniqueDataset(Dataset):
    def __init__(self, img_dir, gt_dir, transform=None, scale='mm'):
        """
        Dataset for image patches and their corresponding masks.
        :param img_dir: Directory with input image patches (.npy files).
        :param gt_dir: Directory with ground truth masks (.npy files).
        :param transform: PyTorch transforms to be applied.
        :param scale: Normalization method ('mm', 'ss', 'mmpc', or 'div255').
        """
        self.img_dir = img_dir
        self.gt_dir = gt_dir
        self.transform = transform
        self.scale = scale
        self.img_names = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        gt_path = os.path.join(self.gt_dir, self.img_names[idx])


        # Check if files exist before loading
        if not os.path.exists(img_path):
            print(f"ERROR: Image file not found: {img_path}")

        if not os.path.exists(gt_path):
            print(f"ERROR: Mask file not found: {gt_path}")

        try:
            image = np.load(img_path).astype(np.float32)
        except ValueError as e:
            print(f"[ERROR] Failed to load: {img_path}")
            raise e

        mask = np.load(gt_path).astype(np.uint8)

        # Normalização
        if self.scale == 'mm': # Min-Max scaling
            for i in range(image.shape[0]):
                min_val = image.min()
                max_val = image.max()
                image[i] = ((image[i] - min_val)) / (max_val - min_val) if max_val > min_val else 0
        elif self.scale == 'ss': # Standard score scaling
            for i in range(image.shape[0]):
                mean_val = image[i].mean()
                std_val = image[i].std()
                image[i] = ((image[i] - mean_val)) / (std_val)
        elif self.scale == 'mmpc': # Min-Max scaling with percentile clipping
            for i in range(image.shape[0]):
                lower = np.percentile(image[i], 2)
                upper = np.percentile(image[i], 98)
                image[i] = np.clip(image[i], lower, upper)
                image[i] = (image[i] - lower) / (upper - lower) if upper > lower else 0

        elif self.scale == 'div255': # Simple division by 255
            image = image / 255.0


        # Aplica transformações na imagem
        if self.transform:
            image = image.transpose(1, 2, 0)
            image = self.transform(image)
        else:
            # Assumes image from .npy is (H, W, C), needs to be (C, H, W) for PyTorch
            image = image.transpose(2, 0, 1)
            image = torch.from_numpy(image)

        # Convert mask to tensor and add channel dimension
        mask = torch.from_numpy(mask).unsqueeze(0)

        return image, mask


class AugmentDataset(Dataset):
    def __init__(self, dataset, augmentation="both"):
        """
        Applies data augmentation to an existing dataset.
        :param dataset: The original dataset (train or validation).
        :param augmentation: Augmentation type ('rotation', 'flip', 'both', or None).
        """
        self.dataset = dataset
        self.augmentation = augmentation
        self.rotations = [0, 90, 180, 270]
        self.transforms = []

        if augmentation in ["rotation", "both"]:
            self.transforms.extend([T.RandomRotation(degrees=[angle, angle]) for angle in self.rotations])

        if augmentation in ["flip", "both"]:
            self.transforms.append(T.RandomHorizontalFlip(p=1.0))
            self.transforms.append(T.RandomVerticalFlip(p=1.0))

    def __len__(self):
        return len(self.dataset) * (len(self.transforms) + 1)

    def __getitem__(self, idx):
        original_idx = idx % len(self.dataset)
        image, mask = self.dataset[original_idx]

        # Aplica uma transformação específica com base no índice
        transform_idx = idx // len(self.dataset)
        if transform_idx > 0:
            transform = self.transforms[transform_idx - 1]
            image = transform(image)
            mask = transform(mask)

        return image, mask


def get_dataloaders(train_img_dir, train_gt_dir, val_img_dir=None, val_gt_dir=None, batch_size=8, n_workers=4,
                    transforms_flag=False, scale='mm', split_data=False, split_ratio=0.8, aug_t=None, aug_v=None):
    """
    Creates dataloaders for training and validation.

    :param train_img_dir: Directory for training images.
    :param train_gt_dir: Directory for training masks.
    :param val_img_dir: Directory for validation images (None if split_data=True).
    :param val_gt_dir: Directory for validation masks (None if split_data=True).
    :param batch_size: Batch size.
    :param n_workers: Number of workers for data loading.
    :param transforms_flag: If True, applies specific normalization (mean/std).
    :param scale: Normalization method ('mm', 'ss', 'mmpc').
    :param split_data: If True, automatically splits the training set into train/validation.
    :param split_ratio: Ratio of the training set used for training (e.g., 0.8 = 80% train, 20% val).
    :param aug_t: Augmentation type for training ('rotation', 'flip', 'both', or None).
    :param aug_v: Augmentation type for validation.
    :return: Dataloaders for training and validation.
    """

    means = [9575.1400, 9333.1825, 8279.4832, 17096.3165, 14481.1129, 11567.9320]
    stds = [2269.6487, 1782.7761, 1505.7824, 3603.5337, 3851.3550, 3271.9373]

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=means, std=stds),
    ]) if transforms_flag else None

    full_dataset = UniqueDataset(train_img_dir, train_gt_dir, transform=transform, scale=scale)

    if split_data:
        train_size = int(split_ratio * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    else:
        train_dataset = full_dataset
        if val_img_dir and val_gt_dir:
            val_dataset = UniqueDataset(val_img_dir, val_gt_dir, transform=transform, scale=scale)
        else:
            raise ValueError("If 'split_data' is False, val_img_dir and val_gt_dir must be provided!")

    # Aplica augmentation apenas ao conjunto de treino
    if aug_t:
        train_dataset = AugmentDataset(train_dataset, aug_t)
    if aug_v:
        val_dataset = AugmentDataset(val_dataset, aug_v)
    #print(f"Train size: {len(train_dataset)} | Val size: {len(val_dataset)}")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers)
    #print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
    return train_loader, val_loader

In [None]:
import torch.nn as nn
import os
import torch


class JointLoss(nn.Module):
    def __init__(self, loss1, loss2, weight1=0.5, weight2=0.5):
        """
        Combines two loss functions with custom weights.
        Args:
            loss1 (nn.Module): First loss function (e.g., FocalLoss).
            loss2 (nn.Module): Second loss function (e.g., DiceLoss).
            weight1 (float): Weight for the first loss function.
            weight2 (float): Weight for the second loss function.
        """
        super(JointLoss, self).__init__()
        self.loss1 = loss1
        self.loss2 = loss2
        self.weight1 = weight1
        self.weight2 = weight2

    def forward(self, outputs, targets):
        # Calcula as duas perdas
        loss1 = self.loss1(outputs, targets)
        loss2 = self.loss2(outputs, targets)

        # Combina as perdas com os pesos
        total_loss = self.weight1 * loss1 + self.weight2 * loss2
        return total_loss


class ModelCheckpoint:
    def __init__(self, checkpoint_dir, max_saves=1, finetune = False):
        """
        Initializes the callback to save only the best model checkpoint.
        Args:
            checkpoint_dir (str): Directory to save checkpoints.
            max_saves (int): Maximum number of checkpoints to keep (set to 1 for best only).
            finetune (bool): If True, adds a '_finetune' suffix to the checkpoint name.
        """
        self.checkpoint_dir = checkpoint_dir
        self.max_saves = max_saves
        self.finetune = finetune
        self.best_losses = []  # Lista de tuplas (val_loss, checkpoint_path)
        os.makedirs(checkpoint_dir, exist_ok=True)  # Cria o diretório se não existir

    def __call__(self, model, avg_val_acc, val_loss, epoch):
        """
        Saves the best checkpoint and removes the previous best.
        Args:
            model (torch.nn.Module): The model to save.
            avg_val_acc (float): The average validation accuracy (IoU).
            val_loss (float): The validation loss value.
            epoch (int): The current epoch number.
        """
        if not self.finetune:
            checkpoint_path = os.path.join(
                self.checkpoint_dir, f"checkpoint_epoch_{epoch}_acc_{avg_val_acc:.4f}_loss_{val_loss:.4f}.pth"
            )
        else:
            checkpoint_path = os.path.join(
                self.checkpoint_dir, f"checkpoint_epoch_{epoch}_acc_{avg_val_acc:.4f}_loss_{val_loss:.4f}_finetune.pth"
            )

        # Se ainda não há checkpoints, salva o primeiro
        if not self.best_losses:
            self.best_losses.append((val_loss, checkpoint_path))
            torch.save(model.state_dict(), checkpoint_path)
            return

        # Obtém a menor perda já salva
        best_loss, best_path = self.best_losses[0]

        if val_loss < best_loss:  # Apenas salva se for melhor
            # Remove o checkpoint anterior
            if os.path.exists(best_path):
                os.remove(best_path)

            # Atualiza para o novo melhor
            self.best_losses = [(val_loss, checkpoint_path)]
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saving new best model: {os.path.basename(checkpoint_path)}")

## MODEL TRAINING

In [None]:
"""
This script orchestrates the training process, bringing together the model, data,
and utilities. It features enhanced progress tracking and a robust training loop.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import time
from tqdm import tqdm
import segmentation_models_pytorch as smp
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score
import torch
import gc

# Free up memory before starting training
gc.collect()
torch.cuda.empty_cache()

# --- Global Configurations ---
BATCH_SIZE = 8
LEARNING_RATE = 0.0001
NUM_EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EARLY_STOP_PATIENCE = 15  # Number of epochs with no improvement to wait before stopping
LR_REDUCTION_PATIENCE = 5 # Number of epochs with no improvement to wait before reducing LR

# --- Path Configurations ---
# IMPORTANT: Replace these paths with your own directories.
CHECKPOINT_DIR = "/path/tile_export/checkpoints"
TRAIN_IMG_DIR = "/path/tile_export/image/"
TRAIN_GT_DIR = "/path/tile_export/gt/"

# --- Initial Setup ---
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Utility Functions ---
def save_last_checkpoint(model, optimizer, epoch, checkpoint_dir):
    """Saves the state of the model and optimizer at the end of an epoch."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(checkpoint_dir, "checkpoint_last.pth"))

# --- Data Loading ---
train_loader, val_loader = get_dataloaders(
    train_img_dir=TRAIN_IMG_DIR,
    train_gt_dir=TRAIN_GT_DIR,
    batch_size=BATCH_SIZE,
    scale='div255',      # Using simple division by 255 for normalization
    split_data=True,     # Automatically split the training data
    split_ratio=0.8,     # 80% for training, 20% for validation
    aug_t='rotation',        # Apply both rotation and flip augmentation to training data
    aug_v='rotation'         # Also apply to validation data for robustness check
)

# --- Model, Optimizer, and Loss Initialization ---
model = UNet(in_channels=3).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-3)
metric_iou = BinaryJaccardIndex().to(DEVICE)  # Jaccard Index / IoU
metric_dice = BinaryF1Score().to(DEVICE)     # F1 Score / Dice Coefficient

# Using a single, robust loss function like FocalLoss
criterion = smp.losses.FocalLoss(mode='binary')
# Example of using a combined loss:
# focal_loss = smp.losses.FocalLoss(mode="binary")
# iou_loss = smp.losses.JaccardLoss(mode="binary")
# criterion = JointLoss(focal_loss, iou_loss, weight1=0.5, weight2=0.5).to(DEVICE)

# --- Resume from Checkpoint if available ---
start_epoch = 0
last_checkpoint_path = os.path.join(CHECKPOINT_DIR, "checkpoint_last.pth")
if os.path.exists(last_checkpoint_path):
    checkpoint = torch.load(last_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"\n▶ Resuming training from epoch {start_epoch+1}")

# --- Training Setup ---
checkpoint_callback = ModelCheckpoint(CHECKPOINT_DIR, max_saves=1)
total_start_time = time.time()
best_val_loss = float('inf')
patience_counter = 0      # Counter for early stopping
lr_patience_counter = 0   # Counter for learning rate reduction

# --- Informative Header ---
print(f"\n{'='*65}")
print(f"Training Started | Device: {DEVICE}")
print(f"Batch Size: {BATCH_SIZE} | Initial LR: {LEARNING_RATE:.0e}")
print(f"Epochs: {NUM_EPOCHS} | Checkpoints Dir: {CHECKPOINT_DIR}")
print(f"Early Stopping Patience: {EARLY_STOP_PATIENCE} epochs | LR Reduction Patience: {LR_REDUCTION_PATIENCE} epochs")
print(f"{'='*65}\n")

# --- Main Training Loop ---
for epoch in range(start_epoch, NUM_EPOCHS):
    epoch_start_time = time.time()
    model.train()

    # Metric containers for the epoch
    batch_times, batch_losses, batch_ious, batch_dices = [], [], [], []

    # Training progress bar
    train_bar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1:03d}/{NUM_EPOCHS:03d} [Train]",
        bar_format="{l_bar}{bar:20}{r_bar}{bar:-20b}",
        unit="batch"
    )

    for images, masks in train_bar:
        batch_start_time = time.time()
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        if torch.isnan(images).any() or torch.isinf(images).any():
          print("[Erro] Tensor contém valores inválidos (NaN ou Inf)!")
          break

        #print(f"Shapes => input: {images.shape} | mask: {masks.shape}")

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Calculate metrics without tracking gradients
        with torch.no_grad():
            iou = metric_iou(outputs, masks)
            dice = metric_dice(outputs, masks)

        # Update statistics
        batch_times.append(time.time() - batch_start_time)
        batch_losses.append(loss.item())
        batch_ious.append(iou.item())
        batch_dices.append(dice.item())

        # Update progress bar postfix
        train_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'iou': f"{iou.item():.4f}",
            'dice': f"{dice.item():.4f}",
            'b_time': f"{batch_times[-1]:.2f}s",
            'lr': f"{optimizer.param_groups[0]['lr']:.1e}"
        })

    # --- End of Training Epoch ---
    avg_train_loss = sum(batch_losses) / len(train_loader)
    avg_train_iou = sum(batch_ious) / len(train_loader)
    avg_train_dice = sum(batch_dices) / len(train_loader)
    epoch_time = time.time() - epoch_start_time

    # --- Validation Phase ---
    model.eval()
    val_loss, val_iou, val_dice = 0.0, 0.0, 0.0
    val_bar = tqdm(
        val_loader,
        desc=f"Epoch {epoch+1:03d}/{NUM_EPOCHS:03d} [Validate]",
        bar_format="{l_bar}{bar:20}{r_bar}{bar:-20b}",
        unit="batch"
    )

    with torch.no_grad():
        for images, masks in val_bar:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)

            loss = criterion(outputs, masks)
            iou = metric_iou(outputs, masks)
            dice = metric_dice(outputs, masks)

            val_loss += loss.item()
            val_iou += iou.item()
            val_dice += dice.item()

            val_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'iou': f"{iou.item():.4f}",
                'dice': f"{dice.item():.4f}"
            })

    avg_val_loss = val_loss / len(val_loader)
    avg_val_iou = val_iou / len(val_loader)
    avg_val_dice = val_dice / len(val_loader)

    # --- Early Stopping and Learning Rate Scheduling Logic ---
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        lr_patience_counter = 0  # Reset both counters on improvement
        checkpoint_callback(model, avg_val_iou, avg_val_loss, epoch + 1)
    else:
        patience_counter += 1
        lr_patience_counter += 1

        # Early stopping check
        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"\nEarly Stopping triggered at epoch {epoch+1}! No improvement for {EARLY_STOP_PATIENCE} consecutive epochs.")
            break

        # Learning rate reduction check
        if lr_patience_counter >= LR_REDUCTION_PATIENCE:
            current_lr = optimizer.param_groups[0]['lr']
            new_lr = current_lr * 0.5 # Halve the learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
            print(f"\nReducing learning rate to {new_lr:.1e} after {LR_REDUCTION_PATIENCE} epochs without improvement.")
            lr_patience_counter = 0  # Reset the LR counter after reduction

    # --- Time and Memory Statistics ---
    total_time = time.time() - total_start_time
    avg_batch_time = sum(batch_times) / len(batch_times)
    epochs_left = NUM_EPOCHS - (epoch + 1)
    # Estimated Time of Arrival (ETA)
    eta = epochs_left * (total_time / (epoch + 1 - start_epoch)) if epoch >= start_epoch else 0

    if torch.cuda.is_available():
        gpu_mem = torch.cuda.max_memory_allocated() / (1024 ** 3)  # in GB
        torch.cuda.reset_peak_memory_stats()
    else:
        gpu_mem = 0

    # --- Print Epoch Summary ---
    print(f"\nEPOCH {epoch+1:03d}/{NUM_EPOCHS:03d} SUMMARY [{epoch_time:.1f}s]")
    print(f"  Train => Loss: {avg_train_loss:.4f} | IoU: {avg_train_iou:.4f} | Dice: {avg_train_dice:.4f}")
    print(f"  Valid => Loss: {avg_val_loss:.4f} | IoU: {avg_val_iou:.4f} | Dice: {avg_val_dice:.4f}")
    print(f"  Patience: {patience_counter}/{EARLY_STOP_PATIENCE} | LR Patience: {lr_patience_counter}/{LR_REDUCTION_PATIENCE}")
    if gpu_mem > 0:
        print(f"  Avg Batch Time: {avg_batch_time:.2f}s | Peak GPU Memory: {gpu_mem:.2f}GB")
    print(f"  Elapsed Time: {time.strftime('%H:%M:%S', time.gmtime(total_time))} | ETA: {time.strftime('%H:%M:%S', time.gmtime(eta))}\n")

    # Save the last checkpoint for resuming
    save_last_checkpoint(model, optimizer, epoch + 1, CHECKPOINT_DIR)

# --- End of Training ---
print(f"\n{'='*65}")
print(f"Training Complete!")
print(f"Total Time: {time.strftime('%H:%M:%S', time.gmtime(time.time()-total_start_time))}")
if checkpoint_callback.best_losses:
    best_model_path = sorted(checkpoint_callback.best_losses)[0][1]
    print(f"Best Model Saved: {os.path.basename(best_model_path)}")
else:
    print("No best model was saved (training may have been interrupted early).")
print(f"{'='*65}")