<h1>
  <b>[STEP 1]</b> SAVE BOTH OF THESE FOLDERS AS <u>SHORTCUTS</u> TO YOUR "MY DRIVE":
  <a href="https://drive.google.com/drive/folders/1695PT8xzD3LyDRcd3fTocM7uEBVMm-gW?usp=sharing" target="_blank">Specimen Images</a> &
  <a href="https://drive.google.com/drive/folders/1hV1xIqXvEzKdtaawIy-H4K-9SbmWZwoy?usp=drive_link" target="_blank">Segmented Images</a>
</h1>


**CLONE REPO:**

In [None]:
%git clone https://github.com/jescuti/deepplant.git

In [None]:
%cd deepplant
%cd preprocessing

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Segment Images

**DOWNLOAD SAM:**

In [None]:
%pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'
%pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision==0.23.0

In [None]:
%mkdir -p {HOME}/weights
%wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

**OPEN IMAGES:**

In [None]:
from google.colab import drive
import os
import zipfile

####################
# CONNECT TO DRIVE #
####################
drive.mount('/content/drive')

zip_file_path = '/content/drive/MyDrive/herbarium_images.zip'
extract_path = '/content/drive/MyDrive/herbarium_images'

###########################
# UNZIP ONLY IF NEEDED    #
###########################
if not os.path.exists(extract_path) or len(os.listdir(extract_path)) == 0:
    print("Extracting zip file")
    os.makedirs(extract_path, exist_ok=True)
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print(f"Unzipped to {extract_path}")
else:
    print(f"Already extracted at {extract_path}")

**RUN SAM LABEL SEGMENTATION:**

In [None]:
%pip install opencv-python torch segment-anything glob tqdm

In [None]:
import torch
import cv2
import os
import numpy as np
import gc
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from glob import glob
from tqdm import tqdm

# clear cache
torch.cuda.empty_cache()

# save memory w/ lower precision (can go up to 'highest' or 'high')
torch.set_float32_matmul_precision('medium')

##############
# SET UP SAM #
##############
sam_checkpoint = os.path.join("{HOME}", "weights", "sam_vit_h_4b8939.pth")
model_type = "vit_h"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print("Setting up SAM!")
sam = sam_model_registry[model_type](checkpoint = sam_checkpoint)
sam.to(device = device)

mask_generator = SamAutomaticMaskGenerator(
    model = sam,
    points_per_side = 32,
    pred_iou_thresh = 0.7,
    stability_score_offset = 0.7,
    crop_n_layers = 1,
    crop_n_points_downscale_factor = 2,
    min_mask_region_area = 50,
    output_mode = "binary_mask"
)

image_folder = "/content/drive/MyDrive/herbarium_images"
# image_folder = "/content/drive/MyDrive/webscraped_images"
image_paths = glob(os.path.join(image_folder, '**', '*.jpg'), recursive = True)

In [None]:
##########################################
# BARCODE AND HEADER DETECTION FUNCTIONS #
##########################################

def analyze_background_color(image):
    """
    Analyze image to determine if it has a dark background (most of the
    specimen headers do and we don't want to falsely detect those)
    Returns `true` if image has mostly dark background
    """
    # grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # calc histogram
    hist = cv2.calcHist([gray], [0], None, [256], [0, 256])

    # darkness distribution (counts pixels w/ intensity < 50)
    dark_pixels = np.sum(hist[:50])
    total_pixels = gray.shape[0] * gray.shape[1]

    # of more than 40% of pixels are very dark,
    # consider it a "dark background"
    dark_ratio = dark_pixels / total_pixels

    return dark_ratio > 0.4


#############################################
# CITATION: Rosebrock, A. (2014)            #
# "The Ultimate Guide to Barcode Detection" #
#############################################
def detect_barcode(image):
    """
    Detect typical barcode patterns using gradients
    Returns True if prob a barcode
    """
    # grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # gradients in x direction (barcodes have strong horizontal gradients)
    # then, scaling to 8b
    # (Rosebrock pg. 4)
    gradX = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx = 1, dy = 0, ksize = -1)
    gradX = np.absolute(gradX)

    (minVal, maxVal) = (np.min(gradX), np.max(gradX))
    if maxVal == minVal:
        return False

    gradX = (255 * ((gradX - minVal) / (maxVal - minVal))).astype("uint8")

    # threshold and morphology operations (Rosebrock pg. 6)
    gradX = cv2.morphologyEx(gradX, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (21, 7)))
    thresh = cv2.threshold(gradX, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]

    # count horizontal lines
    horizontal_count = np.sum(thresh > 0, axis = 1)
    max_horizontal = np.max(horizontal_count) if len(horizontal_count) > 0 else 0

    # look for dense horizontal lines
    return max_horizontal > image.shape[1] * 0.5

