## Lettuce Detection with Detecto

This notebook is a example of how to use the Detecto python library to detect lettuce in aerial drone images. The Detecto library is a PyTorch-based library that simplifies the process of training object detection models and then predicting the objects in imagery. 

This notebook accomplishes the following tasks:
1. Connects to Data to Science platform (https://ps2.d2s.org) to access the drone imagery. The dataset is an orthomosaic of a lettuce field near Yuma, Arizona.
2. Downloads (from Huggingface) an object detection machine learning model that has been fine-tuned to detect lettuce. The model is based on the Faster R-CNN architecture. It was trained by PhytoOracle, a research group at the University of Arizona. It was trained on Maricopa Ag Center gantry images at very high-resolution (millimeters). It is trained to use RGB images and put bounding boxes around lettuce plants.
3. Outputs the detected lettuce bounding boxes as a polygon shapefile layer with the same geographic coordinate system as the input drone image. 
4. Demonstrates how to fine-tune (training) the model on a small dataset of annotated lettuce images.

In [None]:
#Import libraries for D2S and leafmap

from datetime import date

from d2spy.workspace import Workspace

import os

import leafmap

In [None]:
#Import necessary modules from Detecto python library
from detecto.core import Model #bring in the Faster R-CNN ResNet50 FPM model
from detecto.utils import read_image
from detecto.visualize import show_labeled_image

In [None]:
#Import more useful libraries

import os
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import geopandas as gpd
from shapely.geometry import box
import rasterio

In [None]:
# Connect to Purdue hosted D2S instance. It will prompt for your D2S password.
workspace = Workspace.connect("https://ps2.d2s.org", "jgillan@arizona.edu")

In [None]:
# Get list of all your projects in D2S
projects = workspace.get_projects()

# Check if there are any projects
if len(projects) > 0:
    # Loop through all projects and print each one
    for project in projects:
        print(f"ID: {project.id}")
        print(f"Title: {project.title}")
        print(f"Description: {project.description}\n")
else:
    print("Please create a project before proceeding with this guide.")


In [None]:
###Choose a project (from the previous print) and list all flights for that project

# Define the project ID you're looking for
#project_id = "46669dd1-3c9a-487e-adaa-92dd50bf0420"
project_id = "cfb77a4a-a065-400e-ae12-2cf11bf7c25b"


# Find the project by ID
selected_project = None
for project in projects:
    if project.id == project_id:
        selected_project = project
        break

# Check if the project was found
if selected_project:
    # Get list of all flights for the selected project
    flights = selected_project.get_flights()

    # Check if there are any flights
    if len(flights) > 0:
        # Loop through all flights and print each one
        for flight in flights:
            print(flight)
    else:
        print("No flights found for this project.")
else:
    print(f"Project with ID '{project_id}' not found.")

In [None]:
# Get list of data products from a flight. O in this case is the first flight listed.
data_products = flights[0].get_data_products()

# Check if there are any data products
if len(data_products) > 0:
    # Loop through all data products and print their URLs
    for product in data_products:
        print(product.url)
else:
    print("No data products found for this flight.")


In [None]:
## Display Interactive leafmap Map

# Set the TITILER_ENDPOINT environment variable to the D2S hosted Titiler endpoint. Titiler is a cloud optimized GeoTIFF tile server.
os.environ["TITILER_ENDPOINT"] = "https://tt.d2s.org"

m = leafmap.Map()

# URL for a D2S hosted GeoTIFF data product
ortho_url = "https://ps2.d2s.org/static/projects/cfb77a4a-a065-400e-ae12-2cf11bf7c25b/flights/20812f2d-4b31-45b0-9b2f-66595aac16fa/data_products/ed454526-6e85-49f8-a274-fafda388abdd/b1cdf543-0dfd-4da7-b258-edb2fad2f5bd.tif"

# Add a publicly available data product to the map
m.add_cog_layer(ortho_url, name="Orthomosaic")

# If you want to display a private data product, comment out the previously line and uncomment the below m.add_cog_layer line
# Add a private data product to the map
# m.add_cog_layer(f"{ortho_url}?API_KEY={api_key}", name="DSM", colormap_name="rainbow")

# Display the map
m

In [None]:
#Download the image from D2S to the local directory
#!wget https://ps2.d2s.org/static/projects/46669dd1-3c9a-487e-adaa-92dd50bf0420/flights/72e1b4a1-68ea-48bf-85d1-4f6315cd78bd/data_products/fcb09a00-181c-4e3c-ba89-e58c7e7a7223/3ac72e63-64fe-4713-af0b-c332b3851032.tif

#!wget https://ps2.d2s.org/static/projects/cfb77a4a-a065-400e-ae12-2cf11bf7c25b/flights/20812f2d-4b31-45b0-9b2f-66595aac16fa/data_products/ed454526-6e85-49f8-a274-fafda388abdd/b1cdf543-0dfd-4da7-b258-edb2fad2f5bd.tif

#Set the image path as a variable 
image_path = "b1cdf543-0dfd-4da7-b258-edb2fad2f5bd.tif"

# Disable the image size limit of PIL
Image.MAX_IMAGE_PIXELS = None
image = Image.open(image_path)



In [None]:
# Get the size of the image
original_width, original_height = image.size
print(f"Original Image Size: {original_width}x{original_height}")

In [None]:
#Download the fine-tuned model (lettuce detection) from Hugging Face
!wget https://huggingface.co/jgillan/phytooracle_lettuce_2/resolve/main/model_weights.pth


In [None]:
#declare the labels for the fine-tuned model
labels = [
    'lettuce'
]

In [None]:
#Load the fine-tuned model
model = Model.load('model_weights.pth', labels) 

The following code are a series of functions to:
1. Crop the drone aerial image into smaller tiles. This is necessary because the faster r-cnn model wants to process images of a certain size. By default, the maximum size of the long side of an image is around 1333 pixels.The drone orthomosasics are much too big (e.g., 13569x39850). The faster r-cnn default image size can be adjusted, but large images can cause memory issues and long processing times. If you try to use faster r-cnn on an image larger than the default, it will automatically downsample the image to the default size (e.g., 480x1344). This can cause the model to miss small objects. The better solution is to crop the image into smaller tiles and process each tile separately.

2. Detect lettuce in the cropped images. The detecto library has a function that takes an image and returns the bounding boxes of the detected objects. The function returns the bounding boxes in the format (x_min, y_min, x_max, y_max). The function also returns the confidence of the detection. The confidence is a number between 0 and 1. The higher the number, the more confident the model is that the object is a lettuce plant. The function also returns the class of the detected object. The class is a string that tells you what type of object was detected. In this case, the class is always "lettuce".

3. Display the predicted bounding boxes on the original large image. 

In [None]:
##Define a series of functions: 


# Step 1: Crop the large image into jpeg tiles
def crop_image_into_tiles(image_path, tile_size, output_folder):
    image = Image.open(image_path)
    img_width, img_height = image.size
    tiles_with_coordinates = []

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    tile_num = 0
    for y in range(0, img_height, tile_size):
        for x in range(0, img_width, tile_size):
            box = (x, y, min(x + tile_size, img_width), min(y + tile_size, img_height))
            tile = image.crop(box)

            if tile.mode == 'RGBA':  # Convert to RGB if needed
                tile = tile.convert("RGB")

            tile_path = os.path.join(output_folder, f'tile_{tile_num}.jpg')
            tile.save(tile_path, 'JPEG')
            tiles_with_coordinates.append((tile_path, (x, y)))
            tile_num += 1

    return tiles_with_coordinates

# Step 2: Run object detection predictions on each tile
def get_predictions_for_tiles(tiles_with_coordinates, model, threshold=0.3):
    all_predictions = []

    for tile_path, (tile_x, tile_y) in tiles_with_coordinates:
        tile = Image.open(tile_path)
        
        # Run prediction on tile
        labels, boxes, scores = model.predict(tile)
        
        # Store the predictions with the tile's coordinates
        all_predictions.append((labels, boxes, scores, tile_x, tile_y))

    return all_predictions

# Step 3: Adjust and display predictions on the original image
def adjust_boxes(boxes, tile_x, tile_y):
    adjusted_boxes = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        adjusted_box = [x_min + tile_x, y_min + tile_y, x_max + tile_x, y_max + tile_y]
        adjusted_boxes.append(adjusted_box)
    return adjusted_boxes

def display_original_image_with_boxes(image_path, all_predictions, threshold=0.3):
    image = Image.open(image_path)
    fig, ax = plt.subplots(1, figsize=(100, 100))
    ax.imshow(image)

    for labels, boxes, scores, tile_x, tile_y in all_predictions:
        adjusted_boxes = adjust_boxes(boxes, tile_x, tile_y)

        for i, box in enumerate(adjusted_boxes):
            if scores[i] > threshold:
                x_min, y_min, x_max, y_max = box
                rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=3, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
                #label_text = f'{labels[i]}: {scores[i]:.2f}' # Uncomment these lines to display labels and scores
                #plt.text(x_min, y_min - 10, label_text, color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))

    plt.axis('off')
    plt.show()



In [None]:
##Input some parameters and run the functions

# Define the output folder for the cropped images and tile size
output_folder = 'cropped_tiles'
tile_size = 1024  # Adjust as needed

# Run the tiling function
tiles_with_coordinates = crop_image_into_tiles(image_path, tile_size, output_folder)

### Run the prediction function on the tiles ### This will take some time to run
all_predictions = get_predictions_for_tiles(tiles_with_coordinates, model, threshold=0.15)


In [None]:
# Display all predictions on the original large image
display_original_image_with_boxes(image_path, all_predictions, threshold=0.3)

In [None]:
#Define function to create a georeferenced shapefile from the predictions

def create_georeferenced_shapefile(all_predictions, image_path, output_shapefile, threshold):
    # Open the original image to retrieve CRS and affine transform
    with rasterio.open(image_path) as src:
        crs = src.crs  # CRS of the original image (EPSG:32611)
        transform = src.transform  # Affine transformation matrix for the full image

    # List to hold each bounding box with its label and score
    polygons = []

    # Collect bounding boxes in terms of full image coordinates (not tile-based)
    for labels, boxes, scores, tile_x, tile_y in all_predictions:
        for i, box_coords in enumerate(boxes):
            if scores[i] >= threshold:  # Apply confidence threshold
                # Calculate full-image coordinates for each bounding box
                x_min, y_min, x_max, y_max = box_coords
                x_min += tile_x
                y_min += tile_y
                x_max += tile_x
                y_max += tile_y

                # Create a Polygon in image (pixel) coordinates
                polygon = box(x_min, y_min, x_max, y_max)
                polygons.append({
                    'geometry': polygon,
                    'label': labels[i],
                    'score': scores[i]
                })

    # Create a GeoDataFrame with pixel-based coordinates
    gdf_pixel = gpd.GeoDataFrame(polygons, crs="EPSG:32611")

    # Apply affine transform to convert pixel coordinates to geographic coordinates
    gdf_pixel['geometry'] = gdf_pixel['geometry'].apply(
        lambda geom: transform_polygon(geom, transform)
    )

    # Set CRS and save to shapefile
    gdf_pixel.set_crs(crs, inplace=True)
    gdf_pixel.to_file(output_shapefile, driver="ESRI Shapefile")

def transform_polygon(geometry, transform):
    # Convert each coordinate in the Polygon to geographic coordinates using affine transform
    transformed_coords = [(transform * (x, y)) for x, y in geometry.exterior.coords]
    return box(*transformed_coords[0], *transformed_coords[2])




In [None]:
#Execute the function to create the georeferenced shapefile
output_shapefile = "predicted_please_work.shp"
create_georeferenced_shapefile(all_predictions, image_path, output_shapefile, threshold=0.15)

In [None]:
# Removes access token from future requests to D2S
workspace.logout()

## Fine-tune a model with training data

The detecto library also has a function to fine-tune a model with a small dataset of annotated images. For image labeling, I would recommend an open-source program called [Label Studio](https://labelstud.io/). In Label Studio, I can bring in the tile images (non-georeferenced), draw bounding boxes around the lettuce plants, and export the annotations in several formats. The detecto library can read the annotations in the Pascal VOC format.

In [None]:
# Import necessary module
from detecto.core import Dataset
from detecto.core import DataLoader, Model

In [None]:
# create the datasets from the training annotations and images
dataset = Dataset(label_data='./lettuce_training/Annotations/', image_folder='./lettuce_training/images/')

In [None]:
# Visualize the first training image and its bounding boxes
image, target = dataset[0]
show_labeled_image(image, target['boxes'], target['labels'])

In [None]:


# wrap the training set in a DataLoader
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

# specify all unique labels you're trying to predict
labels = [
    'lettuce'
]

### Fine-tuning choices

 We are going to fine-tune train the existing lettuce detection model. You have the choice to train the entire model or just the final layer. Training the entire model will take longer and require more data. Training just the final layer will be faster and require less data. The final layer is the layer that makes the final prediction of lettuce. To improve the model to detect your specific lettuce, it may make more sense to fine-tune just the final few layers. 

 

In [None]:

## You can skip this step if you want train the entire model

# Set up existing model to train just the Region Proposal Network (RPN) and ROI Heads (last layers)
torch_model = model.get_internal_model()

for name, p in torch_model.named_parameters():
    print(name, p.requires_grad)

    if 'roi_heads' not in name and 'rpn' not in name:
        p.requires_grad = False

In [None]:
##TRAINING RUN!!
# initialize a model with the target labels and fit the model
# This step will take some time and is accelerated with GPU
losses = model.fit(train_loader,
                   epochs=8,
                   lr_step_size=5,
                   learning_rate=0.001,
                   verbose=True)

In [None]:
# Test the newly fine-tuned model to predict lettuce in a single image tile

image = read_image('./cropped_tiles/tile_246.jpg')
labels, boxes, scores = model.predict(image)
show_labeled_image(image, boxes, labels)