In [1]:
import os
import leafmap
from datetime import date
from d2spy.workspace import Workspace
# import env_setting # environment setting of d2s workspace

import matplotlib.pyplot as plt
import time

from plot_boundary_extractor import PlotExtraction

ModuleNotFoundError: No module named 'plot_boundary_extractor'

# Connect to workspace and load data product

In [17]:
# Connect to workspace
workspace = Workspace.connect("https://ps2.d2s.org")

In [None]:
# Change the search term in `.filter_by_title` to match your project
projects = list(workspace.get_projects().filter_by_title("cornell"))
for i, proj in enumerate(projects):
    print(i, proj)
    
index = input("Choose project index: ")
project = projects[int(index)]

0 Project(title='2023 Cornell Wheat', description='USDA WheatCAP Project - Cornell University 2023 Winter Wheat Master Nursery trial at Ithaca, NY - McGowan', start_date=datetime.date(2022, 10, 11), end_date=datetime.date(2023, 6, 29))
1 Project(title='2024 Cornell Wheat', description='USDA WheatCAP Project - Cornell University 2024 Winter Wheat Master Nursery trial at Ithaca, NY - Helfer', start_date=datetime.date(2023, 10, 11), end_date=datetime.date(2024, 7, 8))
2 Project(title='2022 Cornell Wheat', description='USDA WheatCAP Project - Cornell University 2022 WWMASTER2022ACCT3 trial at Ithaca, NY - Helfer', start_date=datetime.date(2021, 10, 21))


In [None]:
# Change the date range in `filter_by_date` to match the acquistion date of the flight in your project
start_date = date(2022,5,1)
end_date = date(2022,6,30)
flights = project.get_flights().filter_by_date(start_date,end_date)
for i, fli in enumerate(flights):
    print(i, fli)
    
index = input("Select the flight: ")
flight = flights[int(index)]

Flight(acquisition_date='2022-05-11', name=None, altitude=120.0, side_overlap=60.0, forward_overlap=75.0, sensor='RGB', platform='Phantom_4')
Flight(acquisition_date='2022-05-25', name='', altitude=120.0, side_overlap=60.0, forward_overlap=75.0, sensor='Multispectral', platform='Phantom_4')
Flight(acquisition_date='2022-05-11', name='', altitude=120.0, side_overlap=60.0, forward_overlap=75.0, sensor='Multispectral', platform='M300')


In [None]:
# Change the search term in `.filter_by_data_type` to match your COG's data type
data_products = flight.get_data_products().filter_by_data_type("ortho")
if len(data_products) == 1:
    data_product = data_products[0]
elif len(data_products) > 1:
    for i, dp in enumerate(data_products):
        print(dp)
    index = input("Select the data product: ")
    data_product = data_products[int(index)]
else:
    print("No data products found")

