In [1]:
import numpy as np
import random
import time 
import os 
import geopandas as gpd 
import rioxarray
from msmla50 import MSMLA50
import glob
import gc
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import v2 as transforms
import cnn_utils
import utils
torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True, warn_only=True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
np.random.default_rng(seed=0)

Generator(PCG64) at 0x222CE648BA0

In [4]:
# settings
patch_size = 32
stride = 10
gt_stride = 32
background_label = 0
batch_size = 128
offset_left = 'best'
offset_top = 'best'

alpha=0.5
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),
])

num_epochs = 10
learning_rate = 0.0001

In [5]:
## inputs
# berlin
image = r'imagery\berlin_20170519.tif'
splited_ref_data = gpd.read_file(r'ref_data\berlin_ref_splitS2S3S4.gpkg')
location = 'berlin'

## hong kong
# image = r'imagery\hongkong_20180321.tif'
# splited_ref_data = gpd.read_file(r'ref_data\hongkong_ref_splitS2S3S4.gpkg')
# location = 'hongkong'

## paris
# image = r'imagery\paris_20170526.tif'
# splited_ref_data = gpd.read_file(r'ref_data\splited_S234\paris_ref_splitS2S3S4.gpkg')
# location = 'paris'

## rome
# image = r'imagery\rome_20170620.tif'
# splited_ref_data = gpd.read_file(r'ref_data\splited_S234\rome_ref_splitS2S3S4.gpkg')
# location = 'rome'

## sao paulo
# image = r'imagery\sao_paulo_20170726.tif'
# splited_ref_data = gpd.read_file(r'ref_data\splited_S234\saopaulo_ref_splitS2S3S4.gpkg)
# location = 'saopaulo'

# Training

In [None]:
folds = [0,1,2,3,4,5]
for fold in folds:
    ## prepare train and test polygons
    test_polygons = splited_ref_data[splited_ref_data["fold"] == fold]
    train_polygons = splited_ref_data[splited_ref_data["fold"] != fold]

    train_polygons_raster = fr"{location}_train_f{fold}.tif"
    test_polygons_raster = fr"{location}_test_f{fold}.tif"

    # rasterize
    train_temp = train_polygons_raster.replace(".tif", "_temp.tif")
    test_temp = test_polygons_raster.replace(".tif", "_temp.tif")
    utils.rasterize_reference_polygons(train_polygons, image, train_temp)
    utils.rasterize_reference_polygons(test_polygons, image, test_temp)

    # train and test images matched to 10m image
    train_image_matched = utils.match_rasters(train_temp, image)
    test_image_matched = utils.match_rasters(test_temp, image)
    
    # save
    train_image_matched.rio.to_raster(train_polygons_raster, driver="GTiff", compress="LZW")
    test_image_matched.rio.to_raster(test_polygons_raster, driver="GTiff", compress="LZW")

    # cleaning
    train_image_matched.close()
    test_image_matched.close()
    train_image_matched = None
    test_image_matched = None
    gc.collect()
    if os.path.exists(train_temp):
        os.remove(train_temp)
    if os.path.exists(test_temp):
        os.remove(test_temp)

    ## get train and test patches
    train_patches = cnn_utils.generate_labeled_patches_loader(image_path = image,reference_path = train_polygons_raster,patch_size = patch_size,stride = gt_stride,batch_size = batch_size,offset_left = offset_left,offset_top = offset_top,background_label = background_label)
    test_patches = cnn_utils.generate_labeled_patches_loader(image_path = image,reference_path = test_polygons_raster,patch_size = patch_size,stride = gt_stride,batch_size = batch_size,offset_left = offset_left,offset_top = offset_top,background_label = background_label)

    ## remapping labels
    label_mapping = cnn_utils.compute_label_mapping(train_patches)
    train_patches = cnn_utils.label_remapping(train_patches, label_mapping)
    test_patches = cnn_utils.label_remapping(test_patches, label_mapping)

    ## normalize, augment, over/under-sample
    mean, std = cnn_utils.get_normalization_parameters(train_patches)
    train_patches_norm = cnn_utils.generate_augmented_loader_with_sampler(
        image_path=image,
        reference_path=train_polygons_raster,
        patch_size=patch_size,
        stride=gt_stride,
        batch_size=batch_size,
        offset_left=offset_left,
        offset_top=offset_top,
        background_label=background_label,
        transform=train_transform,
        label_mapping=label_mapping,
        mean=mean,
        std=std,
        alpha=alpha
    )
    test_patches_norm = cnn_utils.normalize_loader(test_patches, mean, std)

    ## initalize model
    model = MSMLA50(input_channels=10, depth=[16,32,48], num_classes=len(train_polygons["gridcode"].unique()))
    model = model.cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    output = fr's2_cnn_models\{location}_S2_fold{fold}.pth'
    cnn_utils.cnn_training(model, train_patches_norm, test_patches_norm, num_epochs, criterion, learning_rate, optimizer, output)

    ## keep the best model out of the checkpointed models
    search_pattern = fr's2_cnn_models\{location}_S2_fold{fold}_epoch*.pth'
    checkpoint_files = glob.glob(search_pattern)
    if not checkpoint_files:
        print(f"No checkpoints found for Fold {fold} in {location}")
    else:
        best_file = None
        best_score = -float('inf') 
        print(f"Found {len(checkpoint_files)} checkpoints.")

    for f_path in checkpoint_files:
        checkpoint = torch.load(f_path, map_location='cpu')
        # searching for highest test accuracy
        current_score = checkpoint['test_accuracy']
        if current_score > best_score:
            best_score = current_score
            best_file = f_path

    print(f"Fold {fold}: {os.path.basename(best_file)} (Score: {best_score:.4f})")

    # delete the others
    for f_path in checkpoint_files:
        if f_path != best_file:
            try:
                os.remove(f_path)
            except OSError as e:
                pass