In [1]:
#############
## imports ##
#############

# libraries 
import geopandas as gpd
import numpy as np 
import pandas as pd
import rasterio
import matplotlib.pyplot as plt
from rasterio.plot import show
from rasterio.mask import mask
import os
import json
from shapely.geometry import box, Polygon
import logging
import random
import torchvision.transforms.functional as TF
import torch

##############
## dataprep ##
##############

def get_RGB_match(DEM_name, original_tif_dir):

    # Find the base name matching the DEM file
    DEM_name = DEM_name[:-9] # remove _year.tif
    north_kms = 0 if int(DEM_name[7:9]) < 50 else 5
    east_kms = 0 if int(DEM_name[9:11]) < 50 else 5
    RGB_name = f"{DEM_name[:7]}{north_kms}{east_kms}"

    all_RGB_files = os.listdir(original_tif_dir)

    newest_match = f"{RGB_name}_0000"
    for file in all_RGB_files: 
        if str(file[:9]) == str(RGB_name) and int(file[10:14]) > int(newest_match[10:14]):
            newest_match = file
    return newest_match

def tif_from_ruta(ruta_geometry):
    minx_ruta = ruta_geometry.bounds[0]
    miny_ruta = ruta_geometry.bounds[1]

    miny = str(miny_ruta)[0:3]
    minx = str(minx_ruta)[0:2]

    if 0 <= int(str(miny_ruta)[3:5]) < 25:
        km_siffran_y = '00'
    elif 25 <= int(str(miny_ruta)[3:5]) < 50:
        km_siffran_y = '25'
    elif 50 <= int(str(miny_ruta)[3:5]) < 75:
        km_siffran_y = '50'
    elif 75 <= int(str(miny_ruta)[3:5]) < 100:
        km_siffran_y = '75'

    if 0 <= int(str(minx_ruta)[3:5]) < 25:
        km_siffran_x = '00'
    elif 25 <= int(str(minx_ruta)[3:5]) < 50:
        km_siffran_x = '25'
    elif 50 <= int(str(minx_ruta)[3:5]) < 75:
        km_siffran_x = '50'
    elif 75 <= int(str(minx_ruta)[3:5]) < 100:
        km_siffran_x = '75'

    year = 2018 # WHICH YEAR SHOULD IT BE??

    filename = f"{miny}_{minx}_{km_siffran_y}{km_siffran_x}_{year}.tif"
    return filename


def filter_imgs(all_rutor_path, original_tif_dir):
    all_rutor = gpd.read_file(all_rutor_path)
    all_rutor['in_tif'] = all_rutor['geometry'].map(tif_from_ruta)
    uniques = all_rutor.in_tif.unique()

    dir_files = os.listdir(original_tif_dir)
    only_tifs = [filename for filename in dir_files if filename[-4:] == ".tif"]

    # compare such only the part without the year. 
    only_tifs_noyear = [filename[:-8] for filename in only_tifs]
    uniques_noyear = [filename[:-8] for filename in list(uniques)]

    # check that all uniques are in only tifs
    if not (set(list(uniques_noyear)).issubset(set(only_tifs_noyear))):
        # logger.WARN(f"at least one tif name generated from all_rutor was not found in the directory: {original_tif_dir}")
        print(f"at least one tif name generated from all_rutor was not found in the directory")
        items_not_in_dir = [item for item in uniques_noyear if item not in only_tifs_noyear]
        print(f"items not in directory are: \n {items_not_in_dir}")

    intersection = list(set(uniques_noyear) & set(only_tifs_noyear))

    tifs_to_use = [filename for filename in only_tifs if filename[:-8] in intersection]

    return tifs_to_use


##############
## cropping ##
##############