def is_likely_header_or_barcode(cropped_img, y_position, img_height):
    """
    Determines if a region is likely a header / barcode:
    1. position in the image (headers at the top)
    2. barcode detection
    3. black background detection (herbarium img headers)
    """
    # 1. check position if in top 20% of img
    position_ratio = y_position / img_height

    # check for dark background
    gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
    mean_intensity = np.mean(gray)

    # dark background = likely a header
    has_dark_background = mean_intensity < 100

    if position_ratio < 0.2 and has_dark_background:
        return True

    if position_ratio < 0.2:
        # additional checks for top regions
        # 2. look for barcode-like properties
        if detect_barcode(cropped_img):
            return True

        # 3. density

        ##########################################################
        # CITATION: Murzova, A. (2020)                           #
        # https://learnopencv.com/otsu-thresholding-with-opencv/ #
        ##########################################################
        _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

        h, w = cropped_img.shape[:2]
        aspect_ratio = w / h if h > 0 else 0

        pixels_per_row = np.sum(binary > 128, axis = 1)
        std_dev = np.std(pixels_per_row)
        mean = np.mean(pixels_per_row)

        # headers have uniform density across rows
        uniformity = std_dev / (mean + 1e-5)

        # headers and barcodes = high density and low variance
        if (aspect_ratio > 3.5 and uniformity < 0.5) or detect_barcode(cropped_img):
            return True

    elif has_dark_background:
        # text-on-dark-background
        bright_pixel_count = np.sum(gray > 200)
        bright_pixel_ratio = bright_pixel_count / (cropped_img.shape[0] * cropped_img.shape[1])

        # check for bright text
        if 0.05 < bright_pixel_ratio < 0.4:
            return True

    return False

############################
# TEXT DETECTION FUNCTIONS #
############################
def detect_text_regions(image):
    """
    Identify potential text regions (usually rectangular)
    """
    ######################################################################
    # CITATIONS: Rosebrock, A. (2021)                                    #
    # https://learnopencv.com/otsu-thresholding-with-opencv/             #
    # Yadav, A. (2024)                                                   #
    # https://medium.com/%40amit25173/opencv-text-detection-8e298e2b5218 #
    ######################################################################

    # grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                  cv2.THRESH_BINARY_INV, 11, 2)

    # kernel
    rect_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 3))

    # connect text characters
    dilation = cv2.dilate(thresh, rect_kernel, iterations=1)

    # find contours
    contours, _ = cv2.findContours(dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    text_regions = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)

        # aspect ratio: text labels tend to be wider than they are tall
        aspect_ratio = w / float(h)

        # filter by area
        area = w * h

        # label aspect ratio > 1.5
        if aspect_ratio > 1.5 and 1000 < area < 100000:
            text_regions.append((x, y, w, h))

    return text_regions

def is_rectangular(mask, threshold = 0.75):
    """
    Check if mask is approx. rectangular by comparing its area
    with the area of its bounding box
    """
    # get mask area
    mask_area = np.sum(mask)

    # get bounding box
    y_indices, x_indices = np.where(mask)
    if len(y_indices) == 0 or len(x_indices) == 0:
        return False

    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    box_width = x_max - x_min + 1
    box_height = y_max - y_min + 1
    box_area = box_width * box_height

    # fullness ratio
    fullness = mask_area / box_area

    # check aspect ratio for rectangule
    aspect_ratio = box_width / max(box_height, 1)

    return fullness > threshold and aspect_ratio > 1.5

