In [None]:
!pip install cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.13-cp38-cp38m-linux_x86_64.whl


In [3]:
!pip install pyrosm openeo lightning rasterio plotly

Collecting pyrosm
  Downloading pyrosm-0.6.2.tar.gz (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting openeo
  Downloading openeo-0.30.0-py3-none-any.whl.metadata (7.3 kB)
Collecting lightning
  Downloading lightning-2.3.2-py3-none-any.whl.metadata (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.1/54.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting python-rapidjson (from pyrosm)
  Downloading python_rapidjson-1.18-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (22 kB)
Collecting shapely>=2.0.1 (from pyrosm)
  Downloading shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.0 kB)
Collecting cykhash (from pyrosm)
  Downl

# Utils

In [4]:
import logging
from pyrosm.data import sources
import numpy as np

def setup_logger(level: int = logging.INFO):
    """
    Set up a logger for the pipeline. 
    """
    logger = logging.getLogger()
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s"
    )
    
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    
    file_handler = logging.FileHandler("main.log")
    file_handler.setFormatter(formatter)

    logger.setLevel(level)
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)

    return logger


def get_available_cities():
    """
    Return all available cities from pyrosm 
    """
    return sources.cities.available


def stretch_hist(band):
    """
    Apply histogram stretching"""
    p2, p98 = np.percentile(band, (0.5, 99.5))
    return np.clip((band - p2) * 255.0 / (p98 - p2), 0, 255).astype(np.uint8)

# data_acquisition

In [5]:
#basics
import os
import time
import json
import pickle
import openeo
import numpy as np

# geography
import geopandas as gpd
import rasterio
from rasterio.features import geometry_mask


#download
import pyrosm as pyr
from openeo.rest import OpenEoApiError
from openeo.processes import ProcessBuilder, if_, is_nan



# plotting 
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go