class Crop_tif_varsize():
    """
    In: tif image to be cropped, and whole extent of 100x100 rutor
    Returns: directory of one cropped tif per 100x100 ruta.
    """

    def __init__(self, RGB_name_code, RGB_img_path, hs_name_code, hs_img_path,
                 DEM_name_code, DEM_img_path, rutor_path, destination_path, dims, logger, groundtruth_shapefile_path):

        self.RGB_name_code = RGB_name_code
        self.hs_name_code = hs_name_code
        self.DEM_name_code = DEM_name_code
        self.dimensions = dims
        self.destination_path = destination_path
        self.logger = logger
        self.RGB_img_path = RGB_img_path
        self.hs_img_path = hs_img_path
        self.DEM_img_path = DEM_img_path
        self.rutor_path = rutor_path
        self.groundtruth_polygs = gpd.read_file(groundtruth_shapefile_path)
        self.RGB_img = rasterio.open(RGB_img_path)
        self.hs_img = rasterio.open(hs_img_path)
        self.DEM_img = rasterio.open(DEM_img_path)
        self.filtered_rutor = self.filter_rutor()

    def forward(self):

        # generate all possible polygons in the image of dim x dim 
        generated_polygons_all = self.generate_geoseries(self.hs_img.bounds, self.hs_img.crs, self.dimensions) 
        generated_polygons_palsa = self.palsa_polygons(generated_polygons_all)
        positive_labels = self.crop_palsa_imgs(generated_polygons_palsa)
        # aug_labels = self.crop_aug_imgs(generated_polygons_palsa)
        negative_labels = self.crop_negatives(generated_polygons_all, generated_polygons_palsa)
        all_labels = positive_labels | negative_labels 
        self.hs_img.close()
        self.RGB_img.close()
        self.DEM_img.close()
        return all_labels

    def filter_rutor(self):
        # Find which 100x100 squares overlap with the current TIF
        rutor = gpd.read_file(self.rutor_path)
        image_polygon = box(*self.hs_img.bounds)
        cropped_polygons = rutor[rutor.geometry.apply(lambda x: x.intersection(image_polygon).equals(x))]

        return cropped_polygons
    
    def new_palsa_percentage(self, big_ruta, joined_df):
        contained_rutor = joined_df.loc[joined_df['name'] == big_ruta]
        total_pals_percentage = contained_rutor['PALS'].sum()
        percentage_factor = self.dimensions **2 / 10000 # TODO check this part. was 100x100 = 10000 so should now be the same still. 
        palsa_percentage = total_pals_percentage / percentage_factor
        return palsa_percentage
    
    def palsa_polygons(self, generated_polygons_all):

        # if 100x100 meter is used, the original rutor are used
        if self.dimensions == 100:
            return self.filtered_rutor

        # if not 100x100, find which polygons have palsa rutor in them
        d = {'name': [i for i in range(len(generated_polygons_all))]}
        generated_polygons_all_df = gpd.GeoDataFrame(d, geometry = generated_polygons_all, crs=generated_polygons_all.crs)

        # Perform a spatial join between generated_polygons_all and filtered_rutor 
        joined_df = gpd.sjoin(generated_polygons_all_df, self.filtered_rutor, how='inner', op = 'contains')
        covering_polygons_index = joined_df.index.unique() # find uniques
        result_df = generated_polygons_all_df.loc[covering_polygons_index] # select polygons that cover a smaller polygon

        # Generate palsa column in the resulting big ruta dataframe
        result_df['PALS'] = result_df['name'].apply(lambda x: self.new_palsa_percentage(x, joined_df))

        return result_df

    def generate_geoseries(self, bounds, crs, dims):

        """
        Generates all dim x dim polygons present in the hillshade TIF.
        """

        # height and width of new squares 
        square_dims = dims # 100x100 meters

        # Calculate the number of segments in each dimension (tif width // desired width in pixels!)
        segments_x = 2500 // square_dims # for depth data its 2500
        segments_y = 2500 // square_dims

        # Create an empty list to store the polygons
        polygons = []

        # Iterate over the segments
        for i in range(segments_y):
            for j in range(segments_x):
                # Calculate the coordinates of the segment
                left = bounds.left + j * square_dims
                bottom = bounds.bottom + i * square_dims
                right = left + square_dims
                top = bottom + square_dims

                # Create a polygon for the segment
                polygon = Polygon([(right, bottom), (left, bottom), (left, top), (right, top), (right, bottom)])

                # Append the polygon to the list
                polygons.append(polygon)

        # Create a GeoSeries from the list of polygons
        return gpd.GeoSeries(polygons, crs=crs)

    def make_crop(self, img, polygon, output_path):
        # Crop the TIF file using the polygon
        cropped_data, cropped_transform = mask(img, [polygon], crop=True)

        # Update the metadata for the cropped TIF
        cropped_meta = img.meta.copy()
        cropped_meta.update({"driver": "GTiff",
                            "height": cropped_data.shape[1],
                            "width": cropped_data.shape[2],
                            "transform": cropped_transform})

        # Save the cropped TIF file with a unique name
        with rasterio.open(output_path, "w", **cropped_meta) as dest:
            dest.write(cropped_data)

    def make_ground_truth(self, intersections, img_crop_path, gt_path):

        # use previously created cropped image
        cropped_img = rasterio.open(img_crop_path)
        masked_data, _ = mask(cropped_img, [polyg for polyg in intersections.geometry])

        with rasterio.open(gt_path, "w", **cropped_img.meta.copy()) as dest:
            dest.write(masked_data)


    def crop_palsa_imgs(self, palsa_rutor):

        """
        Crop TIF according to the polygons containing palsa. 
        """

        cropped_tifs_percentages = {}
        # Iterate over each polygon in the GeoDataFrame
        for idx, percentage, polygon in zip(palsa_rutor.index, palsa_rutor.PALS, palsa_rutor.geometry):

            # see if there is an overlap with ground truths:
            polyg_df = palsa_rutor.iloc[[idx]]
            intersections = gpd.overlay(self.groundtruth_polygs, polyg_df, how='intersection')

            # only crop if theres a ground truth overlap
            if not intersections.empty:

                hs_path = f'{self.destination_path}/hs/{self.hs_name_code}_crop_{idx}.tif'
                RGB_path = f'{self.destination_path}/rgb/{self.hs_name_code}_crop_{idx}.tif'
                DEM_path = f'{self.destination_path}/dem/{self.hs_name_code}_crop_{idx}.tif'
                gt_path = f'{self.destination_path}/groundtruth_mask/{self.hs_name_code}_crop_{idx}.tif'

                # crop hillshade and RGB according to same polygons
                self.make_crop(self.hs_img, polygon, hs_path) 
                self.make_crop(self.RGB_img, polygon, RGB_path)
                self.make_crop(self.DEM_img, polygon, DEM_path)

                # generate ground truth
                self.make_ground_truth(intersections, hs_path, gt_path)
                
                # Write the corresponding percentage to a dictionary as label 
                cropped_tifs_percentages[f"{self.hs_name_code}_crop_{idx}"] = percentage

        return cropped_tifs_percentages

    def crop_negatives(self, generated_polygons_all, generated_polygons_palsa):

        """
        Generates negative samples. Equal amount of negative as positive samples are
        taken from each image such that the final dataset is 50/50 positive and negative. 

            1) split the whole TIF into 100x100m polygons.
            2) filter out the areas containing palsa (positive samples)
            3) randomly sample as many negative samples as positive samples from that image
            4) crop the TIF according to the sampled areas and write locally

        """

        # filter out the squares with palsa 
        positives_mask = ~generated_polygons_all.isin(generated_polygons_palsa.geometry)
        all_negatives = generated_polygons_all[positives_mask]

        # randomly sample 
        sample_size = int(len(generated_polygons_palsa)) # based on number of positive samples 
        if sample_size <= len(all_negatives): # default case
            negative_samples = all_negatives.sample(n=sample_size) # sample randomly
        else:
            self.logger.info('Exception occurred! Number of positive samples > 1/2 image. Training set now contains fewer negative than positive samples.')
            negative_samples = all_negatives

        cropped_tifs_percentages = {}
        # Iterate over each polygon in the GeoDataFrame
        for idx, polygon in enumerate(negative_samples.geometry):
            # Crop the TIF file using the polygon
            hs_path = f'{self.destination_path}/hs/{self.hs_name_code}_negcrop_{idx}.tif'
            RGB_path = f'{self.destination_path}/rgb/{self.hs_name_code}_negcrop_{idx}.tif'
            DEM_path = f'{self.destination_path}/dem/{self.hs_name_code}_negcrop_{idx}.tif'
            gt_path = f'{self.destination_path}/groundtruth_mask/{self.hs_name_code}_crop_{idx}.tif'

            # crop hillshade and RGB according to same polygons
            self.make_crop(self.hs_img, polygon, hs_path) 
            self.make_crop(self.RGB_img, polygon, RGB_path)
            self.make_crop(self.DEM_img, polygon, DEM_path)

