In [2]:
import numpy as np
from PIL import Image, ImageDraw
from SMM.smm_io.Scan import Scan
from SMM.smm_io.ATF import load_gal

from skimage.filters import *
from skimage.morphology import *
from skimage.feature import *
from skimage.segmentation import *
from skimage.transform import *
from scipy.ndimage import distance_transform_edt
from skimage import draw as skdraw

In [3]:
scan = Scan.load_tif('/Volumes/RMW_3/RBD_Panel/SMM-opt/50001918_2020-12-30_S46_A1.tif')[0]
print(scan)

Scan(50001918, 532nm, Standard Green)


In [4]:
gal = load_gal('/Volumes/RMW_3/Compact-Set1.gal', map_blocks=True)
gal.ID
ids = gal.ID.str.extract(r'(\d+)-([A-P])(\d+)\Z').convert_dtypes()
ids.columns = ['Plate', 'Row', 'Column']
sentinels = gal[(ids.Plate=='30') & (ids.Row.str.contains(r'[I-P]'))]

  return pd_object.str.upper().str.replace('.', '')


In [5]:
def impose_grid(pil, gal, xpos, ypos, res):
    draw = ImageDraw.Draw(image)
    for _, x, y, r in gal[['X', 'Y', 'DIA']].itertuples():
        x = (x-xpos)//res
        y = (y-ypos)//res
        r = (r/2)//res
        draw.ellipse((x-r, y-r, x+r, y+r), outline='red')

In [19]:
def make_mask(scan, gal):
    mask = np.zeros(scan.data.shape)
    xpos = scan.x_offset
    ypos = scan.y_offset
    res = scan.resolution
    shape = scan.data.shape
    indices = []
    for _, x, y, r in gal[['X', 'Y', 'DIA']].itertuples():
        x = (x-xpos)/res
        y = (y-ypos)/res
        r = (r/2)/res
        indices.append(skdraw.disk((y, x), r, shape=shape))
    r = np.concatenate([n[0] for n in indices])
    c = np.concatenate([n[1] for n in indices])
    rmin = r.min()
    cmin = c.min()
    r -= rmin
    c -= cmin
    output = np.zeros((r.max()+1, c.max()+1), dtype=bool)
    output[r,c] = 1
    return output, rmin, cmin

# TODO magic numbers based on circle area
def register_array(scan, sentinels):
    lower_border = 8/9*scan.data.shape[0]
    thresh = scan.data[0:lower_border] > (threshold_local(scan.data[0:lower_border], 35, method='mean') * 2)
    thresh = binary_opening(thresh)
    remove_small_objects(thresh, 50, in_place=True)
    thresh = distance_transform_edt(~thresh)**2
    mask = make_mask(scan, sentinels)

    d2, rmin, cmin = match_template(thresh, ~mask)
    r, c = np.unravel_index(np.argmax(d2), d2.shape)


    
    

In [44]:
thresh = scan.data[300:6400] > (threshold_local(scan.data[300:6400], 35, method='mean') * 2)
thresh = binary_opening(thresh)
remove_small_objects(thresh, 50, in_place=True)
mask = make_mask(scan, sentinels)

d2 = match_template(thresh, mask)
r, c = np.unravel_index(np.argmax(d2), d2.shape)

out1 = Image.fromarray(d2*255).convert('RGB')
draw = ImageDraw.Draw(out1)
draw.ellipse([c-5, r-5, c+5, r+5], fill='red')
out1.show()

out2 = Image.fromarray(scan.data).convert('RGB')
draw = ImageDraw.Draw(out2)
draw.rectangle([c, r+300, c+mask.shape[1], 300+r+mask.shape[0]], outline='red')
out2.show()

In [45]:
from functools import lru_cache
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from scipy.ndimage import label as ndi_label
from skimage import transform as tf
from skimage.measure import regionprops
from skimage.morphology import binary_opening as opening


@lru_cache(maxsize=4)
def _find_bright_regions(scan, channel):
    if channel not in scan.channels:
        raise ValueError("Invalid channel %s for guide detection" % channel)
    im = scan[channel]
    im = im[0:int(5 * im.shape[0] / 6)]
    binary_image = im > (im.mean() * 1.3)  # TODO Implement a less arbitrary threshold for creating the binary image
    binary_image = opening(opening(binary_image))
    labeled_image, _ = ndi_label(binary_image)
    return regionprops(labeled_image, coordinates='rc')


@lru_cache(maxsize=8)
def get_spots(scan, channel, radius):
    radius /= scan.resolution
    regions = _find_bright_regions(scan, channel)
    result = [i.centroid for i in regions if
              3.14*(radius/2)**2 < i.area < 3.14*(radius*2)**2 and
              i.equivalent_diameter > radius > i.equivalent_diameter/4 and i.eccentricity < 0.8]
    if len(result) == 0:
        raise RuntimeError("No spots could be detected in channel %s" % channel)
    return cKDTree(np.array(result) * scan.resolution + scan.offset)