class DataHandler: 
    def __init__(self, logger, path_to_data_directory = "data"):
        """
        Initialize the DataHandler class and define openeo params.
        """
        self.logger = logger
        self.openeo_temporal_extent = ["2023-05-01", "2023-09-30"]
        self.openeo_bands = ["B04", "B03", "B02", "B08", "B12", "B11", "SCL"]
        self.openeo_max_cloud_cover = 30
        self.openeo_spatial_resolution = 10
        self.openeo_connection = None
        self.openeo_collections = None
        self.openeo_jobs = None#
        self.path_to_data_directory = path_to_data_directory
        
        
        if not os.path.exists(self.path_to_data_directory):
            os.makedirs(self.path_to_data_directory)
            logger.info("Created data directory")
        else:
            logger.info("Data directory already exists")


    def create_directory(self, city: str):
        """
        Create a directory for each city.
        """
        os.makedirs(os.path.join(self.path_to_data_directory, city), exist_ok=True)
        self.logger.info(f"{city}: Directory available")


    def get_buildings(self, city: str):
        """
        Return buildings for a given city
        """
        self.create_directory(city)
        
        # Check if local data for city is available
        if "buildings.geojson" in os.listdir(os.path.join(self.path_to_data_directory, city)):
            self.logger.info(f"{city}: Using local building data")
            return gpd.read_file(os.path.join(self.path_to_data_directory, city,"buildings.geojson"))

        # Download data for city
        fp = pyr.get_data(city, directory=os.path.join(self.path_to_data_directory, city))
        osm = pyr.OSM(fp)
        self.logger.info(f"{city}: Downloaded data to {self.path_to_data_directory}/{city}")

        # Get bounding box for city
        boundingbox = self.get_boundingbox(city, osm)

        # Get the buildings of the city
        buildings_geodf = osm.get_buildings()

        # Remove buildings outside of the bounding box of the city
        buildings_geodf = buildings_geodf.cx[boundingbox[0] : boundingbox[2], boundingbox[1] : boundingbox[3]]

        # Save the data of the city
        buildings_path = os.path.join(self.path_to_data_directory, city,"buildings.geojson")
        buildings_geodf.to_file(buildings_path, driver="GeoJSON")
        self.logger.info(f"{city}: Stored data to {buildings_path}")

        return buildings_geodf


    def get_boundingbox(self, city: str, osm = None):
        """
        Get the bounding box for a city.
        """

        # Return bounding box for Berlin as specified in exercise sheet to ensure correct testing results
        if city == "Berlin":
            return [13.294333, 52.454927, 13.500205, 52.574409]

        # Check if local bounds are available
        bounds_path = os.path.join(self.path_to_data_directory, city,"bounds.pkl")
        if os.path.exists(bounds_path):
            with open(bounds_path, "rb") as f:
                boundingbox = pickle.load(f)
            return boundingbox
        
        # Ensure OSM data is available 
        if osm is None:
            self.get_buildings(city=city)

        # Get the boundaries
        geoframe_bounds = osm.get_boundaries()
        boundingbox = geoframe_bounds[geoframe_bounds["name"] == city].total_bounds

        # Check if bounding box is None
        if np.isnan(boundingbox[0]) or np.isnan(boundingbox[1]) or np.isnan(boundingbox[2]) or np.isnan(boundingbox[3]):
            self.logger.info(f"{city}: Bounding box is None. Using total bounds instead")
            boundingbox = geoframe_bounds.total_bounds
        self.logger.info(f"{city}: Bounding box is {boundingbox}")     

        # Save total bounds to pickle file
        with open(bounds_path, "wb") as f:
            pickle.dump(boundingbox, f)
        self.logger.info(f"{city}: Saved bounds to {bounds_path}")

        return boundingbox
    

    def get_satellite_image(self, city: str, return_rasterio_dataset = False): 
        """
        Get satellite images for a city. Use local data if available. Returns an Array with (H, W, C) shape
        """
        if os.path.exists(os.path.join(self.path_to_data_directory, city,"openEO.tif")):
            self.logger.info(f"{city}: Using local satellite image")
            ds = rasterio.open(os.path.join(self.path_to_data_directory, city,"openEO.tif"))
            if return_rasterio_dataset:
                return ds
            
            # Read all channels
            sat_data = ds.read()

            # Transpose to (H, W, C)
            sat_data = np.transpose(sat_data, (1, 2, 0))
            return sat_data
        else:
            self.download_satellite_image(city)
            return self.get_satellite_image(city)
    

    def connect_to_openeo(self):
        """
        Connect to the openEO backend and 
        """
        if self.openeo_connection is None:
            connection = openeo.connect("openeo.dataspace.copernicus.eu")
            connection.authenticate_oidc()
            self.openeo_connection = connection

            self.logger.info("Connected to openEO")
        else:
            self.logger.info("Already connected to openEO")


    def download_satellite_image(self, city: str):
        """
        Download satellite images for a city. Retry for 3 times if the job fails or takes longer than 30 min per job.
        """
        self.connect_to_openeo()
        
        # Log the currently running jobs
        self.logger.info("Current jobs:")
        for idx, job in enumerate(self.openeo_connection.list_jobs()):
            self.logger.info(f"{idx} {job['id']} {job['status']}")

        # Retry job up to 3 times. Raise exception after 3 retries.
        job_finished = False
        job_number_of_retries = 0
        while not job_finished : 
            if job_number_of_retries > 3:
                self.logger.error(f"{city}: Job failed after 3 retries")
                raise Exception(f"{city}: Job failed after 3 retries")
            job = self.create_and_start_openeo_job(city)    
            job_finished = self.await_job(city, job)
            job_number_of_retries += 1

        # Get job results and store in data/city
        job_results = self.openeo_connection.job(job.job_id).get_results()
        job_results.download_files(os.path.join(self.path_to_data_directory, city))
        self.logger.info(f"{city}: Downloaded job results to {os.path.join(self.path_to_data_directory, city)}")


    def delete_jobs(self):
        """
        Delete all jobs on the openEO backend. Use only for debugging. 
        """
        self.connect_to_openeo()

        for idx, job in enumerate(self.openeo_connection.list_jobs()):
            self.logger.info(f"Deleting job {idx}, {job['id']}, {job['status']}")
            self.openeo_connection.job(job["id"]).delete_job()


    def create_and_start_openeo_job(self, city: str, collection_id: str = "SENTINEL2_L2A"):
        """
        Creates an openeo processing job for a city and starts it.
        """
        # Transform order in boundingbox to dict
        boundingbox = self.get_boundingbox(city)
        boundingbox = {"west": boundingbox[0], "south": boundingbox[1], "east": boundingbox[2], "north": boundingbox[3]}
        
        # Create datacube
        datacube = self.openeo_connection.load_collection(
            collection_id=collection_id,
            spatial_extent=boundingbox,
            temporal_extent=self.openeo_temporal_extent,
            bands=self.openeo_bands,
            max_cloud_cover=self.openeo_max_cloud_cover,
        ).resample_spatial(self.openeo_spatial_resolution)

        # Create cloud mask
        scl = datacube.band("SCL")

        # Filter out cloud median probability, cloud high probability, and snow/ice
        mask = (scl == 8) | (scl == 9) | (scl == 11)

        # Resample mask to the spatial resolution of the datacube
        mask = mask.resample_cube_spatial(datacube.band("B04"))
        
        # Create the RGB image
        datacube_rgbFU = datacube.filter_bands(self.openeo_bands[:-1])
        
        # Apply cloud mask
        datacube_rgb_masked = datacube_rgbFU.mask(mask)
        
        # Reduce temporal to median 
        datacube_rgb_masked_reduced_t = datacube_rgb_masked.reduce_temporal("median")

        # Define image format 
        datacube_for_submission = datacube_rgb_masked_reduced_t.save_result(format="GTiff")
        
        # Create openEO job with datacube
        job = datacube_for_submission.create_job(title=f"{city}__pic")
        self.logger.info(f"{city}: Created openEO job")

        # Start openEO job
        job.start_job()
        self.logger.info(f"{city}: Started openEO job with ID: {job.job_id}")        

        return job


    def await_job(self, city, job):
        """
        Awaits the processing of a openeo job. 
        Returns when the job is finished or raises an exception if the job failed.
        """

        for i in range(30):
            status = self.openeo_connection.job(job.job_id).status()
            self.logger.debug(f"{city}: Job {job.job_id} status: {status}")
          
            if status == "finished":
                self.logger.info(f"{city}: Job {job.job_id} finished")
                return True
            
            elif status == "error":
                self.logger.warning(f"{city}: Job {job.job_id} failed. Trying again.")
                return False            
            
            time.sleep(60)
        self.logger.error(f"{city}: Job {job.job_id} did not finish in time")
        return False

    def get_building_mask(self, city: str, loaded_buildings = None, all_touched: bool = False):  
        """
        Get the local building mask for buildings in a city.
        """
        if all_touched:
            filename = "building_mask_dense"
        else:
            filename = "building_mask_sparse"
        # Check if the building mask is already available
        if os.path.exists(os.path.join(self.path_to_data_directory, city,f"{filename}.tif")):
            self.logger.info(f"{city}: Using local building mask")
            return rasterio.open(os.path.join(self.path_to_data_directory, city,f"{filename}.tif")).read(1)

        # Create new building mask 
        satellite_image = self.get_satellite_image(city, return_rasterio_dataset=True)

        # Get satellite image metadata
        transform = satellite_image.transform
        out_shape = (satellite_image.height, satellite_image.width)
        crs = satellite_image.crs

        # Read the GeoJSON file with building polygons
        if loaded_buildings is not None:
            buildings = loaded_buildings
        else:
            buildings = self.get_buildings(city)
            buildings = buildings.to_crs(crs)  # Ensure the CRS matches the GeoTIFF

        # Create a mask where pixels inside buildings are True, others are False
        # TODO all_touched paramer nutzen für zweite Maske
        mask = geometry_mask(
            buildings.geometry, transform=transform, invert=True, out_shape=out_shape, all_touched=all_touched,
        )
        
        # Store the mask as a GeoTIFF file
        
        out_meta = satellite_image.meta
        out_meta.update(
            {
                "driver": "GTiff",
                "height": mask.shape[0],
                "width": mask.shape[1],
                # "transform": transform,
                "count": 1,
            }
        )

        # boolmask is automatically being saved as int16 [0,1]
  
        with rasterio.open(os.path.join(self.path_to_data_directory, city,f"{filename}.tif"), "w", **out_meta) as dest:
            dest.write(mask, indexes=1)

        return mask



    def plot(self, city: str = "BerlinTest", 

             backend: str = "matplotlib",
             figure_size: tuple = (10, 10),
             brightness: int = 5,
             image_directory: str = "img/",
             show_plot: bool = False,
             slice_to_be_plotted = None
             ):
        """
        Plot the data for a city either with matplotlib or plotly.
        """

    
    
        if backend != "plotly" and backend != "matplotlib":            
            raise NotImplementedError("Only matplotlib and plotly is supported at the moment")
        
        satellite_data = self.get_satellite_image(city)        
        mask = self.get_building_mask(city)
        # Take out slice if only a slice is to be plotted
        if slice_to_be_plotted is not None:
            satellite_data = satellite_data[slice(*slice_to_be_plotted)]
            mask = mask[slice(*slice_to_be_plotted)]
        
        if backend =="matplotlib":
            #load buildings
            buildings = self.get_buildings(city)

            # create image out path
            image_path_out = os.path.join(image_directory, city)
             # make the output directory if not exists
            os.makedirs(image_path_out, exist_ok=True)

            # Design plots
            fig, ax = plt.subplots(figsize=figure_size)
            buildings.plot(ax=ax, color="black")
            plt.title(f"{city} buildings")
            plt.axis("off")

        # RGB Bands from Sentinel 2
        red = satellite_data[...,0]
        green = satellite_data[...,1]
        blue = satellite_data[...,2]

        # Apply histogram stretching
        red_stretched = stretch_hist(red)
        green_stretched = stretch_hist(green)
        blue_stretched = stretch_hist(blue)

        # Stack the bands after stretching
        rgb_stretched = np.dstack((red_stretched, green_stretched, blue_stretched))

        

        if backend =="matplotlib":
            # Plot the histogram-stretched RGB image
            plt.figure(figsize=figure_size)
            plt.imshow(rgb_stretched)
            # plt.title("Histogram Stretched RGB Composite Image")
            plt.title(f"{city} RGB Bands from Sentinel-2 L2A")
            plt.axis("off")
            # plt.show()
            plt.savefig(os.path.join(image_path_out, f"{city}_RGB.png"))
            if show_plot:
                plt.show()
            plt.close()


        # RGB image with higher brightness
        red_norm = (red - np.min(red)) / (np.max(red) - np.min(red))
        green_norm = (green - np.min(green)) / (np.max(green) - np.min(green))
        blue_norm = (blue - np.min(blue)) / (np.max(blue) - np.min(blue))
        pseudo_RGB_image = np.dstack((red_norm, green_norm, blue_norm))

        pseudo_RGB_image_normalized = (pseudo_RGB_image - np.min(pseudo_RGB_image)) / (
            pseudo_RGB_image.max() - pseudo_RGB_image.min()
        )


        pseudo_RGB_image_brighter = pseudo_RGB_image_normalized * brightness
        pseudo_RGB_image_brighter = np.clip(pseudo_RGB_image_brighter, 0, 1)

        if backend =="matplotlib":
            plt.figure(figsize=figure_size)
            plt.imshow(pseudo_RGB_image_brighter)
            plt.title(f"{city} RGB Image")
            plt.axis("off")
            # plt.show()
            plt.savefig(os.path.join(image_path_out, f"{city}_RGB_Brighter.png"))
            if show_plot:
                plt.show()
            plt.close()

            # single band img
            # single_band = satellite_image.read(1)
            single_band_stretched = stretch_hist(red)
            plt.figure(figsize=figure_size)
            plt.imshow(single_band_stretched, cmap="gray")
            plt.title(f"{city} Single Band Image")
            plt.axis("off")
            # plt.show()
            plt.savefig(os.path.join(image_path_out, f"{city}_SingleBand.png"))
            if show_plot:
                plt.show()
            plt.close()
        elif backend == "plotly":

            # plot the mask
            fig = px.imshow(mask.astype(np.uint8), binary_string=True)

            # Overlay the mask with the image
            fig.add_trace(go.Image(z=(pseudo_RGB_image_brighter * 255).astype(np.uint8), opacity=1))


            # Update layout with a button to toggle mask visibility
            fig.update_layout(
                updatemenus=[
                    dict(
                        type="buttons",
                        direction="left",
                        buttons=list([
                            dict(
                                args=[{"opacity": [0,1]}],
                                label="Hide Mask",
                                method="restyle"
                            ),
                            dict(
                                args=[{"opacity": [0.5, 0.5]}],
                                label="Show Mask",
                                method="restyle"
                            )
                        ]),
                    ),
                ],
                xaxis=dict(
                    scaleanchor="y",
                    scaleratio=1
                ),
                yaxis=dict(
                    scaleanchor="x",
                    scaleratio=1
                )
            )

            # Enable zooming and panning
            fig.update_xaxes(constrain='domain')
            fig.update_yaxes(scaleanchor='x', scaleratio=1)
            fig.update_layout(height=1000, width=1000)

            # Display the figure
            return fig


        # B8 B4 B3 -> False Color
        b8 = satellite_data[...,3]
        b8_stretched = stretch_hist(b8)
        b4 = red_stretched
        b3 = green_stretched

        false_color = np.dstack((b8_stretched, b4, b3))
        plt.figure(figsize=figure_size)
        plt.imshow(false_color)
        plt.title(f"{city} False Color Image")
        plt.axis("off")
        # plt.show()
        plt.savefig(os.path.join(image_path_out, f"{city}_FalseColor.png"))
        if show_plot:
            plt.show()
        plt.close()

        # params["bands"] = ["B04", "B03", "B02", "B08", "B12", "B11", "SCL"] # scl must be last

        # B12, B11, B4 -> False Color Urban
        b12 = satellite_data[...,4]
        b11 = satellite_data[...,5]
        b04 = red
        b12_norm = (b12 - np.min(b12)) / (np.max(b12) - np.min(b12))
        b11_norm = (b11 - np.min(b11)) / (np.max(b11) - np.min(b11))
        b04_norm = (b04 - np.min(b04)) / (np.max(b04) - np.min(b04))


        false_color_urban = np.dstack((b12_norm, b11_norm, b04_norm)) * brightness
        false_color_urban = np.clip(false_color_urban, 0, 1)

        plt.figure(figsize=figure_size)
        plt.imshow(false_color_urban)
        plt.title(f"{city} False Color Urban Image")
        plt.axis("off")
        # plt.show()
        plt.savefig(os.path.join(image_path_out, f"{city}_FalseColorUrban.png"))
        if show_plot:
            plt.show()
        plt.close()


        # get vegetation_index
        def vegetation_index(band1, band2):
            return (band1 - band2) / (band1 + band2)


        ndvi = vegetation_index(satellite_data[...,3], satellite_data[...,2])
        plt.figure(figsize=figure_size)
        plt.imshow(ndvi, cmap="RdYlGn")
        plt.title(f"{city} NDVI Image")
        plt.axis("off")
        # plt.show()
        plt.savefig(os.path.join(image_path_out, f"{city}_NDVI.png"))
        if show_plot:
            plt.show()
        plt.close()

        # Visualize the mask
        plt.figure(figsize=(10, 10))
        plt.imshow(mask, cmap="Blues")
        plt.title(f"{city} Building Mask")
        plt.axis("off")
        # plt.show()
        plt.savefig(os.path.join(image_path_out, f"{city}_BuildingMask.png"))
        if show_plot:
            plt.show()
        plt.close()


        # Load the image
        img = single_band_stretched  # Assuming `blue_stretched` is the single band image
        blue_cmap = plt.cm.Blues
        blue_building_mask = blue_cmap(mask / mask.max())
        blue_building_mask[..., 2] = mask * 0.8

        # Plot the image
        plt.figure(figsize=(10, 10))
        plt.imshow(img, cmap="gray", alpha=1)

        plt.imshow(blue_building_mask)

        # Set the title and axis labels
        plt.title(f"{city} Image with Buildings Mask")
        plt.axis("off")

        # Show the plot
        # plt.show()
        plt.savefig(os.path.join(image_path_out, f"{city}_BuildingMaskOverlay.png"))
        if show_plot:
            plt.show()
        plt.close()





