In [1]:
import json
import os

import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import requests
from shapely.geometry import box

# Build GeoTiff Value Attribute Table (VAT) based on class definition files

Class definitions are stored as JSON in `.ecs` files on GitHub [here](https://github.com/SeaBee-no/annotation/tree/main/class_definitions). GeoTiffs produced by NR's machine learning workflow use integer class codes that are mapped to colours in the `.ecs` files.

In order for the maps to display properly in ArcGIS, we can build Value Attribute Tables (VATs) from the ECS files.

## 1. Define functions

In [2]:
def get_class_codes(url):
    """Build a dataframe of class labels from an ArcGIS Pro class definition file (.ecs). hosted
    on GitHub. Assumes a hierarchical class definition file originally created from Excel using
    'class_definition_from_df'. Returns a dataframe with the codes, names and colours for each
    level in the .ecs file.

    Args
        url: Str. Raw URL of class definition file (.ecs) on GitHub created using
            'class_definition_from_df'

    Returns
        Dataframe.

    Raises
        ValueError if URL does not end in '.ecs'.
    """
    if not url.endswith(".ecs"):
        raise ValueError("'url' must be a '.ecs' file.")

    response = requests.get(url)
    if response.status_code != 200:
        raise Exception(f"Failed to read URL: {response.status_code}")
    data = json.loads(response.text)

    class_dict = {
        "code": [],
        "name": [],
        "desc": [],
        "colour": [],
    }

    def process_subclasses(subclasses):
        for subclass in subclasses:
            class_dict["code"].append(subclass["alias"])
            class_dict["name"].append(subclass["name"])
            class_dict["desc"].append(subclass["description"])
            class_dict["colour"].append(subclass["color"])
            if "subclasses" in subclass:
                process_subclasses(subclass["subclasses"])

    # Recursively process levels in the ECS file
    process_subclasses(data["classDefs"])

    df = pd.DataFrame(class_dict)

    return df


def hex_to_rgb(hex_color):
    """Convert a hex-encoded colour string to RGB."""
    hex_color = hex_color.lstrip("#")

    return pd.Series([int(hex_color[i : i + 2], 16) for i in (0, 2, 4)])


def build_vat_from_ecs(img_path, ecs_version, level):
    """Build a Value Attribute Table (VAT) for an integer raster based on a SeaBee annotation
    file (.ecs). If the VAT is placed in the same folder as the parent image, the raster should
    appear with the correct stying/symbology in ArcMap.

    Args
        img_path: Str. Path to image file
        ecs_version: Str. Version of the .ecs annotation file to use (e.g. '0-1'). Files are
            hosted here: https://github.com/SeaBee-no/annotation/tree/main/class_definitions
        level: Int. Level in class hierarchy to use (1, 2 or 3)

    Returns
        None. A new file named 'xxx.tif.vat.dbf' is created in the image folder. ArcMap will
        consider these two components of a single file and should display the GeoTiff correctly
        as long as they are kept together.

    Raises
        ValueError if 'level' is not one of [1, 2, 3].
    """
    if level not in [1, 2, 3]:
        raise ValueError("'level' must be 1, 2 or 3.")

    # Column names used in ECS files
    code_column = "code"
    name_column = "name"
    hex_column = "colour"

    # Get class definitions
    url = f"https://raw.githubusercontent.com/SeaBee-no/annotation/main/class_definitions/seabee_class_definitions_v{ecs_version}.ecs"
    class_df = get_class_codes(url)

    # Filter to relevant codes for this level
    class_df = class_df[class_df[code_column].str.len() == 2 * level]
    class_df[code_column] = class_df[code_column].astype(int)

    # Read geotiff
    with rasterio.open(img_path) as src:
        raster = src.read(1)

    # Create empty df for VAT
    vat_df = pd.DataFrame(
        {
            "Value": class_df[code_column],
            "Count": [0] * len(class_df),
            "Class": class_df[name_column],
            "Colour": class_df[hex_column],
        }
    )

    # Count cells for each unique value in the raster
    unique, counts = np.unique(raster, return_counts=True)
    for u, c in zip(unique, counts):
        vat_df.loc[vat_df["Value"] == u, "Count"] = c

    # Convert hex codes to RGB
    vat_df[["Red", "Green", "Blue"]] = vat_df["Colour"].apply(hex_to_rgb)
    vat_df = vat_df.drop(columns=["Colour"])
    vat_df = vat_df.query("Count > 0")

    # Convert to gdf so we can save as dbf
    gdf = gpd.GeoDataFrame(vat_df, geometry=[box(0, 0, 1, 1)] * len(vat_df))

    # For dbf cols as ints
    schema = gpd.io.file.infer_schema(gdf)
    schema["properties"]["Value"] = "int:10"
    schema["properties"]["Count"] = "int:10"
    schema["properties"]["Red"] = "int:10"
    schema["properties"]["Green"] = "int:10"
    schema["properties"]["Blue"] = "int:10"

    # Save as xxx.tif.vat.dbf
    img_fold = os.path.dirname(img_path)
    base_name = os.path.basename(img_path)
    fname = os.path.splitext(base_name)[0]
    gdf.to_file(f"{img_fold}/{fname}.tif.vat.dbf", schema=schema)

    # Remove unnecessary components
    for temp_file in [
        f"{img_fold}/{fname}.tif.vat.cpg",
        f"{img_fold}/{fname}.tif.vat.shp",
        f"{img_fold}/{fname}.tif.vat.shx",
    ]:
        os.remove(temp_file)

## 2. Process data

In [3]:
# Specify version of .ecs file used for annotation
ecs_version = "0-1"

In [4]:
# Loop over images
for level in [1, 2, 3]:
    for img_no in [1, 2]:
        img_path = (
            f"/home/notebook/kelpmap_test/Spectrofly_MSI/lev{level}/image_{img_no}.tif"
        )
        build_vat_from_ecs(img_path, ecs_version, level)