# Interactive grain detection/editing

This notebook demonstrates methods for manually editing the grains detected in an image, starting with an image and (optionally) a list of grains detected previously. It uses the interface defined in the "interactions" update to replace the interactive method described in [Segment_every_grain.ipynb](Segment_every_grain.ipynb). The script can be run by executing each cell sequentially, but using "Run All" 

By default, this script operates on the [Torrey Pines](examples/torrey_pines.jpg) example image and a list of grains output by [auto_detection.ipynb](auto_detection.ipynb), saved as [torrey_pines_grains.geojson](examples/auto_detection/torrey_pines_grains.geojson). It saves results in [examples/interactive_edit](examples/interactive_edit).

In [1]:
import segment_anything
from segmenteverygrain import interactions as si

# Display interactive plot in separate window
%matplotlib qt

# Init optional parameters
grains = []
image = None
predictor = None

2025-06-11 14:41:50.461256: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749678110.473467   92457 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749678110.477097   92457 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-11 14:41:50.489083: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Load data

Images can be loaded in most common formats.

SAM checkpoints can be downloaded [here](https://huggingface.co/ybelkada/segment-anything/blob/main/checkpoints/sam_vit_h_4b8939.pth).

It is possible to start from scratch with just an image and SAM. To load grains detected previously, however, load a .geojson file generated by this or other scripts.

In [2]:
# Load image
fn = 'examples/torrey_pines.jpg'
image = si.load_image(fn)

In [3]:
# Load SAM for grain detection
fn = 'models/sam_vit_h_4b8939.pth'
sam = segment_anything.sam_model_registry['default'](checkpoint=fn)
predictor = segment_anything.SamPredictor(sam)
predictor.set_image(image)

In [None]:
# Load grains
fn = 'examples/auto_detection/torrey_pines_grains.geojson'
grains = si.load_grains(fn, image=image)

### Filter grains

Optionally, sections of the cell below can be used to filter loaded grains based on any measured properties reported in the summary csv. Note that, by deault, all spatial units will be in pixels until manually converted. By default, this cell just contains some examples to modify.

Acceptable parameters are:
-   area
-   centroid
-   major_axis_length
-   minor_axis_length 
-   orientation
-   perimeter
-   max_intensity
-   mean_intensity
-   min_intensity

In [None]:
# Filter grains (optional)

# Measure grains (in pixels) to calculate grain properties before filtering
for g in grains:
    g.measure()

len_before = len(grains)
print(f'{len_before} grains loaded.')

# # Remove white grains
# grains = [g for g in grains if g.data['mean_intensity-0'] < 200]
# grains = [g for g in grains if g.data['mean_intensity-1'] < 200]
# grains = [g for g in grains if g.data['mean_intensity-2'] < 200]
# print(f'Brightness filter: Removed {len_before - len(grains)} grains.')

# # Remove grains with strong color
# len_before = len(grains)
# grains = [g for g in grains if abs(g.data['mean_intensity-0'] - g.data['mean_intensity-1']) < 20]
# print(f'Color filter: Removed {len_before - len(grains)} grains.')# 

# Remove eccentric grains
len_before = len(grains)
grains = [g for g in grains if g.data['major_axis_length'] < 2.5 * g.data['minor_axis_length']]
print(f'Eccentricity filter: Removed {len_before - len(grains)} grains.')

print(f'{len(grains)} grains remaining after filters.')

408 grains loaded.
Eccentricity filter: Removed 29 grains.
379 grains remaining after filters.


### Interactive editing

The editing interface itself is defined in segmentanything.interactions.

Navigation within the interface is described in the [matplotlib documentation](https://matplotlib.org/stable/users/explain/figure/interactive.html#interactive-navigation). Additional controls are:

- `Left click`: Select/unselect existing grain or place foreground prompt for grain detection
- `Shift + left click/drag`: Create or adjust box prompt for grain detection
- `Right click`: Place background prompt for grain detection
- `Middle click`: Display measurement information about the indicated grain
- `Middle click + drag`: Measure scale bar to calibrate pixels per meter
- `Control`: Hold to temporarily hide selected grains
- `Escape`: Remove all prompts and unselect all grains
- `c`: Use selection box and/or foreground/background prompts to detect a grain
- `d`: Delete selected (highlighted) grains
- `m`: Merge selected grains (must be touching)
- `z`: Delete the most recently-created grain

`px_per_m`: The ratio of pixels to meters, if known. This will be overwritten if a scale bar is measured in the interface using middle click & drag.

`scale_m`: The length in meters of a reference object. Once the reference object is measured using middle click & drag, size/area values will be converted to meters. The diagonal of the selection box will be taken to represent `scale_m` meters.

`image_max_size` (y, x): Images larger than this in either dimension will be downscaled for display. Operations like grain detection will still be performed on the full image, but the display will not be able to zoom in at full quality. This is a performance optimization. Reduce this size for better performance, increase this size for better visual quality when zoomed.

`image_alpha`: Set this to a value lower than 1 to apply a fade effect to the background image.

In [None]:
# Display interactive interface
plot = si.GrainPlot(
    grains, 
    image = image, 
    predictor = predictor,
    blit = True,
    figsize = (12, 8),              # in
    px_per_m = 3390.,               # px/m
    scale_m = 0.5,                  # m
    # image_max_size = (240, 320),  # pxe cell contains some example filters to modify
    # image_alpha = 1.
)
plot.activate()

Measuring and drawing grains: 100%|██████████| 379/379 [00:02<00:00, 157.19it/s]


### Results

The following results are then saved to the location specified in `out_fn`:
- Grain shapes, for use elsewhere (geojson)
- Image with colorized grains and major/minor axes drawn in (jpg)
- Summary data, presenting measurements for each detected grain (csv)
- Summary histogram, representing major/minor axes of detected grains (jpg)
- Mask representations of the detected grains, in both computer-readable (png, 0-1) and human-readable (jpg, 0-255) formats

In [7]:
# Turn off interactive features
plot.deactivate()

# Draw the major and minor axes of each grain
plot.draw_axes()

# Retrieve unit conversion factor if scale bar selected in image
px_per_m = plot.px_per_m

In [None]:
# Save results
out_fn = './examples/interactive_edit/torrey_pines'
# Grain shapes
si.save_grains(out_fn + '_grains.geojson', grains)
# Grain image
plot.savefig(out_fn + '_grains.jpg')
# Summary data
summary = si.save_summary(
    out_fn + '_summary.csv', grains, px_per_m=px_per_m)
summary.head()
# Summary histogram
si.save_histogram(out_fn + '_summary.jpg', summary=summary)
# Training mask
si.save_mask(out_fn + '_mask.png', grains, image, scale=False)
si.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)