# data_preperation

In [6]:
import rasterio
import numpy as np
import torch
from torch.utils.data import random_split, DataLoader, Dataset

import matplotlib.pyplot as plt
from scipy.special import kl_div
import pandas as pd
import logging

def create_tensor_of_windows(image, mask, patch_size=None):
    """
    Create tensor with dimensions [N, H, W, C+1] from the satellite image of the city.
    image should be of shape (H, W, C)
    mask should be of shape (H, W, 1)
    """

    # Merge Mask onto Image
    image_with_mask = np.dstack((image, mask))

    # cut of edges so image shape is divisible by patch size
    reduced_image = image_with_mask[:-(image_with_mask.shape[0]%patch_size), :-(image_with_mask.shape[1]%patch_size)]

    # get reduced image size
    orig_img_x_size = reduced_image.shape[0]
    orig_img_y_size = reduced_image.shape[1]

    # get factors to which each dimension is reduced
    down_scale_x = orig_img_x_size//patch_size
    down_scale_y = orig_img_y_size//patch_size

    # create array for pathched images
    patched_images = np.zeros([down_scale_x*down_scale_y, patch_size, patch_size, reduced_image.shape[2]], dtype=np.uint16)

    # fill array with patches
    for i in range(down_scale_x):
        for j in range(down_scale_y):
            patched_images[i*down_scale_x+j::down_scale_x*down_scale_y] = reduced_image[patch_size*i:patch_size*(i+1), patch_size*j:patch_size*(j+1)]

    # return array
    return patched_images

   
