##### Reference raster probably must be the same size (width / height) as the imagery, resolution also

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import rasterio as rio
import numpy as np
import os
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Subset
from rasterio.windows import Window
from rasterio.transform import Affine
import random
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
from torchvision.transforms import v2 as transforms

torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False

%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 10]

In [2]:
# LEN TILING SNIMKU NA 32X32 STRIDE 1
class PatchDataset(Dataset):
    def __init__(self, image_path, patch_size, stride, offset_left=0, offset_top=0):
        """
        Dataset for extracting patches from an image and corresponding labels from a reference raster.

        Args:
            image_path (str): Path to the Sentinel image (features).
            reference_path (str): Path to the reference raster (labels).
            patch_size (int): Size of the square patches (in pixels).
            stride (int): Step size for sliding the window.
            offset_left (int): Number of pixels to ignore from the left edge of the image.
            offset_top (int): Number of pixels to ignore from the top edge of the image.
        """
        self.image_path = image_path
        self.patch_size = patch_size
        self.stride = stride
        self.offset_left = offset_left
        self.offset_top = offset_top

        # Open the feature and reference rasters
        self.src_features = rio.open(image_path)

        # Adjust dimensions based on offsets
        self.width = self.src_features.width - offset_left
        self.height = self.src_features.height - offset_top
        self.num_bands = self.src_features.count

        # Precompute patch positions (accounting for offsets)
        self.patches = [
            (row, col)
            for row in range(0, self.height - patch_size + 1, stride)
            for col in range(0, self.width - patch_size + 1, stride)
        ]

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        row, col = self.patches[idx]

        # Adjust row and col to account for offsets
        row += self.offset_top
        col += self.offset_left

        # Extract the image patch
        window = rio.windows.Window(col, row, self.patch_size, self.patch_size)
        patch_features = self.src_features.read(window=window)  # Shape: (bands, patch_size, patch_size)

        # Return the patch and its label
        return torch.tensor(patch_features, dtype=torch.float32)


dirs=[r'32x32_stride1_left0_top0_all1',
      r'32x32_stride1_left0_top0_all2',
      r'32x32_stride1_left0_top0_all3',
      r'32x32_stride1_left0_top0_all4',
      r'32x32_stride1_left0_top0_all5',
      r'32x32_stride1_left0_top0_all6',
      r'32x32_stride1_left0_top0_all7',
      r'32x32_stride1_left0_top0_all8',
      r'32x32_stride1_left0_top0_all9',
      r'32x32_stride1_left0_top0_all10']

start = time.time()


# SETTINGS
image_path = r"composite_subset_with_height.tif"  # Path to Sentinel image (features)
patch_size = 32
stride = 1
offset_left = 0
offset_top = 0
#output_dir_all = r"32x32_stride1_left0_top0_all"
#output_dir_valid = r"32x32_stride1_left0_top0_GT"
dir_index = 0

# Create output directories
os.makedirs(dirs[dir_index], exist_ok=True)

# Initialize the dataset
dataset = PatchDataset(
    image_path=image_path,  
    patch_size=patch_size, 
    stride=stride,
    offset_left=offset_left, 
    offset_top=offset_top
)

for idx, features in enumerate(dataset):
    # Save every generated tile
    all_tile_path = os.path.join(dirs[dir_index], f"tile_{str(idx).zfill(20)}.pt")
    torch.save(features, all_tile_path)

    # Log progress
    if idx % 10000 == 0:
        print(f"Processed tile {idx}")

    if idx % 6500000 == 0 and idx != 0:
        dir_index += 1
        os.makedirs(dirs[dir_index], exist_ok=True)