def has_text_characteristics(img):
    """
    See if it has characteristics of a text label rather than barcode/header
    """
    # grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # check for dark background
    mean_intensity = np.mean(gray)
    if mean_intensity < 100:
        bright_pixels = np.sum(gray > 200)
        bright_ratio = bright_pixels / (img.shape[0] * img.shape[1])
        if 0.05 < bright_ratio < 0.4:
            return False

    # calc variance of pixel values (text regions = higher variance)
    var = np.var(gray)

    # Canny edge detection to analyze edge patterns
    edges = cv2.Canny(gray, 100, 200)
    edge_density = np.sum(edges > 0) / (img.shape[0] * img.shape[1])

    # calc connected components
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(thresh, connectivity=8)

    # check for text-like properties
    has_multiple_components = num_labels > 5
    reasonable_edge_density = 0.03 < edge_density < 0.15
    high_variance = var > 500

    # count components with reasonable size
    reasonable_components = 0
    for i in range(1, num_labels):
        component_area = stats[i, cv2.CC_STAT_AREA]
        if 10 < component_area < 500:
            reasonable_components += 1

    has_reasonable_components = reasonable_components > 3

    # combine criteria
    return high_variance and (has_multiple_components or has_reasonable_components or reasonable_edge_density)

In [None]:
########################
# MAIN PROCESSING LOOP #
########################
output_folder = "/content/drive/MyDrive/segmented_images"
os.makedirs(output_folder, exist_ok = True)