def divide_into_test_training(data, test_ratio=0.2, validation_ratio=0, seed=42):
    """
    Divide the data into test and training split with seed.
    """
    
    # Define the split rati
    train_ratio = 1 - test_ratio - validation_ratio
    if train_ratio < 0:
        raise ValueError("The train ratio is negative. Please check the split ratios.")

    # # Calculate the sizes for training and test sets
    # train_size = int(train_ratio * len(data))
    # test_size = int(test_ratio * len(data))
    # validation_size = int(validation_ratio * len(data))

    # Split the dataset with seed
    generator = torch.Generator().manual_seed(seed)
    train_dataset, test_dataset, validation_dataset = random_split(data, [train_ratio, test_ratio, validation_ratio], generator=generator)

    return train_dataset, test_dataset, validation_dataset


def validate_test_training_validation_split(train_dataset, test_dataset, validation_dataset, city_names=None):
    """
    Validate the train to the test split and the train to the validation split.
    """
    logger = logging.getLogger()
    # take out dataset for better readabiltiy
    dataset = train_dataset.dataset[:,:-2]

    # calculate mean, std, min and max for each image
    means = dataset.mean(axis=(2,3))
    stds = dataset.std(axis=(2,3))
    mins = dataset.min(axis=(2,3))
    maxs = dataset.max(axis=(2,3))

    # create dataframes for train, test and validation set
    train_means = pd.DataFrame({
        "mean":means[train_dataset.indices].mean(axis=0), 
        "std":stds[train_dataset.indices].mean(axis=0),
        "min":mins[train_dataset.indices].mean(axis=0),
        "max":maxs[train_dataset.indices].mean(axis=0),
        }, index=["R", "G", "B","B08", "B12", "B11"])

    test_means = pd.DataFrame({
        "mean":means[test_dataset.indices].mean(axis=0), 
        "std":stds[test_dataset.indices].mean(axis=0),
        "min":mins[test_dataset.indices].mean(axis=0),
        "max":maxs[test_dataset.indices].mean(axis=0),
        }, index=["R", "G", "B","B08", "B12", "B11"])

    # test if differences between train and test set are below 10%
    if (((train_means-test_means)/train_means)<0.1).all().all():
        print(u'\u2713',"Differences of train and test set is below 10% on mean, std, min and max across all input bands",)
    else:
        # if not show differences and give out Warning
        temp_df =(train_means-test_means)/train_means
        logger.warning("Differences of train and test set is above 10% on one of mean, std, min and max across all input bands. This might be too big of a difference between train and test set. Please choose another seed for splitting.")
        print("!!!There might be large diffferecenes between train and test set. Please choose another seed for splitting. For more detail see the differences below",)
        print(temp_df[temp_df>0.1].dropna(axis=1, how='all').dropna(axis=0, how='all'))
        
    if validation_dataset.indices:
        validation_means = pd.DataFrame({
            "mean":means[validation_dataset.indices].mean(axis=0), 
            "std":stds[validation_dataset.indices].mean(axis=0),
            "min":mins[validation_dataset.indices].mean(axis=0),
            "max":maxs[validation_dataset.indices].mean(axis=0),
            }, index=["R", "G", "B","B08", "B12", "B11"])
        
        # test if differences between train and validation set are below 10%
        if (((train_means-validation_means)/train_means)<0.1).all().all():
            print(u'\u2713',"Differences of train and validation set is below 10% on mean, std, min and max across all input bands",)
        else:
            # if not show differences and give out Warning
            temp_df =(train_means-validation_means)/train_means
            logger.warning("Differences of train and validation set is above 10% on one of mean, std, min and max across all input bands. This might be too big of a difference between train and test set. Please choose another seed for splitting.")
            print("!!!There might be large diffferecenes between train and validation set. Please choose another seed for splitting. For more detail see the differences below",)
            print(temp_df[temp_df>0.1].dropna(axis=1, how='all').dropna(axis=0, how='all'))

    print()
    # Look at distribution of masks
    masks  = train_dataset.dataset[:,[-2]]

    # sum the pixels of building up over each image
    sum = masks.sum(axis=(2,3))

    # create dataframes for train, test and validation set with descriptive statistics
    train_means_labels = pd.DataFrame({
        "mean":sum[train_dataset.indices].mean(axis=0), 
        "median":np.median(sum[train_dataset.indices],axis=0), 
        "std":sum[train_dataset.indices].std(axis=0),
        "10th percentile":np.percentile(sum[train_dataset.indices], q=10,axis=0),
        "90th percentile":np.percentile(sum[train_dataset.indices], q=90,axis=0),
        }, index=["Train"])
    test_means_labels = pd.DataFrame({
        "mean":sum[test_dataset.indices].mean(axis=0),
        "median":np.median(sum[test_dataset.indices],axis=0), 
        "std":sum[test_dataset.indices].std(axis=0),
        "10th percentile":np.percentile(sum[test_dataset.indices], q=10,axis=0),
        "90th percentile":np.percentile(sum[test_dataset.indices], q=90,axis=0),
        }, index=["Test"])
    if validation_dataset.indices:
        validation_means_labels = pd.DataFrame({
            "mean":sum[validation_dataset.indices].mean(axis=0), 
            "median":np.median(sum[validation_dataset.indices],axis=0), 
            "std":sum[validation_dataset.indices].std(axis=0),
            "10th percentile":np.percentile(sum[validation_dataset.indices], q=10,axis=0),
            "90th percentile":np.percentile(sum[validation_dataset.indices], q=90,axis=0),
            }, index=["Validation"])
        # print concatenated dataframes
        print("Comparison of distribution of masks:\n",pd.concat([train_means_labels, test_means_labels, validation_means_labels]))
    else:
        print("Comparison of distribution of masks:\n", pd.concat([train_means_labels, test_means_labels] ))

    print()

    # look at the distribution of the data according to the different cities
    if city_names is not None:  
        # take out the city name for each label
        train_cities = train_dataset.dataset[:,-1,0,0][train_dataset.indices]
        test_cities = test_dataset.dataset[:,-1,0,0][test_dataset.indices]
        validation_cities = validation_dataset.dataset[:,-1,0,0][validation_dataset.indices]

        # lookup how often each city occured in the different sets
        train_city_counts = np.unique(train_cities, return_counts=True)
        test_city_counts = np.unique(test_cities, return_counts=True)
        validation_citiy_counts = np.unique(validation_cities, return_counts=True)

        # create dataframes for better readability
        df = pd.DataFrame({
            "Train":pd.Series(train_city_counts[1]/train_cities.shape[0], index=train_city_counts[0], name='train'),
            "Test":pd.Series(test_city_counts[1]/test_cities.shape[0], index=test_city_counts[0], name='test'),
            "Validation":pd.Series(validation_citiy_counts[1]/validation_cities.shape[0], index=validation_citiy_counts[0], name='validation')})
        print(df.index)
        df.index = df.index.map({i:c for i ,c in enumerate(city_names)})
        print("Comparison of cities the data in the differen sets originates from:\n",df.T)




