# Devel

Developing code.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Imports
import sys
sys.path.append('/jcDataStore/Projects/NeuroTK-Dash')

import numpy as np 
import torch.nn as nn
from typing import Optional, Tuple
import large_image
from ultralytics import YOLO
import matplotlib.pyplot as plt
import cv2 as cv
from tqdm.notebook import tqdm
from shapely.geometry import Polygon
from geopandas import GeoDataFrame, GeoSeries

from neurotk import login

### YOLO Inference for DAPI Nuclear Detection.

In [8]:
# Params that will be CLIs.
args = {
    'in_file': '/jcDataStore/Data/NeuroTK-Dash/example-images/'
               '9-12-2023 E11-33 IGHM 568 GFAP FITC DAPI.lif',
    'frame': 2,
    'device': '0', 
    'max_det': 1000, 
    'conf_thr': 0.5,
    'iou_thr': 0.4,
    'fill': [0, 0, 0],
    'contained_thr': 0.6,
}

# Conver the dictionary to an object with attributes.
class ArgumentParser():
    def __init__(self, my_dict):
        for k, v in my_dict.items():
            setattr(self, k, v)
            
    
args = ArgumentParser(args)

# Load model.
model = YOLO(
    '/jcDataStore/Data/NeuroTK-Dash/models/nuclei-detection/version1/'
    'weights/best.pt'
)

In [4]:
# Girder Client
gc = login('http://glasslab.neurology.emory.edu:8080/api/v1', 
           username='jvizcar')

In [5]:
def non_max_suppression(df, thr):
    """Apply non-max suppression (nms) on a set of prediction boxes. 
    Source: https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
    
    INPUTS
    ------
    df : dataframe
        data for each box, must contain the x1, y1, x2, y2, conf columns with point 1 being top left of the box and point 2 and bottom
        right of box
    thr : float
        IoU threshold used for nms
    
    RETURN
    ------
    df : dataframe
        remaining boxes
    
    """
    df = df.reset_index(drop=True)  # indices must be reset
    dets = df[['x1', 'y1', 'x2', 'y2', 'conf']].to_numpy()
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thr)[0]
        order = order[inds + 1]
        
    return df.loc[keep]


def remove_contained_boxes(df, thr):
    """Remove boxes contained in other boxes, or mostly contained. 
    
    INPUTS
    ------
    df : geodataframe
        info about each box
    thr : float
        the threshold of the box that must be contained by fraction of area to be remove
       
    RETURNS
    -------
    df : geodataframe
        the boxes that are left
    
    """
    rm_idx = []
    
    gseries = GeoSeries(df.geometry.tolist(), index=df.index.tolist())  # convert to a geoseries
    
    for i, geo in gseries.items():
        # don't check boxes that have already been removed
        if i not in rm_idx:
            r = df.loc[i]
            
            # remove boxes that don't overlap
            overlapping = df[
                (~df.index.isin(rm_idx + [i])) & ~((r.y2 < df.y1) | (r.y1 > df.y2) | (r.x2 < df.x1) | (r.x1 > df.x2))
            ]
            
            perc_overlap = overlapping.intersection(geo).area / overlapping.area  # percent of object inside the current geo
            
            # filter by the threshold
            overlapping = overlapping[perc_overlap > thr]
            
            rm_idx.extend(overlapping.index.tolist())
            
    return df.drop(index=rm_idx)

