# BerryBox Image Analysis Pipeline


## Validation only

This notebook provides the code to validate a developmental FCN model. **It is not meant for production use.**


# **Materials**
  Input the material mask name and information below. Some of the items described here may not appear below, but anything that appears below is described here.

  Specifically:
 
  **name** - The name for the material. This is pretty arbitrary, but it will be
  used to label output folders and images.
 
  **input_rbg_vals** - The rbg values of the material in the input mask image.
 
  **output_val** - The greyscale value of the mask when you output the images.
  This is arbitrary, but every material should have its own output color
  so they can be differentiated
 
  **confidence_threshold** - The lower this number, the more voxels will be labled a specific material. Essentially, the ML algorith outptus a confdience value  (centered on 0.5) for every voxel and every material. By default, voxels with  a confidence of 0.5 or greater are determined to be the material in question.  But we can labled voxles with a lower condience level by changing this  parameter

  **training_image_directory /training_mask_directory**: Input the directory where your training images and masks are located.

  **validation_fraction**: Input the fraction of images you want to validate your model during training. These are not a independent validation, but are part of the training process.

  **num_models**: Enter the number of models you want to iteratively train. Because these are statistical models, the performance of any given model will vary. Training more models will allow you to select the model that best fits your data.
  
  **num_epochs**: Enter number of epochs that you want to use to train your model. More is generally better, but takes more time.

  **batch_size**: Input your batch size. Larger batch sizes allow for faster training, but take up more VRAM. If you are running out of VRAM during training, decrease your batch size.

  **scale**: Input how you want your images scaled during model training and inference. When the scale is 1, your images will be used at full size for training. When the scale is less than 1, your images will be downsized according to the scale you set for training and inference, decreasing VRAM usage. If you run out of VRAM during training, consider rescaling your images.
  
  **normalization_path**: The path to the normalization data file that was saved during model training.

  **models_directory**: Directory where your models are saved.

  **model_group**: Name for the group models you iteratively generate.

  **current_model_name**: Name for each individual model you generate; will automatically be labeled 1 through n for the number of models you specify above.

  **val_images/val_masks**: Input the directory where your independent validation images and masks are located. These images are not used for training and are used as an independent validation of your model.

  **csv_directory**: Directory where a CSV file of your validation results will be saved.

  **inference_directory**: Directory where the images you want analyzed are located.

  **output_directory**: Directory where you want your analysis results to be saved.



In [22]:
#############################
#### Set user parameters ####
#############################

class Material:
 
  def __init__(self, name, input_rgb_vals, output_val, confidence_threshold=0):
    self.name = name
    self.input_rgb_vals = input_rgb_vals
    self.output_val = output_val
    self.confidence_threshold = confidence_threshold

#Creating a list of materials so we can iterate through it
materials = [
             Material("background", [0,0,0], 0, 0.5),
             Material("berry", [255,255,255], 255, 0.75),
             ]


# What material would you like to make inferences for?
materials_toprint = ["berry"]

# Project directory
# IMPORTANT - ALL DIRECTORIES NEED TO END IN A /
# proj_dir =  drive + "/ARS_Cranberry/ImageAnalysis/BerryBox/fcn_model_building/"
proj_dir = "/project/gifvl_vaccinium/cranberryImaging/BerryBox/fcn_model_building/"


#Decrease scale to decrease VRAM usage; if you run out of VRAM during traing, restart your runtime and down scale your images
scale = 0.3

# Distance for the watershed segmentation
distance = 10

# Properties for regionprops
# region_properties = ["area", "axis_major_length", "axis_minor_length", "eccentricity"] # For local runs
region_properties = ["area", "major_axis_length", "minor_axis_length", "eccentricity"] # For colab runs

# Should the pipeline include watershed segmentation? Note: this can be unreliable
run_watershed = False

# Input deep learning model path 
# This file should end in ".pth"
model_path = proj_dir + "/model_output/berryBox_fcn_0.0.2/models/berryBox_fcn_0.0.2_model3.pth"

# Normalization data path
normalization_path = proj_dir + "/model_output/berryBox_fcn_0.0.2/berryBox_fcn_20220608-113848_model_normalization_param.txt"

# Name of the model group
model_group = "berryBox_fcn_0.0.2/"
current_model_name = model_group.replace("/", "") + "_model"