def resize_if_needed(image, max_dim = 1500):
    """Resize image if it's too large for memory"""
    h, w = image.shape[:2]
    if max(h, w) > max_dim:
        scale = max_dim / max(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        return cv2.resize(image, (new_w, new_h))
    return image

def process_single_image(img_path):
    base_name = os.path.basename(img_path).split('.')[0]
    image_id = base_name.split('_')[-1]

    # SKIP if image already cropped
    existing_crops = glob(os.path.join(output_folder, f"{image_id}_*.jpg"))
    if existing_crops:
        print(f"Skipping: {img_path}")
        return

    try:
        # read image
        print(f"Processing image: {img_path}")
        img = cv2.imread(img_path)
        if img is None:
            print(f"Failed to load: {img_path}")
            return

        ##############################################
        # RESIZE LARGE IMAGES BC OUT OF MEMORY ERROR #
        ##############################################
        img = resize_if_needed(img)
        image_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_height = img.shape[0]

        with torch.no_grad():
            masks = mask_generator.generate(image_rgb)

        saved_count = 0
        skipped_count = 0

        # process SAM masks
        # focus on rectangular shapes
        for i, mask in enumerate(masks):
            if is_rectangular(mask["segmentation"]):
                x, y, w, h = mask["bbox"]

                # filter by reasonable size for text labels
                if 1000 < mask["area"] < 100000:
                    padding = 10
                    x_pad = max(0, x - padding)
                    y_pad = max(0, y - padding)
                    w_pad = min(img.shape[1] - x_pad, w + 2*padding)
                    h_pad = min(img.shape[0] - y_pad, h + 2*padding)

                    cropped_img = img[y_pad:y_pad+h_pad, x_pad:x_pad+w_pad]

                    # check if header / barcode
                    if is_likely_header_or_barcode(cropped_img, y_pad, img_height) or analyze_background_color(cropped_img):
                        skipped_count += 1
                        continue

                    # check for text-like characteristics
                    if has_text_characteristics(cropped_img):
                        cropped_img_name = f"{image_id}_{i}.jpg"

                        #######################
                        # SAVE CROPPED IMAGES #
                        #######################
                        save_path = os.path.join(output_folder, cropped_img_name)
                        cv2.imwrite(save_path, cropped_img)
                        saved_count += 1
                    else:
                        skipped_count += 1

        print(f"Total regions saved: {saved_count}; Skipped: {skipped_count}")

    except Exception as e:
        print(f"Error processing {img_path}: {str(e)}")

    # clean up to free memory
    gc.collect()
    torch.cuda.empty_cache()

# prevent memory accumulation => small batches
batch_size = 1
for i in tqdm(range(4754, len(image_paths), batch_size)):
    batch = image_paths[i:i+batch_size]
    for img_path in batch:
        process_single_image(img_path)

    # garbage collection between batches
    gc.collect()
    torch.cuda.empty_cache()

# Image Clustering

In [None]:
%pip install -U scikit-learn fiftyone

In [None]:
%fiftyone plugins download https://github.com/jacobmarks/clustering-plugin

In [None]:
%pip install umap-learn git+https://github.com/openai/CLIP.git

In [None]:
import fiftyone as fo
import fiftyone.brain as fob
from fiftyone import ViewField as F
import os
from PIL import Image
import clip
import torch
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import numpy as np

# takes in a directory to the segmented labels
# returns clustered dataset
def cluster_dataset(segmented_labels_dir):
   image_files = [os.path.join(segmented_labels_dir, f) for f in os.listdir(segmented_labels_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
   dataset = fo.Dataset.from_images(image_files, name='labels_test', overwrite=True)
   session = fo.launch_app(dataset)
   # compute features
   res = fob.compute_visualization(
        dataset,
        model="clip-vit-base32-torch",
        embeddings="clip_embeddings",
        method="umap",
        brain_key="clip_vis",
        batch_size=10
    )
   dataset.set_values("clip_umap", res.current_points)
   return dataset

# takes in a path to the search image and dataset, as well as k, the number of results to return
# returns top k similar images to the search image, as well as their similarity scores
def query_image(image_path, dataset, k):
   model_name = "ViT-B/32"
   device = "cuda" if torch.cuda.is_available() else "cpu"

   #get image embeddings
   model, preprocess = clip.load(model_name, device=device)
   image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
   with torch.no_grad():
      image_embedding = model.encode_image(image).cpu().numpy()

   #get dataset embeddings
   dataset_embeddings = dataset.values("clip_embeddings")
   sample_ids = dataset.values("id")

   #calculate similarity
   similarity_scores = cosine_similarity(image_embedding, np.array(dataset_embeddings))[0]

   #get top k most similar images
   sorted_indices = np.argsort(similarity_scores)[::-1]
   top_k_indices = sorted_indices[:k]

   top_similar_images = []
   top_similarity_scores = []
   for i in top_k_indices:
      sample = dataset[sample_ids[i]]
      img = Image.open(sample.filepath)
      plt.imshow(img)
      plt.show()
      top_similar_images.append(img)
      top_similarity_scores.append(similarity_scores[i])

   return top_similar_images, similarity_scores

In [None]:
%pip install torch torchvision flask flask-cors pillow numpy scikit-learn fiftyone umap-learn
%pip install git+https://github.com/openai/CLIP.git
%pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision==0.23.0

# Web Scrape Images (Optional)

**Collects 500 Random Images from the Collection**

In [None]:
import os
import requests
import re
import time
import random
from tqdm import tqdm
from bs4 import BeautifulSoup
import shutil

save_path = '/content/drive/MyDrive/webscraped_images'
os.makedirs(save_path, exist_ok = True)

def create_filename(name):
    # create a valid filename
    return "".join(c if c.isalnum() or c in " ._-" else "_" for c in name).strip()

def extract_image_url_from_html(html_content):
    """
    extract the actual image URL from the HTML content
    """
    try:
        soup = BeautifulSoup(html_content, 'html.parser')
        # IIIF => web address with the image, metadata,
        # and the image "web viewer"
        script_tags = soup.find_all('script')
        for script in script_tags:
            if script.string and 'tileSources' in script.string:
                # extract URL using regex
                match = re.search(r'"(https://repository\.library\.brown\.edu/iiif/image/bdr:[^/]+/info\.json)"', script.string)
                if match:
                    iiif_info_url = match.group(1)
                    # convert info.json URL to full image URL
                    return iiif_info_url.replace('/info.json', '/full/full/0/default.jpg')
        return None
    except Exception as e:
        print(f"Error parsing HTML: {e}")
        return None

def download_image(url, filepath):
    """
    download an image from the URL and save to filepath
    """
    try:
        response = requests.get(url, stream = True)
        response.raise_for_status()

        with open(filepath, 'wb') as f:
            for chunk in response.iter_content(8192):
                f.write(chunk)

        if os.path.getsize(filepath) == 0:
            print(f"WARNING: Downloaded file is empty: {filepath}")
            return False
        return True
    except Exception as e:
        print(f"Error downloading image: {e}")
        return False

def get_total_item_count():
    """
    get the total number of items in the collection
    """
    api_url = (
        "https://repository.library.brown.edu/api/search/"
        "?q=rel_is_member_of_collection_ssim:bdr:nz9qn2kb"
        "&rows=1&wt=json"
    )
    try:
        response = requests.get(api_url).json()
        return response.get("response", {}).get("numFound", 0)
    except Exception as e:
        print(f"Error getting total item count: {e}")
        return 0

def fetch_random_sample(sample_size = 100, min_index = 4401, base_dir = "test_images"):
    """
    fetch a random sample of herbarium images
    and save them to folders

    starts at 4,401 because we already downloaded all
    the images before that #
    """
    os.makedirs(base_dir, exist_ok = True)

    total_items = get_total_item_count()
    if total_items == 0:
        print("ERROR: couldn't determine the total number of items in the collection.")
        return 0, 0

    print(f"collection has {total_items} TOTAL items")

    # calculate available range for random sampling
    available_range = total_items - min_index
    if available_range <= 0:
        print("no more items available for sampling")
        return 0, 0

    # generate random indices
    if sample_size > available_range:
        print(f"Requested sample size {sample_size} exceeds available items {available_range}.")
        sample_size = available_range

    random_indices = random.sample(range(min_index, total_items), sample_size)

    total_processed = 0
    total_errors = 0

    # process each random index
    for index in tqdm(random_indices, desc = "Processing random samples"):
        # fetch single item at the random index
        api_url = (
            "https://repository.library.brown.edu/api/search/"
            "?q=rel_is_member_of_collection_ssim:bdr:nz9qn2kb"
            f"&start={index}&rows=1&wt=json"
        )

        try:
            response = requests.get(api_url).json()
            docs = response.get("response", {}).get("docs", [])

            if not docs:
                print(f"no item found at INDEX: {index}")
                continue

            # process item
            item = docs[0]
            pid = item.get("pid")
            if not pid:
                continue

            # get collector name / unknown if not available
            collector = item.get("dwc_recorded_by_ssi", "Unknown_Collector")
            collector = create_filename(collector)

            # scientific name for image name, if available
            scientific_name = item.get("dwc_scientific_name_ssi", "")
            scientific_name = create_filename(scientific_name) if scientific_name else ""

            # make title for file name
            title = item.get("primary_title", "Untitled")
            title = create_filename(title)

            # ceate folder for that collector
            # collector_dir = os.path.join(base_dir, collector)
            # os.makedirs(collector_dir, exist_ok = True)

            # make the image filename
            if scientific_name:
                filename = f"{scientific_name}_{pid.replace(':', '_')}.jpg"
            else:
                filename = f"{title}_{pid.replace(':', '_')}.jpg"

            # filepath = os.path.join(collector_dir, filename)
            filepath = os.path.join(base_dir, filename)

            # SKIP if file exists
            if os.path.exists(filepath):
                print(f"File already exists: {filepath}")
                continue

            # create URL to open the image view => then, fetch the HTML
            viewer_url = f"https://repository.library.brown.edu/viewers/image/zoom/{pid}"
            try:
                print(f"Fetching HTML from {viewer_url}")
                html_response = requests.get(viewer_url)
                html_response.raise_for_status()

                # extract image URL from HTML
                image_url = extract_image_url_from_html(html_response.text)
                if not image_url:
                    # direct IIIF URL
                    image_url = f"https://repository.library.brown.edu/iiif/image/{pid}/full/full/0/default.jpg"
                    print(f"Using alternative URL: {image_url}")

                # download actual image as JPG
                if download_image(image_url, filepath):
                    print(f"Successfully downloaded {filename} to {save_path}")
                    total_processed += 1
                else:
                    print(f"Failed to download image for {pid}")
                    total_errors += 1

            except Exception as e:
                print(f"Error processing {pid}: {e}")
                total_errors += 1

            # don't overwhelm server
            time.sleep(0.5)

        except Exception as e:
            print(f"Error fetching item at index {index}: {e}")
            total_errors += 1
            time.sleep(5)

    print(f"Successfully processed {total_processed} images with ERRORS: {total_errors}")
    return total_processed, total_errors

if __name__ == "__main__":
    # download 500 random samples from the remaining items (after index 4400)
    fetch_random_sample(sample_size = 500, min_index = 4401, base_dir = save_path)