def create_data_loaders(train_dataset, test_dataset, validation_dataset, batch_size = 64):
    """
    Create DataLoaders.
    """
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, validation_loader


def apply_preprocessing_pipeline(images, masks, patch_size = 128, test_ratio = 0.2,validation_ratio=0, batch_size = 64, show_validation_of_split=True, city_names=None, minimum_number_of_true_pixels_per_image=0):
    """
    applies windowing, deviding into train and test and creating data loaders.
    """

    # for each city create patched images
    patched_images = []
    for i,(image, mask ) in enumerate(zip(images, masks)):
        patched_image = create_tensor_of_windows(image, mask, patch_size=patch_size)
        city = np.ones(shape=list(patched_image.shape[:-1])+[1])*i
        patched_image_with_city = np.concatenate([patched_image, city], axis=-1)
        patched_images.append(patched_image_with_city)


    # concatenate all patched images
    patched_images_merged = np.concatenate(patched_images, axis=0)

    # reorder axis to [N, C, H, W] for torch
    patched_images_merged = np.transpose(patched_images_merged, (0,3,1,2))
    
    # discard images with less than minimum_number_of_true_pixels_per_image
    sums = patched_images_merged[:,-2].sum(axis=(1,2))
    patched_images_merged = patched_images_merged[sums>=minimum_number_of_true_pixels_per_image]

    # devide into train and test
    train_dataset, test_dataset, validation_dataset = divide_into_test_training(patched_images_merged, test_ratio=test_ratio, validation_ratio=validation_ratio)

    if show_validation_of_split:
        validate_test_training_validation_split(train_dataset, test_dataset, validation_dataset, city_names=city_names)

    dataset = train_dataset.dataset[:,:-1]
    train_dataset.dataset = dataset
    test_dataset.dataset = dataset
    validation_dataset.dataset = dataset    
    # create data loaders
    train_loader, test_loader , validation_loader= create_data_loaders(train_dataset, test_dataset,validation_dataset, batch_size=batch_size)

    # TODO fix error
    # TODO remove city names

    return train_loader, test_loader, validation_loader