# Directory of images to segment
inference_directory = proj_dir + "/imagesToSegment"



# **Image Segmentation**

Run the image inference pipeline. This pipeline will:
1. Read in an inference image and identify the QR code and scaling
2. Run the image through the prediction model
3. Segment the relevant mask
4. Identify objects in the image
5. Measure object properties and save the results

## Import packages and load a specific model

In [23]:
# Load relevant packages
import os
import torch
import torch.nn as nn
import numpy as np
import cv2 as cv
from torchvision.models.segmentation.fcn import FCNHead
from torchvision.models.segmentation import fcn_resnet101
import torchvision.transforms as T
from PIL import Image
from scipy import ndimage as ndi
import pandas as pd
from skimage.color import rgb2gray, label2rgb
from skimage.transform import rescale, resize, downscale_local_mean
from skimage import feature, segmentation
from skimage.measure import label, regionprops_table, regionprops
from plantcv import plantcv as pcv
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt


## Specify directories
# Directory to store output segmented images
output_directory = proj_dir + "/segmentation_output"
# Directory of segmented images
seg_output_directory = output_directory + "/segmented_images"

# Empty and create these directories
for dirname in [output_directory, seg_output_directory]:
  if os.path.exists(dirname):
    # If exists, delete everything within it
    for subfile in os.listdir(dirname):
      if os.path.isdir(dirname + "/" + subfile):
        continue
      else:
        os.remove(dirname + "/" + subfile)
  else:
    # Else create
    os.mkdir(dirname)

# How many materials?
num_materials = len(materials)

# Load a pretrained model
model = fcn_resnet101(pretrained=False)
model.classifier=FCNHead(2048, num_materials)
device = torch.device('cuda')
model.to(device)
 
# Load the model specified above
model.load_state_dict(torch.load(model_path), strict=False)
model.train()

print("Model loaded!")


## Load the normalization information ##
# Empty dict to store tensors
norm_tensors = {}

# Read in the normalization file
with open(normalization_path, "r") as file:
    for line in file:
        tabs = line.split("\t")
        
        # Create a vector of numeric characters
        num_char_vec = tabs[1].split("[")[1].split("]")[0].split(", ")
        # Convert this to numeric
        num_vec = [float(x) for x in num_char_vec]
        # Convert to np array
        num_arr = np.array(num_vec)
        
        # Convert to tensor and store
        norm_tensors[tabs[0]] = torch.tensor(num_arr)
        
# assign to mean and std
mean = norm_tensors['normalization_mean']
std = norm_tensors['normalization_std']

# Load a QR code detector
detector = cv.QRCodeDetector()

Model loaded!


## Run the image processing pipeline

In [34]:
## Iterate over images and run through the prediction model ##

# Rename the directory containing the images to segment
dir_name = inference_directory
filenames = os.listdir(dir_name)
print(str(len(filenames)) + " images found.")

# Create an empty list to store region data
all_region_df = []

# Iterate over the images
for i, filename in enumerate(filenames):