##################################################################################

            # GENERATE NEGATIVE GROUND TRUTHS

##################################################################################

            # Write the corresponding percentage to a dictionary as label 
            cropped_tifs_percentages[f"{self.hs_name_code}_negcrop_{idx}"] = 0

        return cropped_tifs_percentages
    


In [2]:
#############
## imports ##
#############

# libraries 
import geopandas as gpd
import numpy as np 
import pandas as pd
import rasterio
import matplotlib.pyplot as plt
from rasterio.plot import show
from rasterio.mask import mask
import os
import json
import logging

# functions 
# from functions import get_RGB_match, Crop_tif_varsize, filter_imgs

##################
## setup logger ##
##################

logger = logging.getLogger('my_logger')
logger.setLevel(logging.DEBUG)

# Setup logger
ch = logging.StreamHandler() # create console handler
ch.setLevel(logging.DEBUG) # set level to debug
formatter = logging.Formatter("%(asctime)s - %(message)s \n", "%Y-%m-%d %H:%M:%S") # create formatter
ch.setFormatter(formatter) # add formatter to ch
logger.addHandler(ch) # add ch to logger

logger.info('Imports successful')

##################
## load configs ##
##################

config_path = os.path.join(os.getcwd(), 'configs.json')
with open(config_path, 'r') as config_file:
    configs = json.load(config_file)