def plot_sub_image( image_data):
    """
    Plot sub image.
    """
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    ax[0].imshow(stretch_hist(image_data[:,:,:3]))
    ax[1].imshow(stretch_hist(image_data[:,:,-1]))
    return fig


# Main

In [7]:
# basics
import os

import numpy as np
from tqdm.notebook import tqdm 


# torch
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import lightning as L




# custom modules




# Configure logging for the pipeline
logger = setup_logger(level='ERROR')

2024-07-05 07:42:24.217182: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-05 07:42:24.217292: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-05 07:42:24.363666: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [8]:
cities = ['London', 'CapeTown', 'Hamburg', 'Johannesburg', 'London', 'Montreal', 'Paris', 'Seoul', 'Singapore', 'Sydney']

datahandler = DataHandler(logger, '/kaggle/input/building-prediction/')


In [9]:
# load images and mask for all specified cites

import os
images = []
sparse_masks=[]
dense_masks=[]

for city in cities:
    buildings = None
    if not os.path.exists(os.path.join(datahandler.path_to_data_directory,city,'building_mask_dense.tif')):
        print("loading local buildings")
        buildings = datahandler.get_buildings(city)
    images.append(datahandler.get_satellite_image(city))
    sparse_masks.append(datahandler.get_building_mask(city, all_touched=False, loaded_buildings=buildings))
    dense_masks.append(datahandler.get_building_mask(city, all_touched=True, loaded_buildings=buildings))