# ## TESTING ##
# i = 0
# filename = filenames[i]

    # Open the image
    image = Image.open(dir_name +'/'+ filename)
    image_cv = cv.imread(dir_name + "/" + filename)

    ## Find the QR code and determine the sample name ##

    # Crop the cv2 image
    h, w, d = image_cv.shape
    # Crop the image - this will look at the bottom-right corner
    start_h = int(h / 2)
    start_w = int(w / 2)
    image_cv_crop = image_cv[start_h:h, start_w:w, :]
    # Run the QR detector
    collection_id, points, _ = detector.detectAndDecode(image_cv_crop)

    ##

    ## Find the color checker card
    # Crop
    image_cv_crop = image_cv[:, start_w:w, :]
    h, w, d = image_cv_crop.shape
    # Downsize
    scale_percent = 25 # percent of original size
    new_h = int(h * scale_percent / 100)
    new_w = int(w * scale_percent / 100)

    # Find the color card
    df1, start, space = pcv.transform.find_color_card(rgb_img = cv.resize(image_cv_crop, (new_w, new_h)))

    # Calculate the average box width and height
    box_w = np.mean(df1['width'])
    box_h = np.mean(df1['height'])
    # Rescale the box width/height; this is the average full-scale box_dim
    avg_box_dim = np.mean([x / (scale_percent / 100) for x in [box_w, box_h]])
    # We know boxes are about 1.1581 cm on each side (square this to get area)
    # Calculate the number of pixel per cm
    pixel_per_cm = avg_box_dim / 1.1581
    cm_per_pixel = 1 / pixel_per_cm

    ##

    # Convert the original image to grayscale
    image_gray = np.asarray(image)
    image_gray = rgb2gray(image_gray)
    # Rescale
    image_gray = rescale(image_gray, scale, anti_aliasing=True)

    # Rescale the image
    # h, w, d = image.shape
    w, h = image.size
    newW, newH = int(scale * w), int(scale * h)
    # Convert image from nparray to pil image; remember to convert colors
    # image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    # image = Image.fromarray(image)
    image = image.resize((newW, newH))
    image_gray = image_gray[0:newH, 0:newW] # Add this to resize image1
    image = np.array(image, dtype = float)
    new_im = np.zeros((3, newH, newW))
    new_im[0,:,:] = image[:,:,0]
    new_im[1,:,:] = image[:,:,1]
    new_im[2,:,:] = image[:,:,2]
    image_tensor = new_im

    # Recalculate the pixels per cm
    pixel_per_cm_scale = pixel_per_cm * scale
    cm_per_pixel_scale = 1 / pixel_per_cm_scale
    cm2_per_pixel_scale = cm_per_pixel_scale ** 2

    # Create a tensor from the image
    image_tensor = torch.from_numpy(image_tensor)
    # Normalize the tensor and send it to the GPU
    image_tensor = T.Normalize(mean=mean, std=std)(image_tensor)
    image_tensor.unsqueeze_(0)
    image_tensor = image_tensor.to(device=device, dtype=torch.float32)

    # Run the image through the prediction model
    with torch.no_grad():
        mask = model(image_tensor)['out']
        mask = nn.Sigmoid()(mask)
        mask = mask.cpu().detach().numpy()

    # Iterate over materials to print
    for mat_to_print in materials_toprint:
        # Find the index of this material in the materials list
        mat_idx = [i for i, x in enumerate(materials) if x.name == mat_to_print][0]

        # Get the material at this index
        mat = materials[mat_idx]

        # Get the mask from the prediction model at this index
        mat_mask = mask[0,mat_idx,:,:]
        mat_mask[mat_mask >= mat.confidence_threshold] = mat.output_val
        mat_mask[mat_mask < mat.confidence_threshold] = 0

        # Perform object segmentation and regionprop calculation
        # This is from https://github.com/danforthcenter/plantcv/blob/master/plantcv/plantcv/watershed.py
        # Convert the mat_mask to 8-bit
        mat_mask = mat_mask.astype("uint8")

        # Run the watershed if called
        if run_watershed:
            # Run distance transform
            dist_transform = cv.distanceTransformWithLabels(mat_mask, distanceType = cv.DIST_L2, maskSize = 0)[0]
            local_max = feature.peak_local_max(dist_transform, indices = False, min_distance = distance, labels = mat_mask)

            markers = ndi.label(local_max, structure=np.ones((3, 3)))[0]
            dist_transform1 = -dist_transform
            seg1 = segmentation.watershed(dist_transform1, markers, mask = mat_mask)
        else:
            seg1 = mat_mask

        ## Estimate berry traits
        # Label the segmentation output
        label_mat = label(np.array(seg1), background = 0)

        # Regionprops
        region_properties1 = list(set(["label", "bbox"] + region_properties))
        region_properties_names = region_properties1 + [x + "_intensity_mean" for x in ["red", "green", "blue"]] + [x + "_intensity_sd" for x in ["red", "green", "blue"]]
        region_properties_names = region_properties_names + [x + "_intensity_mean" for x in ["hue", "sat", "val"]] + [x + "_intensity_sd" for x in ["hue", "sat", "val"]]
        region_properties_names = tuple(["unique_id", "material", "label"] + [x for x in region_properties_names if x != "label"])

        # Empty dictionary to store data
        regions_dict = {}

        # Initialize lists in the dictionary
        for key in region_properties_names:
            regions_dict[key] = []

        # Iterate over regions in the image
        for region in regionprops(label_image = label_mat):

            # Add manual keys
            regions_dict["unique_id"] = collection_id
            regions_dict["material"] = mat_to_print

            # Add props to the dictionary
            for prop in region_properties1:
                # If the property is a length, convert
                if prop in ["minor_axis_length", "major_axis_length"]:
                    to_append = region[prop] * cm_per_pixel_scale
                elif prop == "area":
                    to_append = region[prop] * cm2_per_pixel_scale
                else:
                    to_append = region[prop]

                regions_dict[prop].append(to_append)


            # Convert image to HSV
            image_hsv = cv.cvtColor(image.astype("uint8"), cv.COLOR_RGB2HSV)

            berry_rgb_values = []
            berry_hsv_values = []

            for y, x in region.coords:
                berry_rgb_values.append(image[y, x, :])
                berry_hsv_values.append(image_hsv[y, x, :])

            for c, left in enumerate(["red", "green", "blue"]):
                key = left + "_intensity_mean"
                vals = [x[c] for x in berry_rgb_values]
                regions_dict[key].append(np.mean(vals))

                key = key.replace("mean", "sd")
                regions_dict[key].append(np.std(vals))

            for c, left in enumerate(["hue", "sat", "val"]):
                key = left + "_intensity_mean"
                vals = [x[c] for x in berry_hsv_values]
                regions_dict[key].append(np.mean(vals))

                key = key.replace("mean", "sd")
                regions_dict[key].append(np.std(vals))

        # Convert the regions_dict to a data.frame
        regions_df = pd.DataFrame(regions_dict)

        # Remove excessively large regions
        regions_df1 = regions_df[regions_df["area"] <= 5]
        
        # Save the region data
        all_region_df.append(regions_df1)

        ### Save a segmentation image ###
        if save_segmentation_image:
            image_use = image.astype("uint8")
            regions_df_use = regions_df1.to_dict()

            fig, ax = plt.subplots(figsize=(10, 6))
            ax.imshow(image_use)

            # Iterate over the berry index
            for lab in regions_df_use["label"]:
                bbox = regions_df_use["bbox"][lab]
                # draw rectangle around segmented coins
                minr, minc, maxr, maxc = bbox
                rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                        fill=False, edgecolor='red', linewidth=2)
                ax.add_patch(rect)
                # Add label
                ax.text(maxc, maxr, lab)

            ax.set_axis_off()
            plt.close(fig)
            plt.savefig(seg_output_directory + "/" + collection_id + "-" + mat_to_print + "-" + filename)
            