DataProduct(data_type='ortho', filepath='/static/projects/b4ab960b-9629-46d7-881c-5612fd5ee0dd/flights/0d73f050-6a66-4922-851c-9b98e2a45dab/data_products/58822a09-c290-41ac-8248-a306fe5299cf/7c782101-90a6-42a5-b39f-c5b70d0cfb4a.tif', original_filename='20220511_cn_mic_dry_mosaic_rgb.tif', is_active=True, public=True, stac_properties={'raster': [{'data_type': 'uint8', 'stats': {'minimum': 0.0, 'maximum': 255.0, 'mean': 112.381, 'stddev': 89.018}}, {'data_type': 'uint8', 'stats': {'minimum': 0.0, 'maximum': 255.0, 'mean': 115.057, 'stddev': 66.764}}, {'data_type': 'uint8', 'stats': {'minimum': 0.0, 'maximum': 255.0, 'mean': 73.811, 'stddev': 57.792}}, {'data_type': 'uint8', 'stats': {'minimum': 0.0, 'maximum': 255.0, 'mean': 93.907, 'stddev': 122.995}}], 'eo': [{'name': 'b1', 'description': 'Red'}, {'name': 'b2', 'description': 'Green'}, {'name': 'b3', 'description': 'Blue'}, {'name': 'b4', 'description': 'Alpha'}]}, status='SUCCESS', url='https://ps2.d2s.org/static/projects/b4ab960b-962

In [24]:
# Check for API key
api_key = workspace.api_key
if not api_key:
    print("No API key. Please request one from the D2S profile page and re-run this cell.")
else:
    os.environ["D2S_API_KEY"] = api_key
    
os.environ["TITILER_ENDPOINT"] = "https://tt.d2s.org"

# Initialize plot extraction class

In [25]:
m = leafmap.Map()
m.clear_layers()
m.add_basemap("USGS NAIP Imagery")
m.add_cog_layer(f"{data_product.url}?API_KEY={api_key}", name="ortho", zoom_to_layer=True)
m

Map(center=[20, 0], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out_text…

# Automatic plot detection

In [None]:
# define arguments
args = {
        'base_layer': data_product, 
        'api_key': api_key,
        'clipped_filename': './test.tif', 
        'clip_boundary': m.draw_control.last_draw, #optional
        'n_rows': 22,
        'n_cols': 13,
        'plot_width': 3.7,
        'plot_height': 1.0,
        # 'resize': (320, 640), # optional
        'points_per_side': 64, # optional 
        'iou_threshold': 0.1, # optional
        'cc_coverage_thr': 0, # optional
        'out_filename': './plot_boundary.geojson',
        'sam_checkpoint': "./sam_vit_h_4b8939.pth" # path to sam checkpoint
        }

# automatic detection
plot = PlotExtraction(**args)
plot_boundary = plot.automatic_detection(save=True)

m.clear_layers()
m.add_basemap("USGS NAIP Imagery")
m.add_cog_layer(plot.data_product_url, name="ortho", zoom_to_layer=True)
m.add_geojson(plot_boundary, layer_name="Plot boundary", 
              style={"color": "cyan", "weight": 1, "fill": False}, zoom_to_layer=True)
m

Loaded image: ./test.tif
Loaded SAM automatic maskgenerator: points per side=32, device=cuda
Resized image: (1060, 2121)
Estimated orientation angle: 25.02 degree
Loaded SAM automatic maskgenerator: points per side=64, device=cuda
Initial plots: 286
Loaded SAM predictor
Refined plots: 286
Assigned rows and columns
Process completed


Map(center=[42.44627427344795, -76.43950254217863], controls=(ZoomControl(options=['position', 'zoom_in_text',…

# Manual operation

## Remove plot

In [27]:
id = [226]
gdf_final = plot.delete(id)
gdf_geojson = plot.to_geojson(gdf_final, rotation=False)

In [28]:
m.clear_layers()
m.add_basemap("USGS NAIP Imagery")
m.add_cog_layer(plot.data_product_url, name="ortho", zoom_to_layer=True)
m.add_geojson(gdf_geojson, layer_name="Manually removed plot boundary", 
              style={"color": "cyan", "weight": 1, "fill": False}, zoom_to_layer=True)
m

Map(bottom=678.0, center=[np.float64(42.44628668985824), np.float64(-76.43950672526731)], controls=(ZoomContro…

## Add plot

In [43]:
gdf_final = plot.add(m.draw_control.last_draw)
gdf_final.set_crs(f"EPSG:{plot.epsg}", inplace=True)
gdf_geojson = plot.to_geojson(gdf_final, rotation=False)

Loaded SAM predictor


In [44]:
m.clear_layers()
m.add_basemap("USGS NAIP Imagery")
m.add_cog_layer(plot.data_product_url, name="ortho", zoom_to_layer=True)
m.add_geojson(gdf_geojson, layer_name="Manually removed plot boundary", 
              style={"color": "cyan", "weight": 1, "fill": False}, zoom_to_layer=True)
m

Map(bottom=99198775.0, center=[42.44625051817733, -76.43948011100294], controls=(ZoomControl(options=['positio…

# Export to GeoJSON file

In [None]:
gdf_final.to_crs('EPSG:4326').to_file('plot_boundary.geojson', driver='GeoJSON')

# Step by Step Example

In [None]:
# define arguments
args = {
        'base_layer': data_product, 
        'api_key': api_key,
        'clipped_filename': in_filename,
        'clip_boundary': m.draw_control.last_draw, #optional
        'n_rows': n_rows,
        'n_cols': n_cols,
        'plot_width': plot_width,
        'plot_height': plot_height,
        'resize': (1024,1024), # optional
        'points_per_side': 64, # optional 
        'iou_threshold': 0.1, # optional
        'cc_coverage_thr': 0, # optional
        'out_filename': 'plot_boundary.geojson',
        'sam_checkpoint': "/data/hans/segment-anything/sam_vit_h_4b8939.pth" # manual download
        }

plot = PlotExtraction(**args)

# Visualize the base layer

In [None]:
m = leafmap.Map()
m.clear_layers()
m.add_basemap("USGS NAIP Imagery")
m.add_cog_layer(f"{data_product.url}?API_KEY={api_key}", name="ortho", zoom_to_layer=True)
m

# Load image and rotate if needed

In [None]:
# load image to the plot object
plot.load_image()

# load sam model and get initial plots
processing_time = 0
sam_checkpoint = "/data/hans/segment-anything/sam_vit_h_4b8939.pth"
start = time.time()
plot.load_sam(sam_checkpoint, points_per_side=16)
masks = plot.get_masks()
processing_time += time.time() - start
print(f"Processing time: {processing_time:.2f} seconds")

# rotate plot if needed
start = time.time()
img_rotated = plot.rotate_plot()
processing_time += time.time() - start
print(f"Processing time: {processing_time:.2f} seconds")

# visualize the results
# from skimage.color import label2rgb

# plt.figure(figsize=(5, 15))
# plt.imshow(plot.img_array)
# plt.imshow(label2rgb(masks), alpha=0.4)
# plt.xticks([])
# plt.yticks([])
# plt.show()

# plt.imshow(img_rotated)
# plt.xticks([])
# plt.yticks([])
# plt.show()

# Get initial plots

In [None]:
start = time.time()
plot.load_sam(sam_checkpoint, points_per_side=100)
initial_plots = plot.initial_plots()
processing_time += time.time() - start
print(f"Processing time: {processing_time:.2f} seconds")

initial_plots.set_crs(f'EPSG:{plot.epsg}', inplace=True)
gdf_geojson = plot.to_geojson(initial_plots, rotation=True)
m.add_geojson(gdf_geojson, layer_name="Initial plot boundary", 
              style={"color": "red", "weight": 1, "fill": False}, zoom_to_layer=True)
m

# Grid filling

In [None]:
start = time.time()
plot.load_sam(plot.sam_checkpoint, type='manual')
gdf_filled = plot.grid_filling()
processing_time += time.time() - start
print(f"Processing time: {processing_time:.2f} seconds")
gdf_filled.set_crs(f'EPSG:{plot.epsg}', inplace=True)
gdf_geojson = plot.to_geojson(gdf_filled, rotation=True)

m.add_geojson(gdf_geojson, layer_name="Filled plot boundary", 
              style={"color": "yellow", "weight": 1, "fill": False}, zoom_to_layer=True)
m

# Grid remove

In [None]:
start = time.time()
gdf_removed = plot.grid_remove(gdf_filled)
gdf_final = plot.assign_row_col(gdf_removed)
processing_time += time.time() - start
print(f"Processing time: {processing_time:.2f} seconds")
gdf_final.set_crs(f'EPSG:{plot.epsg}', inplace=True)
gdf_geojson = plot.to_geojson(gdf_final, rotation=True)

In [None]:
gdf_geojson = plot.to_geojson(gdf_final, rotation=True)
m.add_geojson(gdf_geojson, layer_name="Refined plot boundary", 
              style={"color": "cyan", "weight": 2, "fill": False}, zoom_to_layer=True)
m

# Export to GeoJSON file

In [None]:
gdf_final.to_file('plot_boundary.geojson', driver='GeoJSON')