In [9]:
# WSI inference function.
def wsi_inference(
    fp: str, model: nn.Module, mask: Optional[np.ndarray] = None, 
    frame: Optional[int] = None, mag: Optional[float] = None,
    tile_size: int = 1280, stride: int = 960, batch_size: int = 10,
    device: str = 'cpu', max_det: int = 1000, iou_thr: float = 0.6, 
    conf_thr: float = 0.5, fill: Tuple[int, int, int] = (255, 255, 255),
    contained_thr: float = 0.8 
):
    """Inference a YOLO model on a large image by tiling it into smaller 
    overlapping regions and then merging predictions.
    
    Args:
        fp (str): Filepath to image, must be openable by large_image.
        model (torch.nn.Module): Model used to predict labels.
        mask (numpy.ndarray): Optional low resolution mask used to narrow
            down the regions to analyze in the image. If None then the entire
            image is analyzed.
        frame (int): Optional frame of multiplex image to analyze.
        mag (int): Optional magnification to analyze images at.
        tile_size (int): Tile size.
        stride (int): Stride size during tiling.
        batch_size (int): Number of images to predict in bulk.
        device (str): Options are 'cuda', 'cpu', or the ids of GPU devices (i.e.
            0,1,2,etc.).
        max_det (int): Maximum number of detection per tile.
        iou (float): IoU threshold when running NMS during prediction in a tile.
        conf (float): Confidence threshold, predictions below this threshold are
            discarded.
        fill (Tuple[int, int, int]): RGB color to fill tiles regions to not be
            analyzed.
        
            
    """
    # Get image tilesource, currently throwing warnings for these images.
    ts = large_image.getTileSource(fp)
    
    ts_metadata = ts.getMetadata()
        
    if mag is not None and ts_metadata['magnification'] is not None:
        # Mult. factor: Desired magnification -> Scan magnifiction
        fr_to_mag = ts_metadata['magnification'] / mag
        fr_tile_size = int(tile_size * fr_to_mag)
        fr_stride = int(stride * fr_to_mag)
    else:
        fr_tile_size, fr_stride = tile_size, stride
    
    # Calculate the x, y coordinates of the top left of each tile.
    xys = []
    
    for y in range(0, ts_metadata['sizeY'], fr_stride):
        for x in range(0, ts_metadata['sizeX'], fr_stride):
            if mask is None:
                xys.append((x, y))
            else:
                print("Mask support not currently available, defaulting to "
                      "include all tiles.")
                # Add logic here checking if this tile is sufficiently enough
                # to include for predicting.
    
    pred_df = []  # track all predictions in dataframe
    
    # Predict on tiles in batches.
    idx = list(range(0, len(xys), batch_size))
    
    print(f'Predicting on tiles for {len(idx)} batches.')
    for i in tqdm(idx):
        imgs = []
        
        batch_xys = xys[i:i+batch_size]
        
        for xy in batch_xys:
            x, y = xy
            
            img = ts.getRegion(
                region={
                    'left': x, 'top': y, 
                    'right':x + fr_tile_size, 'bottom': y + fr_tile_size
                },
                format=large_image.constants.TILE_FORMAT_NUMPY,
                scale={'magnification': mag},
                frame=frame
            )[0]
            
            img_shape = img.shape
                        
            if img_shape[2] == 1:
                img = cv.cvtColor(img[:, :, 0], cv.COLOR_GRAY2RGB)
                                                
            # Pad the image if needed
            if img_shape[:2] != (tile_size, tile_size):
                img = cv.copyMakeBorder(
                    img, 0, tile_size - img_shape[0], 0, 
                    tile_size - img_shape[1], cv.BORDER_CONSTANT, None, fill
                )     
                
            imgs.append(img)
            
        batch_out = model.predict(
            imgs,
            device=device,
            max_det=max_det,
            iou=iou_thr,
            conf=conf_thr,
            imgsz=tile_size,
            verbose=False
        )
        
        for xy, out in zip(batch_xys, batch_out):
            x, y = xy
            
            boxes = out.boxes
            
            for label, box, cf in zip(boxes.cls, boxes.xyxy, boxes.conf):
                box = box.cpu().detach().numpy()
                label = label.cpu().detach().numpy()
                cf = cf.cpu().detach().numpy()
                
                x1, y1, x2, y2 = box[0] + x, box[1] + y, box[2] + x, box[3] + y
                        
                pred_df.append([
                    int(label), x1, y1, x2, y2, cf,
                    Polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)]), 
                    (x2-x1)*(y2-y1)
                ])
           
    # Compile boxes into dataframe.     
    pred_df = GeoDataFrame(
        pred_df, 
        columns=['label', 'x1', 'y1', 'x2', 'y2', 'conf', 'geometry', 
                 'box_area']
    )
    
    print(f"Merging overlapping boxes from a starting {len(pred_df)} boxes...")
    pred_df = non_max_suppression(pred_df, iou_thr)
    pred_df = remove_contained_boxes(pred_df, contained_thr)
    print(f'    {len(pred_df)} final predicted boxes.')
    
    return pred_df

    
pred_df = wsi_inference(
    args.in_file, 
    model,
    frame=args.frame,
    device=args.device,
    max_det=args.max_det,
    iou_thr=args.iou_thr,
    conf_thr=args.conf_thr,
    fill=args.fill,
    contained_thr=args.contained_thr
)

Predicting on tiles for 2 batches.


  0%|          | 0/2 [00:00<?, ?it/s]

Merging overlapping boxes from a starting 1620 boxes...
    894 final predicted boxes.


In [10]:
# Push results to DSA as annotations.
elements = []

for _, r in pred_df.iterrows():
    tile_w, tile_h = r.x2 - r.x1, r.y2 - r.y1
    tile_center = [(r.x2 + r.x1) / 2, (r.y2 + r.y1) / 2, 0]
    label = int(r.label)

    elements.append({
        'lineColor': 'rgb(0,255,0)',
        'lineWidth': 2,
        'rotation': 0,
        'type': 'rectangle',
        'center': tile_center,
        'width': tile_w,
        'height': tile_h,
        'label': {'value': 'nucleus'},
        'group': 'nucleus'
    })
    
_ = gc.post(
    f'/annotation?itemId=65088a9b9a8ab9ec771ba6b6', 
    json={
        'name': f'yolo-inference-test-smaller-thresholds', 
        'description': '', 
        'elements': elements
    }
)