# Save the region data
# Merge the region data.feames
all_regions_data = pd.concat(all_region_df)
region_filename = output_directory + "/" + current_model_name + "_InferenceImageRegionData.csv"
all_regions_data.to_csv(region_filename)
            
print("Segmentation complete!")

10 images found.
Segmentation complete!


<Figure size 432x288 with 0 Axes>

In [None]:


#############################
#### Parameter  loading   ####
#############################
 
        
dataset = BasicDataset(training_image_directory, training_mask_directory, scale=scale, transform=False)
 
#!!!!!!!!!!!!!!!!!!!!!!!!!!Set batch size here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# train, val=trainval_split(dataset, val_fraction=0.5)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)#, collate_fn=pad_collate)
#val_loader = DataLoader(val, batch_size=3, shuffle=False, num_workers=0, pin_memory=True)#, collate_fn=pad_collate)
nimages = 0
mean = 0.
std = 0.
for batch, _ in train_loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)
 
# Final step
mean /= nimages
std /= nimages
 
print(mean)
print(std)

dataset.means=mean
dataset.stds=std 

nimages = 0
newmean = 0.
newstd = 0.
for batch, _ in train_loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    newmean += batch.mean(2).sum(0) 
    newstd += batch.std(2).sum(0)
 
# Final step
newmean /= nimages
newstd /= nimages
 
print(newmean)
print(newstd)


## Save these normalization values for the production pipeline ##

# Open a file
param_filename = model_group_directory + model_group.replace("/", "") + "_normalization_param.txt"
handle = open(param_filename, "w")

# Write all the parameters to this file
handle.write("normalization_mean" + "\t" + str(newmean) + "\n")
handle.write("normalization_std" + "\t" + str(newstd) + "\n")

# Close the file
handle.close()
