In [1]:
import math
import geopandas as gpd

from obia.handlers.geotif import open_geotiff, open_binary_geotiff_as_mask
from obia.segmentation.segment import segment
from obia.classification.classify import classify

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
image_path="data/image.tif"
mask_path="data/mask.tif"
training_segments="data/training_segments.gpkg"
class_regions="data/class_regions.gpkg"

image = open_geotiff(image_path)
mask = open_binary_geotiff_as_mask(mask_path)
training = gpd.read_file(training_segments)
regions = gpd.read_file(class_regions)

In [3]:
# pixel_area = 0.5 ** 2
# crown_area = math.pi * (5 ** 2)
# tree_area = mask.sum() * pixel_area
# n_crowns = round(tree_area / crown_area)
# print(n_crowns)
# 
# segmented_image = segment(
#     image, segmentation_bands=[4, 5, 2], 
#     method="slic", n_segments=n_crowns, convert2lab=False, slic_zero=True, mask=mask,
#     calc_mean=True, calc_variance=True, calc_contrast=True, calc_correlation=True,
#     calc_skewness=False, calc_kurtosis=False, calc_dissimilarity=False, calc_homogeneity=False, calc_ASM=False, calc_energy=False
# )
# segmented_image.write_segments('data/segments.gpkg')

segments = gpd.read_file("data/segments.gpkg")

In [4]:
def preprocess_acceptable_classes(acceptable_classes_gdf):
    acceptable_classes_gdf['acceptable_classes'] = acceptable_classes_gdf['acceptable_classes'].apply(
        lambda x: [int(cls) for cls in str(x).split(',')] if isinstance(x, str) else [int(cls) for cls in x]
    )
    return acceptable_classes_gdf
    
regions = preprocess_acceptable_classes(regions)

In [5]:
classified = classify(
    segments, 
    training,
    acceptable_classes_gdf=regions,
    method='mlp', hidden_layer_sizes=(100,), solver='adam', max_iter=10000,
    test_size=0.2,
    compute_reports=True,
    compute_shap=False
)

In [6]:
print(classified.report)

              precision    recall  f1-score   support

           1       0.95      0.95      0.95       201
           2       0.87      0.91      0.89       113
           3       0.88      0.96      0.92        24
           4       0.95      0.95      0.95        98
           5       1.00      1.00      1.00       189
           6       0.36      0.22      0.28        18

    accuracy                           0.94       643
   macro avg       0.84      0.83      0.83       643
weighted avg       0.93      0.94      0.93       643



In [7]:
classified.classified.to_file("data/classified.gpkg")