In [1]:
# Dataset generation

In [2]:

import rasterio as rio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.crs import CRS
from rasterio.io import MemoryFile
from rasterio.enums import Resampling
import affine
import os
from shapely.geometry import box
from rasterio.coords import BoundingBox
from rasterio.mask import mask as masker
import pystac

# RAM check
import psutil# Import libraries
import numpy as np
from pathlib import Path

In [3]:
# RAM check
process = psutil.Process()
print(process.memory_info().rss / 10**6)

118.411264


In [4]:
# # input paths
# load_vv = Path("../data/input_old_nigeria/vv/vv.tif")
# load_mask = Path("../data/input_old_nigeria/mask/merged/mask.tif")
# load_hand = Path("../data/input_old_nigeria/hand/hand.tif")

# input paths
load_vv = Path("../data/input_scene_4/vv/corrected/vv.tif")
load_mask = Path("../data/input_scene_4/mask/corrected/mask.tif")
load_hand = Path("../data/input_scene_4/hand/hand.tif")

In [5]:
# Setting
OUTPUT_SIZE = (256, 256)
DROPNA = True

In [6]:
# Function to change CRS system
# REPROJECTION CRS WITHOUT SAVING
# KEEPS IN --> RAM <--


def reproject_crs(file_path, target_crs):
    """Function to load tiff file from path
    with desired crs.
    """

    # Open the input GeoTIFF file
    src = rio.open(file_path)
    print(type(src))

    # Read metadata
    src_crs = src.crs
    src_transform = src.transform
    src_width = src.width
    src_height = src.height

    # Calculate the transform for reprojecting
    transform, width, height = calculate_default_transform(
        src_crs, target_crs, src_width, src_height, *src.bounds
    )

    # Create options for the output file
    kwargs = src.meta.copy()
    kwargs.update(
        {"crs": target_crs, "transform": transform, "width": width, "height": height}
    )

    # Create an in-memory dataset
    memfile = MemoryFile()
    dst = memfile.open(**kwargs)

    # Reproject and write to the in-memory dataset
    reproject(
        source=rio.band(src, 1),
        destination=rio.band(dst, 1),
        src_transform=src_transform,
        src_crs=src_crs,
        dst_transform=transform,
        dst_crs=target_crs,
        resampling=Resampling.nearest,
    )
    src.close()
    return dst

In [7]:
# Function to change CRS system
# REPROJECTION CRS WITHOUT SAVING
# KEEPS IN --> RAM <--


def reproject_crs2(file_path, target_crs):
    """Function to load tiff file from path
    with desired crs.
    """

    # Open the input GeoTIFF file
    with rio.open(file_path) as src:
        print(type(src))

        # Read metadata
        src_crs = src.crs
        src_transform = src.transform
        src_width = src.width
        src_height = src.height

        # Calculate the transform for reprojecting
        transform, width, height = calculate_default_transform(
            src_crs, target_crs, src_width, src_height, *src.bounds
        )

        # Create options for the output file
        kwargs = src.meta.copy()
        kwargs.update(
            {
                "crs": target_crs,
                "transform": transform,
                "width": width,
                "height": height,
            }
        )

        # Create an in-memory dataset
        memfile = MemoryFile()
        dst = memfile.open(**kwargs)

        # Reproject and write to the in-memory dataset
        reproject(
            source=rio.band(src, 1),
            destination=rio.band(dst, 1),
            src_transform=src_transform,
            src_crs=src_crs,
            dst_transform=transform,
            dst_crs=target_crs,
            resampling=Resampling.nearest,
        )
    return dst

In [8]:
# Loading all files with CRS:32632 as example
target_crs = CRS.from_epsg(32632)
vv = reproject_crs2(load_vv, target_crs)
mask = reproject_crs2(load_mask, target_crs)
hand = reproject_crs2(load_hand, target_crs)

<class 'rasterio.io.DatasetReader'>
<class 'rasterio.io.DatasetReader'>
<class 'rasterio.io.DatasetReader'>


In [9]:
# Load check
# mask.read(1).shape

In [10]:
# Memory check
print(process.memory_info().rss / 10**6)

2906.329088


In [11]:
print(vv.meta, mask.meta, hand.meta, sep="\n")

{'driver': 'GTiff', 'dtype': 'float32', 'nodata': 0.0, 'width': 10013, 'height': 9231, 'count': 1, 'crs': CRS.from_epsg(32632), 'transform': Affine(42.74890388953347, 0.0, 9262106.148155086,
       0.0, -42.74890388953347, 8051513.363165323)}
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': 255.0, 'width': 18604, 'height': 17312, 'count': 1, 'crs': CRS.from_epsg(32632), 'transform': Affine(22.838760654556694, 0.0, 9264867.293830955,
       0.0, -22.838760654556694, 8050302.032340328)}
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': -9999.0, 'width': 15585, 'height': 14783, 'count': 1, 'crs': CRS.from_epsg(32632), 'transform': Affine(198.79103795581588, 0.0, 8288255.774400535,
       0.0, -198.79103795581588, 9236149.514951987)}