class Alignment:
    _ARRAY_TRANSFORMATIONS = [('euclidean', 10)] * 50 + [('similarity', 4)] * 50
    _BLOCK_TRANSFORMATIONS = [('euclidean', 3)] * 50 + [('affine', 2)] * 50

    def __init__(self, gal, scan, guide_names):
        if scan.image.format != "TIFF":
            raise ValueError("Alignment to non-TIFF images is not supported")
        if len(guide_names) == 0:
            raise ValueError("Guide names must be provided in at least one channel")
        if set(guide_names) - set(scan.channels):
            raise ValueError("Channel %s doesn't exist in the scan" % (set(guide_names)-set(scan.channels)))

        self.aligned = pd.DataFrame(index=gal.index, columns=['Y', 'X', 'Radius'])
        self.gal = gal
        self.scan = scan
        self.guides = guide_names

    def _current_position(self):
        current = self.gal[['Block', 'Name', 'ID', 'Y', 'X', 'Radius']].copy()
        current.update(self.aligned)
        return current

    def approximate(self):
        gal = self._current_position()
        xform = approximate_placement(gal, self.scan, self.guides)
        self.aligned[['Y', 'X']] = xform(gal[['Y', 'X']])

    def as_array(self):
        gal = self._current_position()
        xform, distance = icp_register(gal, self.scan, self.guides, Alignment._ARRAY_TRANSFORMATIONS)
        if sum(distance < 80) / len(distance) < 0.8:
            raise RuntimeError("Inadequate matching (%s/%s)" % (sum(distance < 80), len(distance)))
        if xform.rotation > 3 or abs(1-xform.scale[0])>0.05:
            raise RuntimeError("Unreasonable transformation")
        self.aligned[['Y', 'X']] = xform(gal[['Y', 'X']])

    def as_blocks(self):
        gal = self._current_position()
        failed = []
        for i, block in gal.groupby("Block"):
            try:
                xform, distance = icp_register(block, self.scan, self.guides, Alignment._BLOCK_TRANSFORMATIONS)
                if sum(distance < 80) / len(distance) < 0.8:
                    raise RuntimeError("Inadequate matching (%s/%s)" % (sum(distance < 80), len(distance)))
                if abs(1-xform.scale[0]) > 0.01 or xform.rotation > 3:
                    raise RuntimeError("Unreasonable transformation")
                self.aligned.loc[block.index, ['Y', 'X']] = xform(block[['Y', 'X']])

            except (ValueError, RuntimeError) as ex:
                failed.append("Failed block %s: %s" % (i, str(ex)))
        if failed:
            raise RuntimeError('\n'.join(failed))


def approximate_placement(gal, scan, guide_names):
    radius = gal.Radius.mean()
    gal = gal.loc[gal.matching('|'.join(guide_names.values())), ['Y', 'X']].values
    if len(gal) < 4: raise ValueError("Guide names do not match to rows in the array")
    scan_points = np.vstack(get_spots(scan, channel, radius).data for channel in guide_names)
    if len(scan_points) < 4: raise RuntimeError("Guide spots cannot be detected in scan")
    array_height, array_width = np.ptp(gal, 0)*1.01
    bottom_edge, right_edge = np.max(scan_points, 0)
    best_point = np.min(gal, 0)
    best_size = 0
    y_positions = np.sort(scan_points[:,0])
    for y_position in y_positions:
        y_dist = scan_points[:, 0] - y_position
        covered_points = scan_points[(0 <= y_dist) & (y_dist <= array_height)]
        x_positions = np.sort(covered_points[:, 1])
        for x_position in x_positions:
            x_dist = covered_points[:, 1] - x_position
            covered_count = ((0 <= x_dist) & (x_dist <= array_width)).sum()
            if covered_count > best_size:
                best_size = covered_count
                best_point = np.array([y_position, x_position])
            if x_position + array_width >= right_edge: break
        if y_position + array_height >= bottom_edge: break
    return tf.EuclideanTransform(translation=best_point - np.min(gal, 0))


def icp_register(gal, scan, guides, transformations):
    radius = gal.Radius.mean()
    models = []; trees = []

    # Pair the appropriate gal "model" points and the scan points (in the form of cKDTrees)
    for channel, name in guides.items():
        model = gal.loc[gal.matching(name), ['Y', 'X']].values.astype('float64')
        tree = get_spots(scan, channel, radius)
        if model.size != 0 and tree.data.size != 0:
            models.append(model)
            trees.append(tree)

    if len(models) == 0:
        raise RuntimeError("Pairing between guides and detected scan spots wasn't successful")

    # Estimate a transformation according to each instruction in the transformations sequence
    xform = tf.AffineTransform()
    for transform, fractional_limit in transformations:
        xform += _register_closest_point(map(xform, models), trees, transform, radius * fractional_limit)
    distance = np.concatenate([tree.query(xform(model))[0] for model, tree in zip(models, trees)])

    return tf.AffineTransform(xform.params), distance


def _register_closest_point(models, trees, ttype='euclidean', limit=np.inf):
    # Pair the points between the models and the closest detected scan points
    q = []
    t = []
    for model, scene in zip(models, trees):
        d, i = scene.query(model)
        keep = d <= limit
        q.append(model[keep])
        t.append(scene.data[i[keep]])
    current_model = np.vstack(q)
    matched_scene = np.vstack(t)
    if len(current_model) < 4:
        raise RuntimeError("Insufficient guide spot matching to scan")

    # Estimate the next best transform and update the current transform
    return tf.estimate_transform(ttype, current_model, matched_scene)
