# colab-WormHeadCounter v0.1

This model is fine-tuned to detect the head of the worm and count fluorescent signals in specific reporter images.

Performance with a different reporter cannot be guaranteed.

1.   Upload your images in one directory 📁
2.   Select pretrained model
3.   Set parameters
4.   Runtime > Run all
5.   Download results directory

Any question? Ask to Youngjun 🤗

In [None]:
#@title Set parameters and `Runtime` -> `Run all`
from google.colab import files
import os
import re
import hashlib
import random
import glob

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

def add_hash(x,y):
  return x+"_"+hashlib.sha1(y.encode()).hexdigest()[:5]

pretrained_model_name = "ypark-bioinfo/segformer-b5-finetuned-ce-head-image_ver1.1" #@param ['ypark-bioinfo/segformer-b5-finetuned-ce-head-image_ver1.1'] {type:"string"}
brightness = 50 #@param [100, 70, 50, 30, 10] {type:"raw"}
#@markdown - It changes the minimum signal of the image before min-max normalization.
minimun_head_size = 7500 #@param {type:"raw"}
#@markdown - It changes the minimum head size to detect.
min_distance = 2 #@param {type:"raw"}
#@markdown - The pixel distance between signal peaks affects the total number of detected peaks. (recommend value: 1~10)
threshold_abs = 0.35 #@param {type:"raw"}
#@markdown - The peak detection threshold affects the total number of detected peaks. (recommend value: 0.05~0.50)
raw_image_directory = '/content/test_images' #@param {type:"raw"}
#@markdown - Please provide the location of the images.

# Check the results folder

## Utils

In [None]:
import glob
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import skimage
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.morphology import erosion
from skimage.morphology import footprint_rectangle
import matplotlib.patches as mpatches
from skimage.measure import label, regionprops
from skimage.segmentation import clear_border

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([255/255, 144/255, 30/255, 0.7])
    h, w = mask.shape[:2]
    x = np.where(mask >0, 1, 0)
    mask =x
    mask_image = np.uint8(mask[:,:,0].reshape(h,w,1)*255 * color.reshape(1, 1, -1))
    #print(np.uint8(mask_image))

    ax.imshow(mask_image, interpolation='none')


def get_results(_test_images,
                image_names,
                output_dir,
                ml_mask=2,
                minimun_head_size=150,
                bg_to_remove_level=3,
                min_distance=10,
                threshold_abs=0.35,
                bg_max_value=160,

               ):
    out_results = []

    if os.path.exists(f'{os.path.join(output_dir,f"Summary_label_{ml_mask}_mindist_{min_distance}_thrsh_{threshold_abs}")}.csv'):
        print('PASS! Summary already exists')
        return

    for _i in range(len(_test_images)):
        test_image = _test_images[_i]
        filename = os.path.basename(image_names[_i]).rstrip('.image.png')

        x = np.array(test_image[ml_mask])
        vis = np.where(x == 1, 255,0)
        if len(x.shape) == 2:
            x = np.repeat(np.expand_dims(x, 2), 3, axis=2)

        label_mask = label(x, connectivity=1)
        region_box = regionprops(label_mask)

        number_of_heads = len(region_box)
        print(number_of_heads)
        head_id = 0
        for region_index in range(number_of_heads):
            mask_patch = np.where(label_mask == region_index+1, 1,0)

            if np.sum(mask_patch) < minimun_head_size:
                print('skip due to min head size')
                continue

            head_id += 1

            region = region_box[region_index]
            minr, minc, _, maxr, maxc, _ = region.bbox
            rect = mpatches.Rectangle(
                (minc, minr),
                maxc - minc,
                maxr - minr,
                fill=False,
                edgecolor='red',
                linewidth=2,
            )

            fig, axes = plt.subplots(1,4,  figsize=(13,3))
            axes[0].imshow(np.array(test_image[0]))
            axes[1].imshow(vis)
            axes[1].add_patch(rect)
            axes[1].set_title(f'#{head_id}')

            orig_patch = test_image[0][minr:maxr,minc:maxc]
            mask_patch = mask_patch[minr:maxr,minc:maxc,1]
            mask_patch = erosion(mask_patch,  footprint_rectangle((20, 20)))

            if len(orig_patch.shape) > 2:
              orig_patch = orig_patch[:,:,0]
            filt_patch = np.where(mask_patch == 1, orig_patch, 0)
            try:
              br_level = min(bg_to_remove_level, len(np.unique(filt_patch)-1))
              bg_min = np.unique(filt_patch)[br_level]
            except:
              bg_min = np.unique(filt_patch)[0]
            bg_max = bg_max_value
            norm_patch = np.where(filt_patch > bg_max, bg_max, filt_patch)
            norm_patch = np.where(norm_patch < bg_min, bg_min, norm_patch)
            norm_patch = (norm_patch-bg_min)/(bg_max-bg_min)
            axes[2].imshow(orig_patch)
            axes[2].autoscale(False)

            #norm_patch = skimage.filters.gaussian(norm_patch, sigma=5)
            #axes[4].imshow(norm_patch)

            # Comparison between image_max and im to find the coordinates of local maxima
            coordinates = peak_local_max(norm_patch, min_distance=min_distance, threshold_abs=threshold_abs, threshold_rel=0.35)
            final_coord = []
            for x in coordinates:
                if mask_patch[x[0], x[1]] != 1:
                    continue
                else:
                    final_coord.append(x)
            final_coord = np.array(final_coord)
            axes[3].imshow(orig_patch)
            axes[3].autoscale(False)
            if len(final_coord) >0:
                axes[3].plot(final_coord[:, 1], final_coord[:, 0], 'r.')

            plt.title(f'{filename}-{head_id}:{len(final_coord)}')

            plot_name = f'prediction_file_{filename}_label_{ml_mask}_numofheads_{number_of_heads}_headid_{head_id}_minheadsize_{minimun_head_size}_bg_{bg_to_remove_level}_mindist_{min_distance}_thrsh_{threshold_abs}.png'
            out_results.append({'filename': filename,
                                'label': ml_mask,
                                'num_of_heads': number_of_heads,
                                'head_id': head_id,
                                'head_size': np.sum(mask_patch),
                                'peak_cnt': len(final_coord),
                                'peak_per_60k_pixels': len(final_coord)/np.sum(mask_patch)*60000,
                                'minimun_head_size': minimun_head_size,
                                'bg_to_remove_level': bg_to_remove_level,
                                'min_distance': min_distance,
                                'threshold_abs': threshold_abs,
                                'output_image_name': plot_name,
                               })
            axes[0].axis("off")
            axes[1].axis("off")
            axes[2].axis("off")
            axes[3].axis("off")
            plt.savefig(f'{os.path.join(output_dir, plot_name)}')
            plt.close(fig)
    pd.DataFrame.from_dict(out_results).to_csv(f'{os.path.join(output_dir,f"Summary_label_{ml_mask}_mindist_{min_distance}_thrsh_{threshold_abs}")}.csv')