In [12]:
print(vv.bounds, mask.bounds, hand.bounds, sep="\n")

BoundingBox(left=9262106.148155086, bottom=7656898.231361039, right=9690150.922800984, top=8051513.363165323)
BoundingBox(left=9264867.293830955, bottom=7654917.4078886425, right=9689759.597048327, top=8050302.032340328)
BoundingBox(left=8288255.774400535, bottom=6297421.600851161, right=11386414.100941926, top=9236149.514951987)


In [13]:
# Getting overlap bounding box between files
# So to keep parts of data that are needed in the RAM

left = [vv.bounds.left, mask.bounds.left, hand.bounds.left]
bottom = [vv.bounds.bottom, mask.bounds.bottom, hand.bounds.bottom]
right = [vv.bounds.right, mask.bounds.right, hand.bounds.right]
top = [vv.bounds.top, mask.bounds.top, hand.bounds.top]

overlap_bounds = BoundingBox(
    left=max(left), bottom=max(bottom), right=min(right), top=min(top)
)

In [14]:
# Convert bounds to polygon
overlap_polygon = box(*overlap_bounds)
print(overlap_polygon)

POLYGON ((9689759.597048327 7656898.231361039, 9689759.597048327 8050302.032340328, 9264867.293830955 8050302.032340328, 9264867.293830955 7656898.231361039, 9689759.597048327 7656898.231361039))


In [15]:
crop_img, crop_transform = masker(hand, shapes=[overlap_polygon], crop=True)

In [16]:
# Cropped image shape
crop_img.shape

(1, 1980, 2139)

In [17]:
# Create MemoryFile() out of crop_img
profile = hand.profile.copy()
profile.update(
    driver="GTiff",
    height=crop_img.shape[1],
    width=crop_img.shape[2],
    transform=crop_transform,
)

memfile = MemoryFile()
cropped_hand = memfile.open(**profile)
cropped_hand.write(crop_img)

In [18]:
cropped_hand.meta

{'driver': 'GTiff',
 'dtype': 'float32',
 'nodata': -9999.0,
 'width': 2139,
 'height': 1980,
 'count': 1,
 'crs': CRS.from_epsg(32632),
 'transform': Affine(198.79103795581588, 0.0, 9264717.352839503,
        0.0, -198.79103795581588, 8050360.973545546)}

In [19]:
crop_transform

Affine(198.79103795581588, 0.0, 9264717.352839503,
       0.0, -198.79103795581588, 8050360.973545546)

In [20]:
crop_img.shape

(1, 1980, 2139)

In [21]:
hand.transform

Affine(198.79103795581588, 0.0, 8288255.774400535,
       0.0, -198.79103795581588, 9236149.514951987)

In [22]:
hand.shape

(14783, 15585)

In [23]:
bigger = hand.shape[0] * hand.shape[1]
smaller = crop_img.shape[1] * crop_img.shape[2]

bigger / smaller

54.39931219629677

In [24]:
# Function to change the resolution of files to desired one


def rescale_image(input_file, scale_factor):
    # Read the data from the source file
    src = input_file
    data = src.read(
        out_shape=(
            src.count,
            int(src.height * scale_factor),
            int(src.width * scale_factor),
        ),
        resampling=Resampling.bilinear,
    )

    # Update the metadata
    transform = src.transform * src.transform.scale(
        (src.width / data.shape[-1]), (src.height / data.shape[-2])
    )

    # Update the profile
    profile = src.profile
    profile.update(
        driver="GTiff",
        height=data.shape[1],
        width=data.shape[2],
        transform=transform,
    )

    memfile = MemoryFile()
    scaled_dataset = memfile.open(**profile)
    scaled_dataset.write(data)

    return scaled_dataset, profile 

In [25]:
# Set resolution to standard
# RES = 20
RES = vv.res[0]
vv_refactor = vv.res[0] / RES
mask_refactor = mask.res[0] / RES
hand_refactor = cropped_hand.res[0] / RES
print(vv_refactor, mask_refactor, hand_refactor)

1.0 0.5342537135823142 4.650201990430153


In [26]:
vv_scaled, _ = rescale_image(vv, vv_refactor)
mask_scaled, _ = rescale_image(mask, mask_refactor)
hand_scaled, _ = rescale_image(cropped_hand, hand_refactor)