In [10]:
masks = [sparse_masks[i]+dense_masks[i]for i in range(len(sparse_masks))]

In [11]:
patch_size = 128
test_ratio= 0.2
validation_ratio=0.2
batch_size = 64
show_validation_of_split=False
city_names=cities
minimum_number_of_true_pixels_per_image=1

In [12]:
# for each city create patched images
patched_images = []
for i,(image, mask ) in enumerate(zip(images, masks)):
    patched_image = create_tensor_of_windows(image, mask, patch_size=patch_size)
    city = np.ones(shape=list(patched_image.shape[:-1])+[1])*i
    patched_image_with_city = np.concatenate([patched_image, city], axis=-1)
    patched_images.append(patched_image_with_city.astype(np.int16))



In [13]:
import sys
def sizeof_fmt(num, suffix='B'):
    ''' by Fred Cirera,  https://stackoverflow.com/a/1094933/1870254, modified'''
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)

for name, size in sorted(((name, sys.getsizeof(value)) for name, value in list(
                          locals().items())), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

       patched_image_with_city:  1.0 GiB
                 patched_image: 224.0 MiB
                          city: 128.0 MiB
                          mask: 33.2 MiB
                           _i2: 40.7 KiB
                           _i5: 40.7 KiB
                           _i6: 11.4 KiB
                   DataHandler:  1.2 KiB
                           _i1:  1.1 KiB
                OpenEoApiError:  1.0 KiB


In [14]:
sys.getsizeof(patched_images[0])/1e6

557.842592

In [15]:

# concatenate all patched images
patched_images_merged = np.concatenate(patched_images, axis=0)

# reorder axis to [N, C, H, W] for torch
patched_images_merged = np.transpose(patched_images_merged, (0,3,1,2))

# discard images with less than minimum_number_of_true_pixels_per_image
sums = patched_images_merged[:,-2].sum(axis=(1,2))
patched_images_merged = patched_images_merged[sums>=minimum_number_of_true_pixels_per_image]

# devide into train and test
train_dataset, test_dataset, validation_dataset = divide_into_test_training(patched_images_merged, test_ratio=test_ratio, validation_ratio=validation_ratio)

if show_validation_of_split:
    validate_test_training_validation_split(train_dataset, test_dataset, validation_dataset, city_names=city_names)

dataset = train_dataset.dataset[:,:-1]
train_dataset.dataset = dataset
test_dataset.dataset = dataset
validation_dataset.dataset = dataset    
# create data loaders
train_loader, test_loader , validation_loader= create_data_loaders(train_dataset, test_dataset,validation_dataset, batch_size=batch_size)


In [16]:
# apply training pipeline
# TODO make train test split consistent so we can train with multiple sizes, dont know if there is an advantage though
# train_loader, test_loader , validation_loader= apply_preprocessing_pipeline(images, masks, patch_size = 128, test_ratio= 0.2,validation_ratio=0.2, batch_size = 64, show_validation_of_split=False,city_names=cities, minimum_number_of_true_pixels_per_image=1)

In [17]:
class convNetSimple(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
                nn.Conv2d(6, 32, kernel_size=3, padding=1), nn.ReLU(),
                nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
                nn.Conv2d(128, 1, kernel_size=1, padding=0),
        )#nn.Sigmoid())
    
    def forward(self, x):
        return self.model(x)
    
class LitNet(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.MSELoss()

    def training_step(self, batch, batch_idx):
        x, y = batch[:,:-1], batch[:,-1]
        outs = self.model(x.float())
        loss = self.loss(outs, y.unsqueeze(1).float())
        self.log("train_loss", value=loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch[:,:-1], batch[:,-1]
        outs = self.model(x.float())
        loss = self.loss(outs, y.unsqueeze(1).float())
        
        values = {
            "test_loss": loss,
        }
        self.log_dict(values, on_epoch=True, on_step=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def forward(self, x):
        return self.model(x)


### Load Model

In [18]:
# load trained model
# convmodel = LitNet(convNetSimple())
# model_dict = torch.load("models/lightning_logs/version_1/checkpoints/epoch=39-step=7000.ckpt", map_location=torch.device('cpu'))
# convmodel.load_state_dict(model_dict['state_dict'])

In [19]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

import pytorch_lightning as pl

L.seed_everything(42)
convmodel = LitNet(convNetSimple())
trainer = L.Trainer(
    default_root_dir="models",
    # callbacks=[
    #     EarlyStopping(
    #         monitor="val_loss",
    #         mode="min",
    #         patience=10,
    #     )
    #     ModelCheckpoint(
    #         monitor="val_loss",
    #         mode="min",
    #         save_top_k=2,
    #         dirpath="models",
    #         filename="best_model"
    #     )
    # ]
    # val_check_interval=1,
    fast_dev_run=False,
    # num_sanity_val_steps=2,
    max_epochs=300,
    log_every_n_steps=20,
    accelerator="gpu", devices=2, strategy="ddp_notebook"
)


INFO: Seed set to 42
2024-07-05 07:43:40,318 - lightning.fabric.utilities.seed - INFO - seed_everything - Seed set to 42
2024-07-05 07:43:40,603 - pytorch_lightning.utilities.rank_zero - INFO - _info - GPU available: True (cuda), used: True
2024-07-05 07:43:40,694 - pytorch_lightning.utilities.rank_zero - INFO - _info - TPU available: False, using: 0 TPU cores
2024-07-05 07:43:40,695 - pytorch_lightning.utilities.rank_zero - INFO - _info - HPU available: False, using: 0 HPUs


In [20]:
train_loader.dataset.dataset.shape

(8762, 7, 128, 128)

In [None]:

# training
trainer.fit(convmodel, 
    train_dataloaders=train_loader,
    val_dataloaders=validation_loader
)


# testing

# hier könnte man noch das beste model laden, wenn wir ein Val dataset haben.
# best_model = LitModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
# trainer.test(
#     best_model,
#     dataloaders=test_loader
# )


trainer.test(
    convmodel,
    dataloaders=test_loader
)


  self.pid = os.fork()
  self.pid = os.fork()
/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
INFO: [rank: 0] Seed set to 42
2024-07-05 07:43:40,866 - lightning.fabric.utilities.seed - INFO - seed_everything - [rank: 0] Seed set to 42
INFO: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
2024-07-05 07:43:40,869 - lightning.fabric.utilities.distributed - INFO - _init_dist_connection - Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
INFO: [rank: 1] Seed set to 42
2024-07-05 07:43:41,034 - lightning.fabric.utilities.seed - INFO - seed_everything - [rank: 1] Seed set to 42
INFO: Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
2024-07-05 07:43:41,038 - lightning.fabr

Training: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


In [None]:
# # Instantiate the model, loss function, and optimizer
# criterion = nn.BCELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.01)

# # Training loop
# num_epochs = 50

# model.train()
# for epoch in tqdm(range(num_epochs)):
#     for batch in train_loader:
#         # splid in inputs and labels
#         inputs = batch[:,:-1].to(torch.float32)
#         labels = batch[:,-1, np.newaxis].to(torch.float32)

#         # zero the parameter gradients
#         optimizer.zero_grad()

#         # forward pass
#         outputs = model(inputs)

#         # calculate loss
#         loss = criterion(outputs, labels)

#         # write to tensorboard
#         writer.add_scalar("Loss/train", loss, epoch)

#         # backward pass
#         loss.backward()

#         # optimizer step
#         optimizer.step()
    
# writer.flush()

## Save Model

In [None]:
# import os

# os.makedirs("saved_models", exist_ok=True)
# torch.save(model.state_dict(), "saved_models/model1")

# Evaluation

In [41]:
convmodel(inputs).detach().shape

torch.Size([64, 1, 128, 128])

In [43]:
prediction = np.ones((len(test_loader.dataset.indices), 1,128, 128))*-1
true_values = np.ones((len(test_loader.dataset.indices), 1, 128,128))*-1
for i,batch in enumerate(test_loader):
    inputs = batch[:,:-1].to(torch.float32)
    labels = batch[:,-1, np.newaxis].to(torch.float32)
    prediction[i*64:(i+1)*64]=convmodel(inputs).detach()
    true_values[i*64:(i+1)*64]=labels

In [45]:
true_values.sum()

1048021.0

In [None]:
# predict on test set


t  = torch.Tensor(test_loader.dataset)

# splid in inputs and labels
test_inputs = t[:,:-1]#.to(torch.float32)
test_labels = t[:,-1, np.newaxis]#.to(torch.float32)

test_results = convmodel(test_inputs).detach()

# Look at sums, to check if model only predicts zeros
print("Sum of test results: ", test_results.sum())
print("But it should be closer to: ", test_labels.sum())


# # see how many percnet where predicted right
threshold = 0.5
((test_results>threshold)==test_labels).sum()/np.prod(test_labels.shape)

  t  = torch.Tensor(test_loader.dataset)


In [None]:
# from sklearn.metrics import RocCurveDisplay

# RocCurveDisplay.from_predictions(
#    test_labels.flatten(), test_results.flatten())

# Evaluation

In [None]:
# t  = torch.Tensor(test_loader.dataset)

# # splid in inputs and labels
# test_inputs = t[:,:-1]#.to(torch.float32)
# test_labels = t[:,-1, np.newaxis]#.to(torch.float32)

# test_results = model(test_inputs).detach()





# Download

In [None]:


buildings = []
sat_images = []
building_masks = []

for city in cities: 
    buildings.append(datahandler.get_buildings(city))
    sat_images.append(datahandler.get_satellite_image(city))
    building_masks.append(datahandler.get_building_mask(city))

# Plot the expected results for the first city 
datahandler.plot(city[0])

In [None]:
import data_preparation

for city in cities:
    data_preparation.create_tensor(city)

# Download

In [None]:
# Download 

for city in cities: 
    sat_image = datahandler.get_satellite_image(city)
    mask = datahandler.get_building_mask(city)

# Plot the expected results for the first city 
datahandler.plot(city[0])