In [None]:
from tree_detection_framework.preprocessing.preprocessing import create_dataloader, visualize_dataloader
from tree_detection_framework.detection.detector import DeepForestDetector

In [2]:
dataloader = create_dataloader(
        raster_folder_path='/ofo-share/scratch-amritha/emerald_point_dtree2/dataset/emerald-point-ortho',
        chip_size=100,
        chip_stride=50,
        use_units_meters=True,
        output_resolution=0.2,
    )

In [None]:
len(dataloader)

In [None]:
visualize_dataloader(dataloader, 5)

In [None]:
model = DeepForestDetector()

In [None]:
predictions = model.predict(dataloader)

In [None]:
predictions[0]

In [None]:
import shapely
import pandas as pd
from tree_detection_framework.detection.region_detections import RegionDetections, RegionDetectionsSet

def get_pixel_bounds_box_from_sample(sample):
    # Get the shape of the image
    image_shape = sample["image"].shape[-2:]
    # Create a box from it, noting that the shape is i, j
    image_bounds = shapely.box(xmin=0, ymin=0, xmax=image_shape[1], ymax=image_shape[0])
    # Duplicate by the number of elements in the box
    image_bounds_batch = [image_bounds] * sample["image"].shape[0]
    return image_bounds_batch

def get_geospatial_bounds_box_from_sample(sample):
    # Extract the bounds field
    batch_bounds = sample["bounds"]
    # The bounds are different for each region in the batch, so create a box for each one
    geospatial_bounds_batch = [
        shapely.box(
            xmin=tile_bounds.minx,
            ymin=tile_bounds.miny,
            xmax=tile_bounds.maxx,
            ymax=tile_bounds.maxy
        )
        for tile_bounds in batch_bounds
    ]

    return geospatial_bounds_batch

def parse_deepforest_output(prediction: pd.DataFrame):
    # Create a list of shapely objects from the deepforest predictions
    xmin = prediction["xmin"].to_numpy()
    ymin = prediction["ymin"].to_numpy()
    xmax = prediction["xmax"].to_numpy()
    ymax = prediction["ymax"].to_numpy()
    prediction_geometry = shapely.box(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
    return prediction_geometry

# Create a list of RegionDetection objects
region_detections = []
for sample, prediction in zip(dataloader, predictions):
    # Extract the derived attributes from the sample and prediction
    # Note that the first element is taken from the ones where a batch is returned
    image_bounds = get_pixel_bounds_box_from_sample(sample)[0]
    geospatial_bounds = get_geospatial_bounds_box_from_sample(sample)[0]
    prediction_geometry = parse_deepforest_output(prediction)

    # Extract the CRS of the first (only) element in the batch
    CRS = sample["crs"][0]

    # Create the region detection
    region_detection = RegionDetections(
        detection_geometries=prediction_geometry,
        data=prediction,
        pixel_prediction_bounds=image_bounds,
        geospatial_prediction_bounds=geospatial_bounds,
        input_in_pixels=True,
        CRS=CRS,
    )
    # Append to the list
    region_detections.append(region_detection)

# Create the region detection set and save
region_detection_set = RegionDetectionsSet(region_detections)
region_detection_set.save("/ofo-share/repos-david/tree-detection-framework/data/predictions.geojson")