In [27]:
vv.res

(42.74890388953347, 42.74890388953347)

In [28]:
print(vv_scaled.meta, mask_scaled.meta, hand_scaled.meta, sep="\n")

{'driver': 'GTiff', 'dtype': 'float32', 'nodata': 0.0, 'width': 10013, 'height': 9231, 'count': 1, 'crs': CRS.from_epsg(32632), 'transform': Affine(42.74890388953347, 0.0, 9262106.148155086,
       0.0, -42.74890388953347, 8051513.363165323)}
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': 255.0, 'width': 9939, 'height': 9249, 'count': 1, 'crs': CRS.from_epsg(32632), 'transform': Affine(42.75000535439911, 0.0, 9264867.293830955,
       0.0, -42.74890522777441, 8050302.032340328)}
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': -9999.0, 'width': 9946, 'height': 9207, 'count': 1, 'crs': CRS.from_epsg(32632), 'transform': Affine(42.75226525110498, 0.0, 9264717.352839503,
       0.0, -42.75076085071309, 8050360.973545546)}


In [29]:
import psutil

process = psutil.Process()
print(process.memory_info().rss / 10**6)

3993.51808


In [30]:
class DatasetGenerator:

    # OK
    def __init__(self, collection_id):
        """Initiates an empty list. To be fed with different images data"""
        self.images = []
        self.ref_flag = False
        self.clipped_addresses = []
        self.row = 0
        self.col = 0
        self.collection_id = collection_id

    # OK
    # THE WAT THAT NO DATA ADDED?
    def add(self, image, name: str, set_nodata: int = 0):
        """Creates a dictionary of infromation about the given image.
        Then adds that dict to the images_data list.
        """

        # Correcting Nodata
        if image.nodata == None:
            image.nodata = set_nodata

        image_dict = dict()
        image_dict["name"] = name
        image_dict["image"] = image
        image_dict["band"] = image.read(1)

        self.images.append(image_dict)

    # OK
    def set_ref_image(self, name: str = "vv"):
        """Sets the image file which all dataset should be cliped
        with respect to that.
        Input:str -> name of the image
        Output:Bool -> True if done, False name not in the list
        """
        # Checks if the name is in added images list
        check_availibility = [True for image in self.images if image["name"] == name]
        if not check_availibility:
            return False

        for image in self.images:
            if image["name"] == name:
                self.ref_name = name
                self.ref_image = image["image"]
                self.ref_crs = image["image"].crs
                self.ref_res = image["image"].res
                self.ref_shape = image["image"].shape

        self.refrence_flag = True
        return True

    # OUTPUT IS NOT IMAGE, IS A LIST
    # NEED CHANGE?
    # Self.band needs to change
    def _create_clipped_image(self, image, band, row, col, height, width, name):

        ### need change self.band
        # band data array

        clipped_band = band[
            row : row + height,
            col : col + width,
        ]
        clipped_band = np.array(clipped_band)

        # Positioning
        tcol, trow = image.transform * (col, row)
        new_transform = affine.Affine(
            image.transform[0],
            image.transform[1],
            tcol,
            image.transform[3],
            image.transform[4],
            trow,
        )

        # creating clipped_image
        return_image = [
            clipped_band,
            image.crs,
            new_transform,
            clipped_band.shape[0],
            clipped_band.shape[1],
            image.dtypes[0],
            image.nodata,
            name,
        ]
        return return_image

    def _check_complete(self, images, height, width):
        """
        If any image in the same coordination has
        nodata value returns False, otherwise True.
        """
        # for image in images:
        #     with image.read(1) as band:
        #         if sum(sum(band == image.nodata)):
        #             return False
        # return True
        print("images sent to complete_check method:")
        for img in images:
            print(img[7])

        for image in images:

            print(image[7], ":")
            print("type(image[0])", type(image[0]))
            print("type(image[6])", type(image[6]))
            print("type sum", type(sum(image[0] == image[6])))

            if image[0].shape != (height, width):
                print("shape was not complete -> shape:", image[0].shape)
                return False

            # if type(sum(image[0]==image[6])) == int:
            #     return False
            if sum(sum(image[0] == image[6])):
                print("none found")
                return False
        return True

    def _save_image(self, save_path_format, image, col, row, mask_coverage):
        """
        Saves Clipped image into file
        """
        name = image[7]

        file_name = save_path_format.format(
            name=name, col=col, row=row, mask_coverage=mask_coverage
        )
        file_name = Path(file_name)

        os.makedirs(os.path.split(file_name)[0], exist_ok=True)

        print("SAVE CALLED", file_name)

        with rio.open(
            file_name,
            "w",
            driver="GTiff",
            height=image[3],
            width=image[4],
            count=1,
            dtype=image[5],
            crs=image[1],
            transform=image[2],
        ) as dst:
            dst.write(image[0], 1)
        self.clipped_addresses.append(file_name)
        return file_name

    def _xy_from_row_col(self, image, row, col):
        """Returns coordinate of a pixel in one image from it's (row,col)"""
        x, y = image.xy(row, col)
        return x, y

    def _row_col_from_xy(self, image, x, y):
        """Returns (row,col) position of a pixel from it's coordinate"""
        row, col = image.index(x, y)
        return row, col

    def run(self, height: int = 256, width: int = 256, only_complete: bool = True):

        save_path_format = "../data/dataset/scene4/x{row}_y{col}_{mask_coverage}/{name}.tif"

        row = self.row
        while row < self.ref_shape[0]:
            col = self.col
            while col < self.ref_shape[1]:

                mask_coverage = 0
                clipped_images = []
                for img in self.images:

                    name = img["name"]
                    image = img["image"]
                    band = img["band"]

                    # Coverting row,col of refrence image to row,col of the current image
                    x, y = self._xy_from_row_col(self.ref_image, row=row, col=col)
                    trow, tcol = self._row_col_from_xy(image=image, x=x, y=y)

                    # Creating cropped image
                    clipped_image = self._create_clipped_image(
                        image, band, trow, tcol, height, width, name
                    )

                    # Append to list of images on same location
                    clipped_images.append(clipped_image)

                    # Calculate mask_coverage for the area of interest
                    if name == "mask":
                        mask_coverage = self._get_mask_coverage(clipped_image)
                        print("MASK COVERAGE:", mask_coverage)

                ### _check_complete
                print("ROW:", row, "COL:", col)

                if mask_coverage == -1:
                    complete_check = False
                else:
                    complete_check = self._check_complete(clipped_images, height, width)

                print("complete check:", complete_check)
                ### save images
                if complete_check:
                    # Calculate water coverage in the image
                    # Save images into path
                    for image in clipped_images:
                        self._save_image(
                            save_path_format, image, col, row, mask_coverage
                        )
                ### add to STAC
                ### append saved images into a list?
                # ->## Done inside _save_image method

                # Update column position
                col = col + width

            # Update row position
            row = row + height

        print("Tiles saved successfully")
        return self.clipped_addresses

    def set_row_col_for_generator(self, row, col):
        """This method is used when a half bulit collection is loaded.
        To continue generating clips from the given row, col
        """
        self.row = row
        self.col = col

    def _create_collection(self):

        self.collection = pystac.Collection(
            id="",
            description="",
            extent="",
            title="",
            href="",
            extra_fields={},
            catalog_type="",
            license="",
        )

    # -------------- later ------------
    def _get_mask_coverage(self, mask):
        """Returns the mask coverage in the scene
        0 is considered as non covered area in input file
        1 is considered as covered area in input file
        Returns integer in range (0, 100)
        """

        data = mask[0]
        image_count = data.size
        mask_count = (data == 1).sum()

        # print(mask)
        # print(data)

        if image_count == 0:
            return -1

        return int((mask_count / image_count) * 100)

    def _create_asset(self, file_href):
        return pystac.Asset(
            href=file_href,
            media_type=pystac.MediaType.GEOTIFF,
        )

    def _create_item(
        self,
        id,
        geometry,
        bbox,
    ):

        item = pystac.Item()
        item.id = id
        item.geometry = geometry
        item.bbox = bbox

    pystac.Item(
        id="",  # ?
        geometry="",  #
        bbox="",  #
        datetime="",  #
        start_datetime="",  #
        end_datetime="",  #
        href="",  #
        collection="",  #
        # for 1 or 2:
        # "water_coverage": in percent%
        # "minimum MASK" value
        properties={},  # 1 -> 1 or 2
        extra_fields={},  # 2 -> 1 or 2
        assets={},  #
    )

    def _add_to_stac(self):
        pass

In [None]:
ds_generator = DatasetGenerator(collection_id="randomId")
ds_generator.add(vv_scaled, name="vv", set_nodata=0.0)
ds_generator.add(mask_scaled, name="mask")
ds_generator.add(hand_scaled, name="HAND")

ds_generator.set_ref_image("HAND")
paths = ds_generator.run(height=256, width=256)

In [32]:
# Memory usage
print(process.memory_info().rss / 10**6)

5358.592


In [33]:
!ls ../data/

dataset		     input_scene_1  input_scene_4
dataset_old_nigeria  input_scene_2  trials_output
input_old_nigeria    input_scene_3  trials_output.zip
