# SHP to XML & GTIFF Clipper

In [None]:
# Install if needed
!pip install shapely
!pip install elementpath
!pip install pandas
!conda install -c conda-forge gdal rasterio fiona -y

In [1]:
import pandas as pd
import rasterio
import rasterio.mask
import fiona
import xml.etree.ElementTree as ET
from pathlib import Path
import numpy as np

In [2]:
class SHP2XML(object):
    """
    Shape to XML class
    """

    def __init__(
        self,
        mosaic_name=None,
        shape_name=None,
        class_field_name: str = "class",
        output_folder="annotations",
        image_format="JPEG",
    ):
        """
        Constructor of SHP2XML
        Args:
            mosaic_name: orthomosaic path
            shape_name: shape path
            class_field_name: column class name that contains the classes
            output_folder: str() folder name
            image_format: GTiff, JPEG, PNG
        """
        self.folder_path = Path(mosaic_name)
        self.shape_path = Path(shape_name)
        self.class_field_name = class_field_name
        self.output_folder_name = output_folder

        # check path and directories
        self.check_directories()

        self.open_datasets()

        self.raster_config = {"driver": image_format}
        if image_format == "JPEG":
            self.file_extension = "jpg"
        elif image_format == "PNG":
            self.file_extension = "png"
        elif image_format == "GTiff":
            self.file_extension = "tif"
            self.raster_config["crs"] = self.dataset.crs
            self.raster_config["compress"] = "lzw"
        else:
            raise Exception(
                "File extension format unknown image_format='{}', please try with JPEG, PNG, GTiff".format(
                    image_format
                )
            )

        # Create annotations based on the shapefile
        self.annotations()

    def open_datasets(self):
        """
        Opens the datasets using rasterio and fiona
        """
        self.dataset = rasterio.open(self.folder_path)
        self.bands = self.dataset.count
        self.shapes = fiona.open(self.shape_path)

    def check_directories(self, enable_abosulute=False):
        """
        Checks folders and paths
        Args:
            enable_abosulute: transforms all the paths in absolute paths
        """
        if not self.folder_path.is_absolute() and enable_abosulute:
            self.folder_path = self.folder_path.resolve()
        if not self.shape_path.is_absolute():
            self.shape_path = self.shape_path.resolve()

        self.file_name = self.folder_path.name
        self.folder_name = self.folder_path.parts[-2]

        self.output_folder = Path(self.folder_name) / self.output_folder_name

        if not Path.is_dir(self.output_folder):
            Path.mkdir(self.output_folder)

    def create_tree(self):
        """
        Creates root node of the XML
        """
        annotation = ET.Element("annotation")
        self.folder = ET.SubElement(annotation, "folder")
        self.filename = ET.SubElement(annotation, "filename")
        self.path = ET.SubElement(annotation, "path")
        source = ET.SubElement(annotation, "source")
        self.database = ET.SubElement(source, "database")
        size = ET.SubElement(annotation, "size")
        self.width = ET.SubElement(size, "width")
        self.height = ET.SubElement(size, "height")
        self.depth = ET.SubElement(size, "depth")
        self.segmented = ET.SubElement(annotation, "segmented")
        self.annotation = annotation

    def add_tags(self, tags):
        """
        Add tags and classes to an XML tree
        Args:
            tags: Includes the bbox and class name of each tag
        """
        for tag_bbox in tags.tolist():
            xml_object = ET.Element("object")
            name = ET.SubElement(xml_object, "name")
            pose = ET.SubElement(xml_object, "pose")
            truncated = ET.SubElement(xml_object, "truncated")
            difficult = ET.SubElement(xml_object, "difficult")
            bndbox = ET.SubElement(xml_object, "bndbox")
            for bbox_coord in tag_bbox:
                # tag_bbox includes a key with class name
                if bbox_coord != "class":
                    val = ET.SubElement(bndbox, str(bbox_coord))
                    val.text = str(tag_bbox[bbox_coord])
                else:
                    # This will add the class name to the xml tree child
                    name.text = str(tag_bbox[bbox_coord])
            pose.text = "Unspecified"
            truncated.text = "0"
            difficult.text = "0"
            self.annotation.append(xml_object)

    def file_information(self, width, height, output_file_name):
        """
        Adds the file information of each .xml file
        Args:
            width: len(x) in pixels
            height: len(y) in pixels
            layers: numer of bands in the GTIFF can be 1-N
        """
        #         self.folder.text = self.folder_name
        self.folder.text = self.output_folder_name
        self.path.text = output_file_name
        self.database.text = "Unknown"
        self.segmented.text = "0"
        ## add photo size
        self.width.text = str(width)
        self.height.text = str(height)
        self.depth.text = str(self.bands)

    def write_xml(self, sequence):
        """
        Writes the XML file using a sequence number
        Args:
            sequence: string or int used to write the file name
        """
        # Thil will capture the original GTiff file name and adds the extention .xml and a sequence number
        original_file_name = self.file_name.split(".")[0]
        name = str(original_file_name + "_{}.xml".format(sequence))
        # This will adds the file name in to the XML tree
        self.filename.text = original_file_name + "_{}.tif".format(sequence)
        # Write xml file
        ET.ElementTree(self.annotation).write(str(self.output_folder / name))

    def annotations(self):
        """
        Captures the annotations from a shapefile and transforms the geo coordinates in xy coordinates
        and finally generates a list with the bbox of each polygon inside the shapefile
        """
        self.tags = []
        bwd = ~self.dataset.transform  # this transforms coordinates to rows and cols
        for geometry in self.shapes:
            #     tag.append(geometry["properties"]["id"])
            polygon = geometry["geometry"]["coordinates"][0][:-1]
            try:
                class_value = geometry["properties"][self.class_field_name]
            except:
                raise Exception(
                    "Field class name not found in shapefile. Please check your shapefile CLASS column names"
                )
            xmin = np.inf
            ymin = np.inf
            xmax = -np.inf
            ymax = -np.inf
            for edge in polygon:
                edge = bwd * edge
                if edge[0] < xmin:
                    xmin = int(edge[0])
                if edge[1] < ymin:
                    ymin = int(edge[1])
                if edge[0] > xmax:
                    xmax = int(edge[0])
                if edge[1] > ymax:
                    ymax = int(edge[1])
            self.tags.append(
                {
                    "xmin": xmin,
                    "xmax": xmax,
                    "ymin": ymin,
                    "ymax": ymax,
                    "class": class_value,
                }
            )
        self.xml_bbox = pd.DataFrame(self.tags)
        self.tags = np.array(self.tags)

    def clipper(self, base_divider=1800):
        """
        Clips each GTiff using a scale of ~1.8K pixels with an overlap of 5% in X and Y to match a square of ~2K pixels
        Args:
            base_divider: number of pixels to clip in X and Y directions, modify this if needed
        """
        fwd = (
            self.dataset.transform
        )  # this transform rows and cols numbers to coordinates
        bwd = ~self.dataset.transform  # this transforms coordinates to rows and cols

        width = self.dataset.width
        height = self.dataset.height

        # This is useful to calculate the divisions and later calculate the corners of the clip mask

        #         base_divider = 1800
        steps_w = np.array(range(0, width, base_divider))
        steps_h = np.array(range(0, height, base_divider))

        # this lines above are used to give the remaining pixels when the partition isnt even
        if steps_w[-1] != width:
            steps_w[-1] += int(width - steps_w[-1])
        if steps_h[-1] != height:
            steps_h[-1] += int(height - steps_h[-1])

        # lets build the clipping masks
        sequence = 0
        shapes = []
        polygon_template = {"type": "Polygon"}
        offset = int(base_divider * 0.05)  # this should be around 5% to match 2k
        for i in range(len(steps_w) - 1):
            for j in range(len(steps_h) - 1):
                polygon = []
                template = polygon_template.copy()
                if i > 0 and j > 0 and i < len(steps_w) - 2 and j < len(steps_h) - 2:
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i + 1] + offset, steps_h[j] - offset]))
                    )
                    polygon.append(
                        fwd
                        * (tuple([steps_w[i + 1] + offset, steps_h[j + 1] + offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j + 1] + offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )

                    coordinates = list(
                        (
                            tuple([steps_w[i] - offset, steps_h[j] - offset]),
                            tuple([steps_w[i + 1] + offset, steps_h[j + 1] + offset]),
                        )
                    )
                elif i == len(steps_w) - 2 and j == len(steps_h) - 2:
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )
                    polygon.append(fwd * (tuple([steps_w[i + 1], steps_h[j] - offset])))
                    polygon.append(fwd * (tuple([steps_w[i + 1], steps_h[j + 1]])))
                    polygon.append(fwd * (tuple([steps_w[i] - offset, steps_h[j + 1]])))
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )

                    coordinates = list(
                        (
                            tuple([steps_w[i] - offset, steps_h[j] - offset]),
                            tuple([steps_w[i + 1], steps_h[j + 1]]),
                        )
                    )
                elif i == len(steps_w) - 2 and j > 0:
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )
                    polygon.append(fwd * (tuple([steps_w[i + 1], steps_h[j] - offset])))
                    polygon.append(
                        fwd * (tuple([steps_w[i + 1], steps_h[j + 1] + offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j + 1] + offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )

                    coordinates = list(
                        (
                            tuple([steps_w[i] - offset, steps_h[j] - offset]),
                            tuple([steps_w[i + 1], steps_h[j + 1] + offset]),
                        )
                    )
                elif j == len(steps_h) - 2 and i > 0:
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i + 1] + offset, steps_h[j] - offset]))
                    )
                    polygon.append(
                        fwd * (tuple([steps_w[i + 1] + offset, steps_h[j + 1]]))
                    )
                    polygon.append(fwd * (tuple([steps_w[i] - offset, steps_h[j + 1]])))
                    polygon.append(
                        fwd * (tuple([steps_w[i] - offset, steps_h[j] - offset]))
                    )

                    coordinates = list(
                        (
                            tuple([steps_w[i] - offset, steps_h[j] - offset]),
                            tuple([steps_w[i + 1] + offset, steps_h[j + 1]]),
                        )
                    )
                elif i == 0 and j == 0:
                    polygon.append(fwd * (tuple([steps_w[i], steps_h[j]])))
                    polygon.append(fwd * (tuple([steps_w[i + 1] + offset, steps_h[j]])))
                    polygon.append(
                        fwd
                        * (tuple([steps_w[i + 1] + offset, steps_h[j + 1] + offset]))
                    )
                    polygon.append(fwd * (tuple([steps_w[i], steps_h[j + 1] + offset])))
                    polygon.append(fwd * (tuple([steps_w[i], steps_h[j]])))

                    coordinates = list(
                        (
                            tuple([steps_w[i], steps_h[j]]),
                            tuple([steps_w[i + 1] + offset, steps_h[j + 1] + offset]),
                        )
                    )
                elif i == 0 and j > 0:
                    polygon.append(fwd * (tuple([steps_w[i], steps_h[j] - offset])))
                    polygon.append(
                        fwd * (tuple([steps_w[i + 1] + offset, steps_h[j] - offset]))
                    )
                    polygon.append(
                        fwd
                        * (tuple([steps_w[i + 1] + offset, steps_h[j + 1] + offset]))
                    )
                    polygon.append(fwd * (tuple([steps_w[i], steps_h[j + 1] + offset])))
                    polygon.append(fwd * (tuple([steps_w[i], steps_h[j] - offset])))

                    coordinate = list(
                        (
                            tuple([steps_w[i], steps_h[j] - offset]),
                            tuple([steps_w[i + 1] + offset, steps_h[j + 1] + offset]),
                        )
                    )

                elif j == 0 and i > 0:
                    polygon.append(fwd * (tuple([steps_w[i] - offset, steps_h[j]])))
                    polygon.append(fwd * (tuple([steps_w[i + 1] + offset, steps_h[j]])))
                    polygon.append(
                        fwd * (tuple([steps_w[i + 1] + offset, steps_h[j + 1]]))
                    )
                    polygon.append(fwd * (tuple([steps_w[i] - offset, steps_h[j + 1]])))
                    polygon.append(fwd * (tuple([steps_w[i] - offset, steps_h[j]])))

                    coordinates = list(
                        (
                            tuple([steps_w[i] - offset, steps_h[j]]),
                            tuple([steps_w[i + 1] + offset, steps_h[j + 1]]),
                        )
                    )

                xmax = coordinates[1][0]
                xmin = coordinates[0][0]
                ymax = coordinates[1][1]
                ymin = coordinates[0][1]

                # Intersection of xmin and xmax
                xminmax = self.xml_bbox[self.xml_bbox["xmin"] > xmin].merge(
                    self.xml_bbox[self.xml_bbox["xmax"] < xmax],
                    left_index=True,
                    right_index=True,
                    on=["xmin", "ymin", "xmax", "ymax"],
                )
                # Intersection of ymin and ymax
                yminmax = self.xml_bbox[self.xml_bbox["ymin"] > ymin].merge(
                    self.xml_bbox[self.xml_bbox["ymax"] < ymax],
                    left_index=True,
                    right_index=True,
                    on=["xmin", "ymin", "xmax", "ymax"],
                )
                # Intersection of x and y
                window = xminmax.merge(
                    yminmax,
                    left_index=True,
                    right_index=True,
                    on=["xmin", "ymin", "xmax", "ymax"],
                )

                # here we should create the xml file
                if window.shape[0] > 0:

                    output_file_name = str(
                        self.output_folder
                        / str(
                            self.file_name.split(".")[0]
                            + "_{}.{}".format(sequence, self.file_extension)
                        )
                    )

                    self.create_tree()
                    self.add_tags(self.tags[window.index])
                    self.file_information(xmax - xmin, ymax - ymin, output_file_name)
                    self.write_xml(sequence)

                    template["coordinates"] = [polygon]

                    # this will clip the raster based on the polygons
                    result = rasterio.mask.mask(
                        self.dataset, [template], crop=True, all_touched=True
                    )

                    if self.file_extension == "tif":
                        self.raster_config["transform"] = result[1]

                    new_dataset = rasterio.open(
                        output_file_name,
                        "w",
                        height=result[0].shape[1],
                        width=result[0].shape[2],
                        count=self.bands,
                        dtype=result[0].dtype,
                        **self.raster_config
                    )
                    new_dataset.write(result[0])
                    new_dataset.close()
                    sequence += 1


In [3]:
xml = SHP2XML('RHBV/RGB.tif', "RHBV/PLOTS.shp", class_field_name="class", image_format="JPEG")
xml.clipper()

  s = get_writer_for_driver(driver)(path, mode, driver=driver,