# load paths from configs 
config_paths = configs.get('paths', {}) 
palsa_shapefile_path = config_paths.get('palsa_shapefile_path') # load shapefile path
groundtruth_shapefile_path = config_paths.get('groundtruth_shapefile_path') # load shapefile path
save_crops_dir = config_paths.get('save_crops_dir') # load directory with all tifs
RGB_tif_dir = config_paths.get('RGB_tif_dir') # load directory with all tifs
hillshade_tif_dir = config_paths.get('hillshade_tif_dir') # load directory with all tifs
DEM_tif_dir = config_paths.get('DEM_tif_dir') # load directory with all tifs

config_img = configs.get('image_info', {}) 
dims = int(config_img.get('meters_per_axis')) 

logger.info('Configurations were loaded')

##########
## code ##
##########

# Filter hillshade data so only those containing a ground truth polygon remain
hillshade_filenames = filter_imgs(groundtruth_shapefile_path, hillshade_tif_dir) 
logger.info(f'{len(hillshade_filenames)} TIF paths have been loaded!')

# Loop over hillshade images to generate the crops. 
logger.info('Starting to generate training samples from TIFs..')
labels = {}
not_found = []
for idx, hs_img_name in enumerate(hillshade_filenames):
    # grab corresponding RGB image (matching the hillshade)
    try:
        RGB_tif_name = get_RGB_match(hs_img_name, RGB_tif_dir) 
        RGB_img_name_code = RGB_tif_name.split('.')[0]
        RGB_img_path = os.path.join(RGB_tif_dir, RGB_tif_name)

        hs_img_name_code = hs_img_name.split('.')[0]
        hs_img_path = os.path.join(hillshade_tif_dir, hs_img_name)

        DEM_img_name_code = hs_img_name.split('.')[0]
        DEM_img_path = os.path.join(DEM_tif_dir, hs_img_name)

        cropping = Crop_tif_varsize(RGB_img_name_code, RGB_img_path, hs_img_name_code, 
                                    hs_img_path, DEM_img_name_code, DEM_img_path, 
                                    palsa_shapefile_path, save_crops_dir, dims, logger, groundtruth_shapefile_path)
        
        # Run the cropping script
        new_labels = cropping.forward()
        labels = labels | new_labels
        logger.info(f'Generated training samples from image {idx+1}/{len(hillshade_filenames)}')
    except: 
        logger.info(f'RGB or DEM match for {hs_img_name} not found')
        not_found.append(hs_img_name)

print(f'The following images had no rgb match: \n {not_found}')
print(f'number of images where script failed: \n {len(not_found)}')

label_df = pd.DataFrame.from_dict(labels, orient='index', columns = ['palsa_percentage'])
label_df.to_csv(os.path.join(save_crops_dir, "palsa_labels.csv"))


2024-06-04 16:51:51 - Imports successful 

2024-06-04 16:51:51 - Configurations were loaded 

2024-06-04 16:51:51 - 8 TIF paths have been loaded! 

2024-06-04 16:51:51 - Starting to generate training samples from TIFs.. 

  new_labels = cropping.forward()
2024-06-04 16:51:58 - Generated training samples from image 1/8 

  new_labels = cropping.forward()
2024-06-04 16:52:05 - RGB or DEM match for 746_57_5050_2015.tif not found 

  new_labels = cropping.forward()
2024-06-04 16:52:13 - Generated training samples from image 3/8 

  new_labels = cropping.forward()
2024-06-04 16:52:20 - RGB or DEM match for 746_57_5075_2015.tif not found 

  new_labels = cropping.forward()
2024-06-04 16:52:27 - Generated training samples from image 5/8 

  new_labels = cropping.forward()
2024-06-04 16:52:34 - RGB or DEM match for 746_57_5025_2015.tif not found 

  new_labels = cropping.forward()
2024-06-04 16:52:42 - RGB or DEM match for 749_68_2550_2014.tif not found 

  new_labels = cropping.forward()
2024

The following images had no rgb match: 
 ['746_57_5050_2015.tif', '746_57_5075_2015.tif', '746_57_5025_2015.tif', '749_68_2550_2014.tif', '746_57_2525_2015.tif']
number of images where script failed: 
 5
