<h1>SegmentEveryGrain -- Hawaii project</h1>

Details go here

<h3>To-do</h3>

- Break interactive elements into segmenteverygrain.interactions
- Create tests for interactive elements?
- Sort out grain_data, all_grain, new_grain_data, etc
- Performance improvements?

In [None]:
# Setup ---

# Imports
import cv2
import keras.saving
import keras.utils
import matplotlib.pyplot as plt
import numpy as np
import segment_anything
import segmenteverygrain
import segmenteverygrain.interactions as si

# import tensorflow as tf
# from tensorflow.python.platform.build_info import build_info
# for k, v in build_info.items():
#     print(f'{k}:\t{v}')

# Plotting
%matplotlib widget
FIGSIZE = (12, 8)

In [None]:
# Unet segmentation ---

# Load Unet model
model = keras.saving.load_model("pretrained_model.keras", custom_objects={'weighted_crossentropy': segmenteverygrain.weighted_crossentropy})

# Load image for analysis
fname = './DJI_0605010.jpg'
image = np.array(keras.utils.load_img(fname))

# Generate Unet prompts
image_pred = segmenteverygrain.predict_image(image, model, I=256)
labels, coords = segmenteverygrain.label_grains(image, image_pred, dbs_max_dist=20.0)

# Display Unet prompts for verification
fig, ax = plt.subplots(figsize=FIGSIZE)
ax.set_aspect('equal')
ax.imshow(image_pred)
plt.scatter(np.array(coords)[:,0], np.array(coords)[:,1], c='k')
plt.xticks([])
plt.yticks([])

In [None]:
# SAM segmentation ---

# Close Unet figure
plt.close()

# Load and apply Segment Anything model
sam = segment_anything.sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
# TODO: Figure out TensorRT to accelerate this w/GPU; currently it's actually slower on CUDA
# sam.to(device='cuda')
# TODO: Separate this function into smaller chunks (plotting, mask, etc)
# TODO: Choose min_area by image size? Do unit conversion from pixels first?
all_grains, labels, mask_all, grain_data, fig, ax = segmenteverygrain.sam_segmentation(
    sam, image, image_pred, coords, labels, 
    min_area=400.0, plot_image=False, remove_edge_grains=False, remove_large_objects=False
)

# Set up predictor for interaction plot
# TODO: this already happens in sam_segmentation!
predictor = segment_anything.SamPredictor(sam)
predictor.set_image(image)

In [None]:
# Manual editing ---

# Create and display interactive grain plot
grains = [si.Grain(p.exterior.xy, row[1]) for p, row in zip(all_grains, grain_data.iterrows())]
grain_plot = si.GrainPlot(grains, image=image, predictor=predictor, figsize=FIGSIZE)
grain_plot.activate()

In [None]:
# Process manual edits ---

# Disable further interactions
grain_plot.deactivate()

# Get grain data as pd.DataFrame
new_grain_data = grain_plot.get_data()
# print(new_grain_data.head())

# TODO: Convert from pixels to real units
n_of_units = 1000
units_per_pixel = n_of_units/1552.77 # length of scale bar in pixels
for col in ['major_axis_length', 'minor_axis_length', 'perimeter', 'area']:
    new_grain_data[col] *= units_per_pixel

# Save csv, show histogram
new_grain_data.to_csv(fname[:-4] + '.csv')
fig, ax = segmenteverygrain.plot_histogram_of_axis_lengths(new_grain_data['major_axis_length']/1000, new_grain_data['minor_axis_length']/1000)

# Save mask with original image for training
dirname = './images/'
outname = dirname + fname.split('/')[-1].split('.')[-2]
rasterized_image, mask = grain_plot.get_mask()
# TODO: Remove opencv dependency?
cv2.imwrite(outname + '_mask2.png', mask)
cv2.imwrite(outname + '_verify2.png', mask*127)
cv2.imwrite(outname + '_image2.png', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

<h2>Results</h2>

At the end of this script, you should have the following where "fname" is the name input image (without file extension):

- Cell output in this notebook (can be saved manually):
    - A plot of Unet prompts generated for Segment Anything to analyze. If this looks messy, consider fine-tuning the model.
    - A colorized plot of detected grains, after manual edits.
    - A histogram of grain axis lengths for quick verification.
- ./:
    - fname.csv: A list of identified grains and associated data (after manual edits)
- ./Images/:
    - fname_image.png: A copy of the original input image
    - fname_mask.png: A map of grains, boundaries, and background for fine-tuning the Unet model
    - fname_verify.png: A human-readable version of fname_mask for easy verification