print('---')
time_elapsed = time.time() - start
print('Tile creation took {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

Processed tile 0
Processed tile 10000
Processed tile 20000
Processed tile 30000
Processed tile 40000
Processed tile 50000
Processed tile 60000
Processed tile 70000
Processed tile 80000
Processed tile 90000
Processed tile 100000
Processed tile 110000
Processed tile 120000
Processed tile 130000
Processed tile 140000
Processed tile 150000
Processed tile 160000
Processed tile 170000
Processed tile 180000
Processed tile 190000
Processed tile 200000
Processed tile 210000
Processed tile 220000
Processed tile 230000
Processed tile 240000
Processed tile 250000
Processed tile 260000
Processed tile 270000
Processed tile 280000
Processed tile 290000
Processed tile 300000
Processed tile 310000
Processed tile 320000
Processed tile 330000
Processed tile 340000
Processed tile 350000
Processed tile 360000
Processed tile 370000
Processed tile 380000
Processed tile 390000
Processed tile 400000
Processed tile 410000
Processed tile 420000
Processed tile 430000
Processed tile 440000
Processed tile 450000
Pr

In [3]:
# TILING SNIMKU + REFERENCIE NA 32X32 STRIDE 32 NA NATRENOVANIE A VALIDOVANIE MODELU

# SETTINGS
image_path = r"composite_subset_with_height.tif"  # Path to Sentinel image (features)
reference_path = r"berlin_lcz_GT_fullres.tif"  # Path to reference raster (labels)
patch_size = 32
stride = 32
offset_left = 2
offset_top = 2
output_dir_all = r"32x32_stride32_left2_top2_all"
output_dir_valid = r"32x32_stride32_left2_top2_GT"

start = time.time()
class PatchDataset(Dataset):
    def __init__(self, image_path, reference_path, patch_size, stride, offset_left=0, offset_top=0):
        """
        Dataset for extracting patches from an image and corresponding labels from a reference raster.

        Args:
            image_path (str): Path to the Sentinel image (features).
            reference_path (str): Path to the reference raster (labels).
            patch_size (int): Size of the square patches (in pixels).
            stride (int): Step size for sliding the window.
            offset_left (int): Number of pixels to ignore from the left edge of the image.
            offset_top (int): Number of pixels to ignore from the top edge of the image.
        """
        self.image_path = image_path
        self.reference_path = reference_path
        self.patch_size = patch_size
        self.stride = stride
        self.offset_left = offset_left
        self.offset_top = offset_top

        # Open the feature and reference rasters
        self.src_features = rio.open(image_path)
        self.src_labels = rio.open(reference_path)

        # Ensure the dimensions match
        assert (
            self.src_features.width == self.src_labels.width
            and self.src_features.height == self.src_labels.height
        ), "Image and reference dimensions do not match!"

        # Adjust dimensions based on offsets
        self.width = self.src_features.width - offset_left
        self.height = self.src_features.height - offset_top
        self.num_bands = self.src_features.count

        # Precompute patch positions (accounting for offsets)
        self.patches = [
            (row, col)
            for row in range(0, self.height - patch_size + 1, stride)
            for col in range(0, self.width - patch_size + 1, stride)
        ]

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        row, col = self.patches[idx]

        # Adjust row and col to account for offsets
        row += self.offset_top
        col += self.offset_left

        # Extract the image patch
        window = rio.windows.Window(col, row, self.patch_size, self.patch_size)
        patch_features = self.src_features.read(window=window)  # Shape: (bands, patch_size, patch_size)

        # Extract the central pixel from the reference raster
        center_row = row + self.patch_size // 2
        center_col = col + self.patch_size // 2
        central_window = rio.windows.Window(center_col, center_row, 1, 1)
        label = self.src_labels.read(1, window=central_window).item()  # Read as scalar

        # Return the patch and its label
        return torch.tensor(patch_features, dtype=torch.float32), label

# Create output directories
os.makedirs(output_dir_all, exist_ok=True)
os.makedirs(output_dir_valid, exist_ok=True)

# Initialize the dataset
dataset = PatchDataset(
    image_path=image_path, 
    reference_path=reference_path, 
    patch_size=patch_size, 
    stride=stride,
    offset_left=offset_left, 
    offset_top=offset_top
)

# Save all tiles and valid tiles
valid_tile_paths = []  # To store paths to valid tiles
labels_list = []       # To store corresponding labels

for idx, (features, label) in enumerate(dataset):
    # Save every generated tile
    all_tile_path = os.path.join(output_dir_all, f"tile_{str(idx).zfill(20)}.pt")
    torch.save(features, all_tile_path)

    # Save only valid tiles (label != 0)
    if label != 0:
        valid_tile_path = os.path.join(output_dir_valid, f"tile_{str(idx).zfill(20)}.pt")
        torch.save(features, valid_tile_path)
        valid_tile_paths.append(valid_tile_path)
        labels_list.append(label)

    # Log progress
    if idx % 10000 == 0:
        print(f"Processed tile {idx}")

# Save labels for valid tiles
labels_tensor = torch.tensor(labels_list, dtype=torch.long)
torch.save(labels_tensor, os.path.join(output_dir_valid, "labels.pt"))


print(f"Saved {len(valid_tile_paths)} valid tiles to '{output_dir_valid}'.")
print(f"Saved all tiles to '{output_dir_all}'.")

print('---')
time_elapsed = time.time() - start
print('Tile creation took {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

Processed tile 0
Processed tile 10000
Processed tile 20000
Processed tile 30000
Processed tile 40000
Saved 2404 valid tiles to '32x32_stride32_left2_top2_GT'.
Saved all tiles to '32x32_stride32_left2_top2_all'.
---
Tile creation took 4m 33s
