# Automatic grain detection for large images

In [1]:
# Pip imports
import keras.utils
import matplotlib.pyplot as plt
import numpy as np
import segment_anything
import segmenteverygrain
import segmenteverygrain.interactions as si
from tqdm import tqdm

# Bypass large image restriction
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

# Display plots in notebook
%matplotlib inline

2025-04-18 13:04:52.593984: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745006692.606768   20960 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745006692.610477   20960 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-18 13:04:52.622638: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Load models

Images can be loaded in most common formats. Avoid lossy compression for best results.

A unet model (ending in `.keras`) should be included in the `segmenteverygrain` folder.

SAM checkpoints (`.pth`) can be downloaded [here](https://huggingface.co/ybelkada/segment-anything/blob/main/checkpoints/sam_vit_h_4b8939.pth).

In [2]:
# Load U-Net model
fn = './segmenteverygrain/seg_model.keras'
unet = keras.saving.load_model(fn, custom_objects={
    'weighted_crossentropy': segmenteverygrain.weighted_crossentropy})

2025-04-18 13:04:57.164537: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [3]:
# Load SAM
fn = 'sam_vit_h_4b8939.pth'
sam = segment_anything.sam_model_registry['default'](checkpoint=fn)

### Grain segmentation
The image to be segmented is loaded inside the predict_large_image function, so it only requires a filename, not the loaded image.

The Segment Anything Model (SAM) attempts to investigate each prompt generated by the U-Net model, outlining the "segment" of the image representing each grain. Resulting polygons are then cleaned up to avoid duplicates and find the best fit for each grain.

Verify that `min_area` is reasonable as a lower bound for the pixel area of detected grains.

TODO: `patch size` and `overlap` guidance

In [4]:
fn = 'examples/torrey_pines.jpg'
polygons, image_pred, all_coords = segmenteverygrain.predict_large_image(fn, unet, sam, min_area=400.0, patch_size=2000, overlap=200)
print(type(image_pred))

segmenting image tiles...


100%|██████████| 7/7 [00:02<00:00,  2.82it/s]
100%|██████████| 6/6 [00:01<00:00,  3.12it/s]


creating masks using SAM...


100%|██████████| 1041/1041 [01:02<00:00, 16.53it/s]


finding overlapping polygons...


872it [00:05, 159.39it/s]


finding best polygons...


100%|██████████| 333/333 [00:09<00:00, 36.23it/s]


creating labeled image...
processed patch #1 out of 1 patches


400it [00:00, 641.77it/s]
0it [00:00, ?it/s]

<class 'numpy.ndarray'>





### Results

Polygons detected by the SAM are used to construct Grain objects, which provide methods for measuring and extracting data.

In [5]:
# Load image for measurement, since it's unloaded after large image segmentation
fn = 'examples/torrey_pines.jpg'
image = si.load_image(fn)

# Extract results
grains = si.polygons_to_grains(polygons, image=image)
for g in tqdm(grains):
    g.measure()

100%|██████████| 400/400 [00:00<00:00, 939.74it/s]


The following results are then saved to the location specified in `out_fn`:
- Grain shapes, for use elsewhere (geojson)
- Summary data, presenting measurements for each detected grain (csv)
- Summary histogram, representing major/minor axes of detected grains (jpg)
- Mask representations of the detected grains, in both computer-readable (png, 0-1) and human-readable (jpg, 0-255) formats

If `px_per_m` is 1, then the summary data and histogram will be in pixels. If the ratio of pixels to meters is known, set `px_per_m` in order to save them in meters.

In [6]:
# Save results
px_per_m = 3390.    # Arbitrary; change this if known!
out_fn = 'examples/auto_large/torrey_pines'
# Grain shapes
si.save_grains(out_fn + '_grains.geojson', grains)
# Summary data
summary = si.save_summary(out_fn + '_summary.csv', grains, px_per_m=px_per_m)
# Summary histogram
si.save_histogram(out_fn + '_summary.jpg', summary=summary)
# Training mask
si.save_mask(out_fn + '_mask.png', grains, image, scale=False)
si.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)