<a href="https://colab.research.google.com/github/neyhartj/BerryBox/blob/master/deploy_BerryBox_FCNSegmentationModel_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BerryBox Image Analysis Pipeline

Use this notebook for production applications of a trained fully convolutional network (FCN) to measure quality parameters on individual berries in images.

## Setup

Prior to running the pipeline, make sure you have completed these setup steps:

1. Clone the [`BerryBox` GitHub repository](https://github.com/neyhartj/BerryBox) and place it on Google Drive.
2. Place a copy of the trained model in BerryBox/fcn_model_training/productionModel. A trained model is freely available [here]().
3. Add images to be fed into the pipeline to BerryBox/imagesToSegment.

In [None]:
# Mount Google Drive - this will ask you to authorize.
from google.colab import drive
drive.mount('/content/drive/')
drive = "/content/drive/MyDrive"

!nvidia-smi

# Install plant cv
!pip install plantcv

Mounted at /content/drive/
Wed Jul  6 13:45:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   59C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+------------------------------------------------------------


# **Pipeline Settings**

Edit settings below to run the production pipeline

## Materials

The **materials** object will be used to store information about the features in the images. For now, the only materials are berries and non-berries (i.e. background). For each material, add the following (in order):



> **name** (str) - The name for the material. This is pretty arbitrary, but it will be
  used to label output folders and images.  
  **input_rbg_vals** (list) - The rbg values of the material in the input mask image.  
  **output_val** (int) - The greyscale value of the mask when you output the images.
  This is arbitrary, but every material should have its own output color
  so they can be differentiated.  
  **confidence_threshold** (float) - The lower this number, the more voxels will be labled a specific material. Essentially, the ML algorith outptus a confdience value  (centered on 0.5) for every voxel and every material. By default, voxels with  a confidence of 0.5 or greater are determined to be the material in question.  But we can labled voxles with a lower condience level by changing this  parameter

## Other settings

**proj_dir**: The path to the project directory (i.e. "BerryBox").

**inference_dir**: Folder with images to run through the pipeline. Defaults to "imagesToSegment".

**model_path**: The path to the trained FCN model. It should be located in the BerryBox/fcn_model_training/productionModel directory.

**normalization_path**: The path to the normalization data file that was saved during model training.

**region_properties**: A list of region properties to extract for the berries. See the [regionprops documentation](https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.regionprops) for details.

**max_area**: Maximum area of object (in pixels) to keep that object.

**min_area**: Minimum area of object (in pixels) to keep that object.

**save_segmentation_image** (boolean): should the pipeline save segmented images?

In [None]:
#############################
#### Set parameters #########
#############################

class Material:
 
  def __init__(self, name, input_rgb_vals, output_val, confidence_threshold=0):
    self.name = name
    self.input_rgb_vals = input_rgb_vals
    self.output_val = output_val
    self.confidence_threshold = confidence_threshold

# Create a list of materials so we can iterate through it
materials = [
             Material("background", [0,0,0], 0, 0.5),
             Material("berry", [255,255,255], 255, 0.5),
             ]


# What material would you like to make inferences for?
materials_toprint = ["berry"]

# Project directory
# IMPORTANT - ALL DIRECTORIES NEED TO END IN A /
proj_dir = drive + "/ImageAnalysis/BerryBox/"

# Directory of images to segment
inference_dir = proj_dir + "/imagesToSegment"


# Path to the trained FCN model
# This file should end in ".pth"
# model_path = proj_dir + "productionModel/berryBox_fcn_production_v1_model3.pth"
model_path = proj_dir + "productionModel/berryBox_fcn_developmental_v4_model3.pth"


# Normalization data path
normalization_path = proj_dir + "productionModel/berryBox_fcn_developmental_v4_model_normalization_param.txt"

# Path to the blank image with the color card - this is used for color correction
cc_img_path = proj_dir + "/resources/color_checker_standard1.JPG"

# Properties for regionprops
# region_properties = ["area", "axis_major_length", "axis_minor_length", "eccentricity"] # For local runs
region_properties = ["area", "major_axis_length", "minor_axis_length", "eccentricity"] # For colab runs

# Maximum object area (in pixels) to keep
max_area = 15000
# Minimum object area (in pixels) to keep
min_area = 200

# Should the pipeline save segmented images
save_segmentation_image = True

# **Image Segmentation**

Run the image inference pipeline. This pipeline will:
1. Read in an inference image, correct color, and identify the QR code.
2. Run the image through the prediction model
3. Segment the relevant mask
4. Identify objects in the image
5. Measure object properties and save the results

## Import packages and load a specific model

In [None]:
# Load relevant packages
import os
import torch
import torch.nn as nn
import numpy as np
import cv2 as cv
from plantcv import plantcv as pcv
from torchvision.models.segmentation.fcn import FCNHead
from torchvision.models.segmentation import fcn_resnet101
import torchvision.transforms as T
from PIL import Image
from scipy import ndimage as ndi
import pandas as pd
from tqdm import tqdm
from skimage.color import rgb2gray, label2rgb
from skimage.transform import rescale, resize, downscale_local_mean
from skimage.morphology import binary_erosion
from skimage.measure import label, regionprops_table, regionprops
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

# Name of the model group
current_model_name = os.path.basename(model_path).replace(".pth", "")
model_group = current_model_name.split("_model")[0]


## Specify directories
# Directory to store output segmented images
output_directory = proj_dir + "/output/"
# Directory of segmented images
seg_output_directory = output_directory + "/segmented_images/"

# Empty and create these directories
if not os.path.exists(output_directory):
    os.mkdir(output_directory)

if not os.path.exists(seg_output_directory):
    os.mkdir(seg_output_directory)

# How many materials?
num_materials = len(materials)

# Load a pretrained model
model = fcn_resnet101(pretrained=False)
model.classifier=FCNHead(2048, num_materials)
device = torch.device('cuda')
model.to(device)
 
# Load the model specified above
model.load_state_dict(torch.load(model_path), strict=False)
model.train()

print("Model loaded!")


## Load the normalization information ##
# Read in the important model training log information
with open(normalization_path, "r") as file:
    for line in file:
        # Strip off newline; separate by tab
        tabs = line.strip().split("\t")

        # Get the name of the variable; this will be used for assignment
        var_name = tabs[0]

        # Parse the second tab
        if tabs[1].startswith("tensor"):

            # Create a vector of numeric characters
            var_value = tabs[1].split("[")[1].split("]")[0].split(", ")
            # Convert this to numeric
            var_value = [float(x) for x in var_value]
            # Convert to np array; then to tensor
            var_value = np.array(var_value)
            var_value = torch.tensor(var_value)

        else:
            var_value = float(tabs[1])

        # Assign variable name
        vars()[var_name] = var_value
        
# assign to mean and std
mean = normalization_mean
std = normalization_std
newW = int(image_scale_newW)
newH = int(image_scale_newH)

# Find the color card in the source file
# Read in the color checker standard file
cc_img = np.array(Image.open(cc_img_path).resize((newW, newH)), dtype = "uint8")
# Find the color card in the color checker standard file
df1, start, space = pcv.transform.find_color_card(rgb_img = cc_img)
# Create a mask
# Use these outputs to create a labeled color card mask
target_mask = pcv.transform.create_color_card_mask(rgb_img = cc_img, radius = 25, start_coord = start, 
                                                   spacing = space, ncols = 4, nrows = 6)
# get color matrix of target and save
target_headers, target_matrix = pcv.transform.get_color_matrix(cc_img, target_mask)


# Load a QR code detector
detector = cv.QRCodeDetector()

# Create a function to iterate over image grids
def find_qr_grids(img, sr, sc, gh, gw):
    for j, r0 in enumerate(sr):
        for i, c0 in enumerate(sc):
            r1 = r0 + gh
            c1 = c0 + gw

            # Crop the image
            img2_crop = img[r0:r1, c0:c1]

            # Attempt to find the QR code
            collection_id, points, _ = detector.detectAndDecode(img2_crop)

            # If the QR code is found, stop
            if collection_id != "":
                # print("QR code found for sample: " + collection_id)
                return collection_id, points, i, j
            
    # Alert user if no QR code was found
    if collection_id == "":
        # print("No QR code found.")
        return "", "", i, j




Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


  0%|          | 0.00/171M [00:00<?, ?B/s]

Model loaded!


## Run the image processing pipeline

In [None]:
## Iterate over images and run through the prediction model ##

# Rename the directory containing the images to segment
dir_name = inference_dir
filenames = [x for x in os.listdir(dir_name) if ("JPG" in x.upper() or "PNG" in x.upper())]
filenames.sort()
print(str(len(filenames)) + " images found.")

# Create an empty list to store region data
all_region_df = []

# Iterate over the images
# tqdm produces a progress bar
for i, filename in tqdm(enumerate(filenames)):

# # ## TESTING ##
# i = 0
# filename = filenames[i]
# ################
    
    # Open the image
    image = Image.open(dir_name +'/'+ filename)

    # If the image is not PNG, convert to PNG
    if not "PNG" in filename.upper():
        # Get the extension
        ext = filename.split(".")[-1]
        filename = filename.replace(ext, "PNG")
        image.save(dir_name +'/'+ filename)
        # Read the image back in
        image = Image.open(dir_name +'/'+ filename)

    # Rescale the image
    image = image.resize((newW, newH))
    # Convert to gray
    image_gray = np.asarray(image)
    image_gray = rgb2gray(image_gray)
    # Convert image to NP array
    # Need to use uint8 for the QR code detector and the color correction
    image = np.array(image, dtype = "uint8")

    ###
    # Find and read the QR code
    ###

    # Crop the image to speed up the QR code and color card finder
    half_newW = int(newW / 2)
    img1_resize = image[:, half_newW:newW, :]

    # Split images into a grid
    # Grid size, rows x cols
    split_grid_size = (3, 1)
    # Get the dimensions of each element of the grid
    grid_h = int(newH / split_grid_size[0])
    grid_w = int(half_newW / split_grid_size[1])
    # Starting points
    start_rows = [x for x in range(0, newH, grid_h)]
    start_cols = [x for x in range(0, half_newW, grid_w)]
    start_rows.reverse()
    start_cols.reverse()
    # Find the QR code
    # When a QR code is not found, cid is an empty string
    cid, pts, i, j = find_qr_grids(img = img1_resize, sr = start_rows, sc = start_cols, gh = grid_h, gw = grid_w)


    ###
    # Pixel size determination
    ###

    # ## Find the color checker card
    # # Downsize
    # h, w, d = img1.shape
    # scale_percent = 50 # percent of original size
    # new_h = int(h * scale_percent / 100)
    # new_w = int(w * scale_percent / 100)

    # # Crop and scale the image
    # half_new_w = int(new_w / 2)
    # img1_resize = cv.resize(img1, (new_w, new_h))[:, half_new_w:new_w, :]

    # Try to find the color card
    try:
        df1, start, space = pcv.transform.find_color_card(rgb_img = img1_resize)
        color_card_found = True
    except:
        color_card_found = False

    # IF the color card was not found, skip pixel scaling and color correction
    if color_card_found:

        # Calculate the average box width and height
        box_w = np.mean(df1['width'])
        box_h = np.mean(df1['height'])
        # We know boxes are about 1.1581 cm on each side (square this to get area)
        # Calculate the number of pixel per cm
        pixel_per_cm = np.mean([x / 1.1581 for x in [box_w, box_h]])
        cm_per_pixel = 1 / pixel_per_cm
        # Recalculate the pixels per cm
        cm2_per_pixel = cm_per_pixel ** 2

        ###
        # Color correction
        ###

        # Create a mask
        # Use these outputs to create a labeled color card mask
        # The radius setting needs to be large enough to capture the color on each square, but not
        # too big as to overlap with adjacent squares.
        source_mask = pcv.transform.create_color_card_mask(rgb_img = img1_resize, radius = 15, start_coord = start, 
                                                            spacing = space, ncols = 4, nrows = 6)
        # Get the source matrix
        source_headers, source_matrix = pcv.transform.get_color_matrix(img1_resize, source_mask)
        ## Run color correction ##
        # matrix_a is a matrix of average rgb values for each color ship in source_img, matrix_m is a moore-penrose inverse matrix,
        # matrix_b is a matrix of average rgb values for each color ship in source_img
        matrix_a, matrix_m, matrix_b = pcv.transform.get_matrix_m(target_matrix = target_matrix, source_matrix = source_matrix)
        # deviance is the measure of how greatly the source image deviates from the target image's color space. 
        # Two images of the same color space should have a deviance of ~0.
        # transformation_matrix is a 9x9 matrix of transformation coefficients 
        deviance, transformation_matrix = pcv.transform.calc_transformation_matrix(matrix_m, matrix_b)

        image3 = pcv.transform.apply_transformation_matrix(source_img = image, target_img = cc_img, transformation_matrix = transformation_matrix)

    else:
        image3 = image

    ###
    # Use the FCN model to predict berry pixels
    ###

    # Create a tensor from the image
    image3 = image3.astype("float")
    new_im = np.zeros((3, newH, newW))
    new_im[0,:,:] = image3[:,:,0]
    new_im[1,:,:] = image3[:,:,1]
    new_im[2,:,:] = image3[:,:,2]
    image_tensor = new_im
    image_tensor = torch.from_numpy(image_tensor)
    # Normalize the tensor and send it to the GPU
    image_tensor = T.Normalize(mean=mean, std=std)(image_tensor)
    image_tensor.unsqueeze_(0)
    image_tensor = image_tensor.to(device=device, dtype=torch.float32)

    # Run the image through the prediction model
    with torch.no_grad():
        mask = model(image_tensor)['out']
        mask = nn.Sigmoid()(mask)
        mask = mask.cpu().detach().numpy()

    ###
    # Measure berry properties
    ###

    # Iterate over materials to print
    for mat_to_print in materials_toprint:

        # Find the index of this material in the materials list
        mat_idx = [i for i, x in enumerate(materials) if x.name == mat_to_print][0]
        # Get the material at this index
        mat = materials[mat_idx]
        # Get the mask from the prediction model at this index
        mat_mask = mask[0,mat_idx,:,:]
        mat_mask[mat_mask >= mat.confidence_threshold] = mat.output_val
        mat_mask[mat_mask < mat.confidence_threshold] = 0

        # Perform object segmentation and regionprop calculation
        # This is from https://github.com/danforthcenter/plantcv/blob/master/plantcv/plantcv/watershed.py
        # Convert the mat_mask to 8-bit
        mat_mask = mat_mask.astype("uint8")

        # Run watershed here? Or binary erosion?
        # For now, skip
        # Run distance transform
        # dist_transform = cv.distanceTransformWithLabels(mat_mask, distanceType = cv.DIST_L2, maskSize = 0)[0]
        # local_max = feature.peak_local_max(dist_transform, indices = False, min_distance = distance, labels = mat_mask)
        # markers = ndi.label(local_max, structure=np.ones((3, 3)))[0]
        # dist_transform1 = -dist_transform
        # seg1 = segmentation.watershed(dist_transform1, markers, mask = mat_mask)
        seg1 = mat_mask

        ## Estimate berry traits
        # Label the segmentation output
        label_mat = label(np.array(seg1), background = 0)

        # Regionprops
        region_properties1 = list(set(["label", "bbox"] + region_properties))
        region_properties_names = region_properties1 + [x + "_intensity_mean" for x in ["red", "green", "blue"]] + [x + "_intensity_sd" for x in ["red", "green", "blue"]]
        region_properties_names = region_properties_names + [x + "_intensity_mean" for x in ["hue", "sat", "val"]] + [x + "_intensity_sd" for x in ["hue", "sat", "val"]]
        region_properties_names = tuple(["file_name", "collection_id", "color_corrected", "material", "label"] + [x for x in region_properties_names if x != "label"])

        # Empty dictionary to store data
        regions_dict = {}

        # Initialize lists in the dictionary
        for key in region_properties_names:
            regions_dict[key] = []

        # Iterate over regions in the image
        for region in regionprops(label_image = label_mat):

            # Add manual keys
            regions_dict["file_name"] = filename
            regions_dict["collection_id"] = cid
            regions_dict["color_corrected"] = str(color_card_found)
            regions_dict["material"] = mat_to_print

            # Add props to the dictionary
            for prop in region_properties1:
                regions_dict[prop].append(region[prop])


            # Convert image to HSV
            image3_hsv = cv.cvtColor(image3.astype("uint8"), cv.COLOR_RGB2HSV)

            berry_rgb_values = []
            berry_hsv_values = []

            for y, x in region.coords:
                berry_rgb_values.append(image3[y, x, :])
                berry_hsv_values.append(image3_hsv[y, x, :])

            for c, left in enumerate(["red", "green", "blue"]):
                key = left + "_intensity_mean"
                vals = [x[c] for x in berry_rgb_values]
                regions_dict[key].append(np.mean(vals))

                key = key.replace("mean", "sd")
                regions_dict[key].append(np.std(vals))

            for c, left in enumerate(["hue", "sat", "val"]):
                key = left + "_intensity_mean"
                vals = [x[c] for x in berry_hsv_values]
                regions_dict[key].append(np.mean(vals))

                key = key.replace("mean", "sd")
                regions_dict[key].append(np.std(vals))

        # Convert the regions_dict to a data.frame
        regions_df = pd.DataFrame(regions_dict)

        # Remove excessively large regions
        regions_df = regions_df[(regions_df["area"] <= max_area) & (regions_df["area"] >= min_area)]

        # Convert area, length, width to cm
        regions_df["area"] = regions_df["area"] * cm2_per_pixel
        regions_df["major_axis_length"] = regions_df["major_axis_length"] * cm_per_pixel
        regions_df["minor_axis_length"] = regions_df["minor_axis_length"] * cm_per_pixel

        # Save the region data
        all_region_df.append(regions_df)

        ### Save a segmentation image ###
        if save_segmentation_image:
            image_use = image3.astype("uint8")
            image_use = label2rgb(label_mat, image_use, alpha=0.3, bg_label = 0)
            regions_df_use = regions_df.to_dict()

            # Iterate over the berry index
            for lab in regions_df_use["label"]:
                bbox = regions_df_use["bbox"][lab]
                # draw rectangle around segmented coins
                minr, minc, maxr, maxc = bbox
                cv.rectangle(image_use, (minc, minr), (maxc, maxr),(0,255,0),2)

            image_use_save = Image.fromarray((image_use * 255).astype("uint8"))
            image_use_save.save(seg_output_directory + "/" + mat_to_print + "-segmented-" + cid + "-" + filename)

# Save the region data
# Merge the region data.feames
all_regions_data = pd.concat(all_region_df)
# Drop the bbox column
all_regions_data = all_regions_data.drop(columns = "bbox")
region_filename = output_directory + "/" + current_model_name + "_InferenceImageRegionData.csv"
all_regions_data.to_csv(region_filename, index = False)
            
print("\nImage analysis pipeline complete!")




10 images found.


10it [01:21,  8.11s/it]


Image analysis pipeline complete!