## Image preprocessing

In [None]:
import cv2
import glob
import os
import numpy as np

output_dir = "./processed_images"
os.makedirs(output_dir, exist_ok=True)

image_list = glob.glob(f"{raw_image_directory}/*")
for target_image in image_list:
    cv_x = cv2.imread(f'{target_image}', cv2.IMREAD_ANYDEPTH )

    cv_x = np.repeat(cv_x[..., np.newaxis], 3, -1)
    print(np.min(cv_x))
    min_value = np.min(cv_x) - brightness
    min_value = max(min_value, 0)
    cv_x = (cv_x - min_value) / (np.max(cv_x) - min_value) *255.0
    print(target_image, np.max(cv_x), np.min(cv_x))
    #cv2.imwrite(f'{os.path.join(save_root,img_id)}.image.png', img_to_write)

    cv2.imwrite(f'{os.path.join(output_dir,os.path.basename(target_image))}.png', cv_x)

In [None]:
import glob
import datasets

IMAGES = glob.glob(f"/content/processed_images/*")
IMAGES = sorted(IMAGES)
print(IMAGES)
SEG_MAPS = IMAGES
SEG_MAPS = sorted(SEG_MAPS)
target_ds = datasets.Dataset.from_dict({"pixel_values": sorted(IMAGES), "label": sorted(SEG_MAPS), }, features=datasets.Features({"pixel_values": datasets.Image(), "label": datasets.Image(),}))
target_ds

## 🤗 Model download

In [None]:
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

id2label = {1: 'head', 0:'bg'}
label2id = {label: id for id, label in id2label.items()}

num_labels = len(id2label)


model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    #num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)
model.to(device)
model.eval()

baseline_model_name = "nvidia/mit-b5"
processor = SegformerImageProcessor.from_pretrained(baseline_model_name)#, do_reduce_labels=True)


## Inference

In [None]:
import matplotlib.patches as mpatches
from skimage.measure import label, regionprops
from skimage.segmentation import clear_border
import numpy as np

test_images = []

for _i in range(len(target_ds)):
    vis_index = _i
    image = target_ds[vis_index]['pixel_values']

    inputs = processor(images=image.convert("RGB"), return_tensors="pt")
    inputs= inputs.to(device)
    outputs = model(**inputs)
    new_mask = processor.post_process_semantic_segmentation(outputs, target_sizes=[np.array(image).shape[:2]])
    m = new_mask[0].cpu().detach().numpy()

    test_images.append((np.array(target_ds[vis_index]['pixel_values']),
                       np.array(target_ds[vis_index]['label']),
                               m))

In [None]:
output_dir = "./results"
os.makedirs(output_dir, exist_ok=True)

get_results(test_images, IMAGES, output_dir, min_distance=min_distance, threshold_abs=threshold_abs, minimun_head_size=minimun_head_size)