# Interactive grain detection/editing

In [7]:
import matplotlib.pyplot as plt
import segment_anything
import segmenteverygrain.interactions as si
from tqdm import tqdm

%matplotlib qt

### Load data

Images can be loaded in most common formats.

SAM checkpoints can be downloaded [here](https://huggingface.co/ybelkada/segment-anything/blob/main/checkpoints/sam_vit_h_4b8939.pth).

It is possible to start from scratch with just an image and SAM. To load grains detected previously, however, load a .geojson file generated by this or other scripts.

In [2]:
# Load image
fn = 'examples/torrey_pines.jpeg'
image = si.load_image(fn)

In [3]:
# Load SAM
fn = 'sam_vit_h_4b8939.pth'
sam = segment_anything.sam_model_registry['default'](checkpoint=fn)
predictor = segment_anything.SamPredictor(sam)
predictor.set_image(image)

In [4]:
# Load grains
fn = 'examples/auto/torrey_pines_grains.geojson'
grains = si.load_grains(fn)
# Use this instead to start from scratch
# grains = []

### Filter grains

Optionally, the cell below can filter loaded grains before display based on any of the calculated properties reported in the summary csv. Acceptable parameters are:
-   area
-   centroid
-   major_axis_length
-   minor_axis_length 
-   orientation
-   perimeter
-   max_intensity
-   mean_intensity
-   min_intensity

In [5]:
# Filter grains (optional)

# Calculate grain info
for g in tqdm(grains):
    g.measure(image=image)

# Convert from meters to pixels (also optional)
min_area = 0.                             # m^2
px_per_m = 3390.                          # px/m
min_area_px = min_area * px_per_m ** 2    # px^2

# Filter grains (see description)
grains = [g for g in grains if g.data['area'] > min_area_px]

100%|██████████| 467/467 [00:03<00:00, 133.92it/s]


### Interactive editing

The editing interface itself is defined in segmentanything.interactions.

Navigation within the interface is described in the [matplotlib documentation](https://matplotlib.org/stable/users/explain/figure/interactive.html#interactive-navigation). Additional controls are:

- `Left click`: Select/unselect existing grain or place foreground prompt for grain detection
- `Shift + left click/drag`: Create or adjust box prompt for grain detection
- `Right click`: Place background prompt for grain detection
- `Middle click`: Display measurement information about the indicated grain
- `Middle click + drag`: Measure scale bar to calibrate pixels per meter
- `Control`: Hold to temporarily hide selected grains
- `Escape`: Remove all prompts and unselect all grains
- `c`: Use selection box and/or foreground/background prompts to detect a grain
- `d`: Delete selected (highlighted) grains
- `m`: Merge selected grains (must be touching)
- `z`: Delete the most recently-created grain

`px_per_m`: The ratio of pixels to meters, if known. This will be replaced if a scale bar is measured in the interface using middle click & drag.

`scale_m`: The length of a scale bar in the image. Measure it with middle click & drag to calibrate size/area measurements.

`image_max_size` (y, x): Images larger than this in either dimension will be downscaled for display. Operations like grain detection will still be performed on the full image, but you won't be able to zoom in as far. This is necessary as a performance optimization. Reduce this size for better performance, increase this size for better visual quality.

`image_alpha`: Set this to values lower than 1 to apply a fade effect to the background image.

In [8]:
# Display interactive interface
plot = si.GrainPlot(
    grains, 
    image = image, 
    predictor = predictor,
    blit = True,
    figsize = (12, 8),              # in
    px_per_m = 3390.,               # px/m
    scale_m = 0.5,                  # m
    image_max_size = (2160, 4096),  # px
    image_alpha = 1.)
plot.activate()

100%|██████████| 467/467 [00:02<00:00, 186.82it/s]


### Results

The following results are then saved to the location specified in `out_fn`:
- Grain shapes, for use elsewhere (geojson)
- Image with colorized grains and major/minor axes drawn in (jpg)
- Summary data, presenting measurements for each detected grain (csv)
- Summary histogram, representing major/minor axes of detected grains (jpg)
- Mask representations of the detected grains, in both computer-readable (png, 0-1) and human-readable (jpg, 0-255) formats

In [9]:
plot.deactivate()

# Get updated grains (with proper scale) and unit conversion (if changed)
grains = plot.grains
px_per_m = plot.px_per_m

# Add grain axes to plot
for grain in tqdm(grains):
    grain.draw_axes(plot.ax)

100%|██████████| 465/465 [00:00<00:00, 689.55it/s]


In [12]:
# Save results
out_fn = 'examples/interactive/torrey_pines'
# Grain shapes
si.save_grains(out_fn + '_grains.geojson', grains)
# Grain image
plot.savefig(out_fn + '_grains.jpg')
# Summary data
si.save_summary(out_fn + '_summary.csv', grains, px_per_m=px_per_m)
# Summary histogram
si.save_histogram(out_fn + '_summary.jpg', grains, px_per_m=px_per_m)
# Training mask
si.save_mask(out_fn + '_mask.png', grains, image, scale=False)
si.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)