In [15]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from scipy import ndimage
from skimage.feature import peak_local_max
from matplotlib import cm
from scipy import ndimage
from skimage.segmentation import watershed
import pickle
from tqdm.auto import tqdm
import copy

MIN = 0
MAX = 2 ** 16 - 1
NROWS = 700
NCOLS = 1100

##############################################################################
# import all images per sequence, return images
def import_sequences():
    PATH = 'sequences/0'
    imgs = []
    for i in tqdm(range(1,5), desc='Importing Images'):
        imgs_path = []
        for image in glob(f'{PATH}{i}/*.tif'):
            img = cv2.imread(image, -1)
            imgs_path.append(cv2.normalize(img, dst=None, alpha=MIN, beta=MAX,
                             norm_type=cv2.NORM_MINMAX))
        imgs.append(imgs_path)
    for i in range(4):
        dump_obj(imgs[i], f'imgs_seq_0{i+1}.p')

##############################################################################
# get white points of an image, return dict of coord tuples
def points(img):
    pts = {}
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            if img[i,j] == MAX:
                pts[(i,j)] = None
    return pts

##############################################################################
# create an image from points
def create_img(pts):
    new = np.zeros((NROWS, NCOLS))
    for (i, j) in pts:
        new[i, j] = MAX
    return new.astype('uint16')

##############################################################################
# flush all cells that touch the border of the image
def flush_border_cells(pts, img):
    to_check = []
    border_pts = {}
    border_img = np.zeros((NROWS, NCOLS))
    for pt in pts:
        if pt[0] == 0 or pt[1] == 0 or pt[0] == NROWS-1 or pt[1] == NCOLS-1:
            to_check.append(pt)
            border_pts[pt] = None
            border_img[pt[0], pt[1]] = MAX
    while len(to_check) > 0:
        b_pt = to_check.pop(0)
        for i in range(-1, 2):
            for j in range(-1, 2):
                if (b_pt[0]+i, b_pt[1]+j) in pts and\
                   (b_pt[0]+i, b_pt[1]+j) not in border_pts:
                    to_check.append((b_pt[0]+i, b_pt[1]+j))
                    border_pts[(b_pt[0]+i, b_pt[1]+j)] = None
                    border_img[b_pt[0]+i, b_pt[1]+j] = MAX
    new = img - border_img
    return points(new), new.astype('uint16')

##############################################################################
# binary erosion of points with 3x3 filter
def erode(pts):
    eroded1_pts = {}
    for (x,y) in pts:
        if (x,y-1) in pts and (x, y+1) in pts:
            eroded1_pts[(x,y)] = pts[(x,y)]
    eroded2_pts = {}
    for (x,y) in eroded1_pts:
        if (x-1,y) in eroded1_pts and (x+1, y) in eroded1_pts:
            eroded2_pts[(x,y)] = eroded1_pts[(x,y)]
    return eroded2_pts

##############################################################################
# check coords within image bounds
def check_coords(pt):
    if pt[0] < 0 or pt[0] > NROWS-1 or pt[1] < 0 or pt[1] > NCOLS-1:
        return False
    return True

##############################################################################
# binary dialtion of points with 3x3 filter
def dilate(pts):
    dilated1_pts = pts.copy()
    for (x,y) in pts:
        if (x,y-1) not in pts and check_coords((x,y-1)):
            dilated1_pts[(x, y-1)] = pts[(x,y)]
        if (x, y+1) not in pts and check_coords((x,y+1)):
            dilated1_pts[(x, y+1)] = pts[(x,y)]
    dilated2_pts = dilated1_pts.copy()
    for (x,y) in dilated1_pts:
        if (x-1,y) not in dilated1_pts and check_coords((x-1,y)):
            dilated2_pts[(x-1, y)] = dilated1_pts[(x,y)]
        if (x+1, y) not in dilated1_pts and check_coords((x+1,y)):  
            dilated2_pts[(x+1, y)] = dilated1_pts[(x,y)]
    return dilated2_pts

##############################################################################
# get point-label pairs grouped by label
def get_labels(pts):
    label = {}
    for pt, lbl in pts.items():
        if lbl not in label:
            label[lbl] = {pt: lbl}
        else:
            label[lbl][pt] = lbl
    return label

##############################################################################
# segment the given image
def segment(img):
    # CLAHE preprocess: segment from background
    clahe = cv2.createCLAHE(clipLimit=80.0, tileGridSize=(25,25))
    cl1 = clahe.apply(img)

    hist = cv2.calcHist([cl1],[0],None,[257],[0,MAX])
    from_max = hist[np.argmax(hist):]
    for intensity, _ in enumerate(from_max[:-1]):
        if from_max[intensity+1] >= from_max[intensity]:
            for i in range(30):
                if intensity + i >= len(from_max):
                    break
                if from_max[intensity + i] < intensity - 250:
                    break
            else:
                break

    final = (intensity + np.argmax(hist)) * 255
    seg = np.zeros_like(cl1)
    cv2.threshold(cl1, dst=seg, thresh=final, maxval=MAX,
                  type=cv2.THRESH_BINARY)
    opens = np.zeros_like(seg)
    elem = cv2.getStructuringElement(shape=cv2.MORPH_RECT, ksize=(5,5))
    cv2.morphologyEx(seg, dst=opens, op=cv2.MORPH_OPEN, kernel=elem)

    opens = create_img(erode(points(opens)))

    flushed_pts, _ = flush_border_cells(points(opens), opens)

    # watershed: segment cells from each other
    distance = ndimage.distance_transform_edt(opens)
    coords = peak_local_max(distance, footprint=np.ones((5,5)),
                            min_distance = 20, labels = opens)
    mask = np.zeros(distance.shape, dtype = bool)
    mask[tuple(coords.T)] = True
    markers, _ = ndimage.label(mask)
    ws_labels = watershed(-distance, markers, mask = opens)
    ws_pts = {}
    for i in range(opens.shape[0]):
        for j in range(opens.shape[1]):
            if ws_labels[i,j] == 0:
                continue
            ws_pts[(i,j)] = ws_labels[i,j]-1

    # seperate open all cells
    labels = get_labels(ws_pts)
    new_pts = {}
    for label, cell_pts in labels.items():
        dilated = cell_pts.copy()
        for _ in range(3):
            dilated = dilate(dilated)

        eroded = dilated.copy()
        for _ in range(3):
            eroded = erode(eroded)

        for pt, label in eroded.items():
            new_pts[pt] = label

    # mix local labels
    all_pts = new_pts.copy()
    cents = centroids(new_pts)
    len(cents)
    shuffled = np.arange(0, len(cents))
    np.random.shuffle(shuffled)
    swap_labels = {}
    for i, label in enumerate(cents.values()):
        swap_labels[label] = shuffled[i]
    pts = {}
    for pt, label in all_pts.items():
        pts[pt] = swap_labels[label]
    
    # delete small cells
    labels = get_labels(pts)
    temp_pts = pts.copy()
    del_labels = []
    for label,pts1 in labels.items():
        if len(pts1) < 50:
            del_labels.append(label)
    for pt, label in pts.items():
        if label in del_labels: 
            del temp_pts[pt]

    # adjust flushed points to reflect changes
    temp_flushed_pts = flushed_pts.copy()
    for pt in flushed_pts.keys():
        if pt not in temp_pts:
            del temp_flushed_pts[pt]
        else:
            temp_flushed_pts[pt] = temp_pts[pt]
    return temp_pts, temp_flushed_pts

##############################################################################
# dump given object to filepath
def dump_obj(obj, filepath):
    with open(filepath, 'wb') as outfile:
        pickle.dump(obj, outfile)

##############################################################################
# load given object from filepath
def load_obj(filepath):
    with open(filepath, 'rb') as infile:
        return pickle.load(infile)

##############################################################################
# load an image sequence
def load_imgs(seq_num):
    return load_obj(f'imgs_seq_0{seq_num}.p')

##############################################################################
# segment the images and store the data
def segment_store():
    for sdx in range(1, 5):
        sequence = load_imgs(sdx)
        seg = {}
        seg_flushed = {}
        for fdx, frame in enumerate(tqdm(sequence,
                                         desc=f'Segmenting sequence 0{sdx}')):
            seg[fdx], seg_flushed[fdx] = segment(frame)
        dump_obj(seg, f'segmented_0{sdx}.p')
        dump_obj(seg_flushed, f'segmented_flushed_0{sdx}.p')

##############################################################################
# from points and labels create a coloured image
def colour_label_image(pts_labels, max):
    cmap = copy.copy(cm.get_cmap('hsv'))
    new_img = np.zeros((NROWS, NCOLS))
    cmap.set_bad(color='black')
    for (pt, label) in pts_labels.items():
        new_img[pt[0],pt[1]] = label
    norm_new = plt.Normalize(new_img.min(), max)
    for i in range(NROWS):
        for j in range(NCOLS):
            if (i,j) not in pts_labels:
                new_img[i,j] = np.nan
    rgba_img = cmap(norm_new(new_img))
    return rgba_img, cmap, norm_new

##############################################################################
# plot scatter of centroids over the colour labelled image
def plt_rgb_img(centroids,pts_labels,traj,max,save,seq_data,num_cells,
    avg_size,avg_disp, num_div,og_img, splitting, alert):
    rgb_img, cmap, norm_new = colour_label_image(pts_labels, max)
    plt.figure(figsize=(15,15))
    new = cv2.cvtColor(og_img, cv2.COLOR_GRAY2RGB)
    new = (new * (255/MAX)).astype('uint8')
    for i in range(NROWS):
        for j in range(NCOLS):
            if (i,j) in pts_labels.keys():
                new[i,j,:] = (rgb_img[i,j,:-1] * 255).astype('uint8')
            elif (i,j) in alert.keys():
                new[i,j,:] = np.ones((3)) * 255
    plt.imshow(new)
    for label, trajectories in traj.items():
        plt.plot(trajectories[1], trajectories[0], c=cmap(norm_new(label)))
    plt.title(f'Sequence 0{seq_data[0]}: frame {seq_data[1]:02d}')
    plt.text(0,750, f'Cell count: {num_cells:02d}')
    plt.text(250, 750, f'Average cell size: {avg_size}')
    if avg_disp == -1:
        plt.text(500,750, 'Average displacement: N/A')
    else:
        plt.text(500,750, f'Average displacement: {avg_disp}')
    plt.text(750,750, f'Currently dividing: {num_div}')
    plt.savefig(save)
    plt.close()

##############################################################################
# get cell contours (outlines) from points
def contours(pts):
    dilated2_pts = dilate(pts)
    sub = {}
    for pt, label in dilated2_pts.items():
        if pt not in pts:
            sub[pt] = label
    return sub

##############################################################################
# get centroids from points
def centroids(pts):
    labels = {}
    for (x,y), label in pts.items():
        if label not in labels:
            labels[label] = {'x': [x], 'y': [y]}
            continue
        labels[label]['x'].append(x)
        labels[label]['y'].append(y)
    centres = {}
    for label, data in labels.items():
        centres[int(sum(data['x'])/len(data['x'])), 
                int(sum(data['y'])/len(data['y']))] = label
    return centres

##############################################################################
# get euclidean distances between the centroids of two frames
def distance_matrix(cent1, cent2):
    dist_mat = np.zeros((len(cent1), len(cent2)))
    rows, cols = {}, {}
    for i, pt1 in enumerate(cent1.keys()):
        rows[i] = pt1
        for j, pt2 in enumerate(cent2.keys()):
            dist_mat[i,j] = np.sqrt(np.square(pt1[0]-pt2[0]) + np.square(pt1[1]-pt2[1]))
    for j, pt2 in enumerate(cent2.keys()):
        cols[j] = pt2
    return dist_mat, rows, cols

##############################################################################
# tracking
def nearest_neighbour(seq, show_img=True, seq_num=None, flushed_pts=None, imgs=None):
    # intialise values and tracking data
    init_centroid = centroids(seq[0])
    cent_global_labels = {0: init_centroid.copy(), -1: 1+ max(list(init_centroid.values()))}

    cent_local_labels = {0: init_centroid.copy()}

    displacement = {0: {}}
    for label in cent_global_labels[0].values():
            displacement[0][label] = 0

    # distance matrices
    dist_mats = {}

    # centroids of cells about to split
    split = {}

    for t in tqdm(range(len(seq)-1), desc=f'Global labels sequence 0{seq_num}'): 
        cent0, cent1 = centroids(seq[t]), centroids(seq[t+1])
        dist_mat, rows, cols = distance_matrix(cent0, cent1)
        dist_mats[t] = dist_mat.copy()
        cent_local_labels[t+1] = cent1.copy()

        # get global labels for centroids
        cent_global_labels[t+1] = {}
        displacement[t+1] = {}
        new_cells = cent1.copy()
        thresh = 50
        for i, row in enumerate(dist_mat):
            if np.amin(row) > thresh:
                continue
            if cols[np.argmin(row)] in cent_global_labels[t+1].keys():
                if np.amin(row) < displacement[t+1][cent_global_labels[t+1][cols[np.argmin(row)]]]:
                    cent_global_labels[t+1][cols[np.argmin(row)]] = cent_global_labels[t][rows[i]]
                    displacement[t+1][cent_global_labels[t][rows[i]]] = np.amin(row)
            else:
                cent_global_labels[t+1][cols[np.argmin(row)]] = cent_global_labels[t][rows[i]]
                displacement[t+1][cent_global_labels[t][rows[i]]] = np.amin(row)
            if cols[np.argmin(row)] in new_cells:
                del new_cells[cols[np.argmin(row)]]
        
        for new in new_cells.keys():
            cent_global_labels[t+1][new] = cent_global_labels[-1]
            cent_global_labels[-1] += 1

        
        # get potential cell splits
        split[t] = []
        closest = {}
        for j, col in enumerate(dist_mat.T):
            for i in np.argsort(col):  
                if col[i] <= thresh:
                    if np.amin(dist_mat[i]) > 10:
                        close = False
                    else:
                        close = True           
                    area = 0
                    for label in seq[t+1].values():
                        if label == cent_local_labels[t+1][cols[j]]:
                            area += 1
                    if i not in closest:
                        closest[i] = [(area, close)]
                    else:
                        closest[i].append((area, close))
                elif col[i] > thresh:
                    break
        
        temp_closest = closest.copy()
        for i, areas in temp_closest.items():
            if len(areas) <= 2:
                continue
            closest[i] = []
            for area in areas:
                if not area[1]:
                    closest[i].append(area)

        for i, areas in closest.items():
            old_area = 0
            for label in seq[t].values():
                if label == cent_local_labels[t][rows[i]]:
                    old_area += 1
            for m, area1 in enumerate(areas):
                for k, area2 in enumerate(areas):
                    if m == k:
                        continue
                    if old_area * 0.5 <= area1[0] + area2[0] <= old_area * 1.25:
                        if rows[i] not in split[t]:
                            split[t].append(rows[i])
                        break

    # checking that very faint cells are omitted if they are only segmented 
    # sporadically in the background, not if near other cells as causes collsions
    # get label count up to required frame
    frames = 3
    label_count = {}
    for t in range(len(seq)-frames):
        for label in cent_global_labels[t].values():
            if label not in label_count:
                label_count[label] = 1
            else:
                label_count[label] += 1

    # add potential noise labels to noise
    noise = []
    for label, count in label_count.items():
        if count <= frames:
            for t in range(len(seq)-frames, len(seq)):
                if label in cent_global_labels[t].values():
                    break
            else:
                noise.append(label)

    # get label count for last frames
    label_count = {}
    for t in range(len(seq)-frames, len(seq)):
        for label in cent_global_labels[t].values():
            if label not in label_count:
                label_count[label] = 1
            else:
                label_count[label] += 1

    # add potential noise labels to noise
    for label, count in label_count.items():
        if count <= 1:
            noise.append(label)

    # points that are close to other points must not be removed
    for t in range(len(seq)):
        cent = centroids(seq[t])
        dist_mat, rows, cols = distance_matrix(cent, cent)
        for m in range(dist_mat.shape[0]):
            for n in range(dist_mat.shape[1]):
                if m == n:
                    dist_mat[m,n] = np.inf
        for i, row in enumerate(dist_mat):
            if np.amin(row) <= thresh and \
               cent_global_labels[t][rows[i]] in noise:
                noise.remove(cent_global_labels[t][rows[i]])

    # delete these points from the data, they are background noise
    temp_cent_global_labels = cent_global_labels.copy()
    for del_label in noise:
        for t in range(len(seq)):
            if del_label in displacement[t]:
                del displacement[t][del_label]
            if t < len(seq)-1:
                temp_split = split[t].copy()
                for cent in temp_split:
                    if cent_global_labels[t][cent] == del_label:
                        split[t].remove(cent)
            for cent, label in list(temp_cent_global_labels[t].items()):
                if label == del_label:
                    del cent_global_labels[t][cent]

    # if cells are split, reconcile future labels to reflect this change
    # for each frame loop through splitting cell centroids
    for t in range(len(seq)-1):
        for split_cent in split[t]:
            # search for the splitting cell label in the next frame
            for label in cent_global_labels[t+1].values():
                # if label is found
                if label == cent_global_labels[t][split_cent]:
                    # loop through frames onwards and change labels
                    for f in range(t+1, len(seq)):
                        for cent2, label2 in cent_global_labels[f].items():
                            if label2 == label:
                                cent_global_labels[f][cent2] = cent_global_labels[-1]
                                displacement[f][cent_global_labels[-1]] = 0
                                #if label2 in displacement[f]:
                                del displacement[f][label2]
                                #break
                    cent_global_labels[-1] += 1
                    break

    # mix global centroid labels
    shuffled = list(range(cent_global_labels[-1]))
    np.random.shuffle(shuffled)
    for t in range(len(seq)):
        for pt, label in cent_global_labels[t].items():
            cent_global_labels[t][pt] = shuffled[label]
    new_disp = {}
    for t in range(len(seq)):
        new_disp[t] = {}
        for label, dist in displacement[t].items():
            new_disp[t][shuffled[label]] = dist
    displacement = new_disp.copy()

    # track centroid positions over time
    centroid_paths = {}
    for pt, label in cent_global_labels[0].items():
        centroid_paths[label] = {}
        centroid_paths[label][0] = pt

    # get cell contours
    seq_contours = {}
    for t, pts in enumerate(seq.values()):
        seq_contours[t] = contours(pts)

    # global labels for cell contours
    global_labels = {}
    global_labels[0] = {}
    for pt, local in seq_contours[0].items():
        for cent, global_label in cent_global_labels[0].items():
            if cent_local_labels[0][cent] == local:
                global_labels[0][pt] = global_label
                continue

    # piecewise cell trajectories
    traj = {}
    for pt, label in cent_global_labels[0].items():
        traj[label] = {}
        traj[label][0] = [[pt[0]],[pt[1]]]

    for t in range(len(seq)-1):
        # update paths of centroids
        for pt, label in cent_global_labels[t].items():
            if label not in centroid_paths:
                centroid_paths[label] = {}
            centroid_paths[label][t+1] = pt
        
        # get global labels for each point of cell contours
        global_labels[t+1] = {}
        for pt, local in seq_contours[t+1].items():
            for cent, global_label in cent_global_labels[t+1].items():
                if cent_local_labels[t+1][cent] == local:
                    global_labels[t+1][pt] = global_label
                    continue

        # update trajectories for centorids in image
        for pt, label in cent_global_labels[t+1].items():
            if label not in traj:
                traj[label] = {}
                traj[label][t+1] = [[pt[0]],[pt[1]]]
            else:
                traj[label][t+1] = [traj[label][t][0]+[pt[0]],traj[label][t][1]+[pt[1]]]

    # cell extrema, min-max order
    extrema = {}
    for t in range(len(seq)):
        extrema[t] = {}
        for pt, label in global_labels[t].items():
            for cent, cent_label in cent_global_labels[t].items():
                if cent_label == label:
                    break
            dist = ((cent[0]-pt[0]) ** 2 + (cent[1]-pt[1]) ** 2) ** 0.5
            if cent not in extrema[t]:
                extrema[t][cent] = [dist, dist]
            if dist < extrema[t][cent][0]:
                extrema[t][cent][0] = dist
            elif dist > extrema[t][cent][1]:
                extrema[t][cent][1] = dist
        # add potential splits to list
        if t not in split.keys():
            split[t] = []
        for cent, dists in extrema[t].items():
            # cell meets criteria
            if dists[0]/dists[1] < 0.3:
                # check cell splits in the window
                for i, f in enumerate(range(t, len(seq))):
                    if i >= 5:
                        break
                    # cell label non-existant in window, add split to all frames
                    if cent_global_labels[t][cent] not in cent_global_labels[f].values():
                        dist_mat, rows, cols = distance_matrix(cent_global_labels[f-1], cent_global_labels[f])
                        for r, row in enumerate(dist_mat):
                            if cent_global_labels[f-1][rows[r]] == cent_global_labels[t][cent]:
                                break
                        args = np.argsort(row)
                        dist1, dist2 = row[args[0]], row[args[1]]
                        if 0.7 <= dist1/dist2 <= 1.3:
                            for k in range(t, t+i):
                                for cent2, label2 in cent_global_labels[k].items():
                                    if label2 == cent_global_labels[t][cent]:
                                        if cent2 not in split[k]:
                                            split[k].append(cent2)
                        break

    if show_img:
        for t in range(len(seq)):
            valid_traj = {}
            for label, times in traj.items():
                if label not in cent_global_labels[t].values():
                    continue
                valid_traj[label] = traj[label][t]
            if seq_num == None:
                plt_rgb_img(cent_global_labels[t], global_labels[t], traj=valid_traj, 
                max=cent_global_labels[-1])
            else:
                flushed_cents = centroids(flushed_pts[t])
                if t == 0:
                    avg_disp = None
                else:
                    avg_disp = 0
                    for disp in displacement[t].values():
                        avg_disp += disp
                    avg_disp = round(avg_disp/len(cent_global_labels[t]), 2)
                dividing = {}
                for split_cent in split[t]:
                    for pt, label in global_labels[t].items():
                        if label == cent_global_labels[t][split_cent]:
                            dividing[pt] = 0
                alert_pts = dilate(dilate(dividing))
                plt_rgb_img(cent_global_labels[t], global_labels[t], traj=valid_traj, 
                max=cent_global_labels[-1], save=f'track_{seq_num}_{t}.png',
                seq_data=(seq_num, t), num_cells=len(cent_global_labels[t]),
                avg_size=round(len(flushed_pts[t])/len(flushed_cents), 2),
                avg_disp=avg_disp, num_div=len(split[t]), og_img=imgs[t],
                splitting=split[t], alert=alert_pts)
    return global_labels, split, traj

##############################################################################
# track and store the images
def track_store():
    for i in range(1,5):
        imgs = load_imgs(i)
        seq = load_obj(f'segmented_0{i}.p')
        flushed = load_obj(f'segmented_flushed_0{i}.p') 
        global_labels, splits, traj = nearest_neighbour(seq, True, i, flushed, imgs)
        dump_obj(global_labels, f'global_labels_0{i}.p')
        dump_obj(splits, f'splits_0{i}.p')
        dump_obj(traj, f'traj_0{i}.p')

In [4]:
import_sequences()

Importing Images: 100%|██████████| 4/4 [00:12<00:00,  3.23s/it]


In [12]:
segment_store()

Segmenting sequence 01: 100%|██████████| 92/92 [06:09<00:00,  4.01s/it]
Segmenting sequence 02: 100%|██████████| 92/92 [08:31<00:00,  5.56s/it]
Segmenting sequence 03: 100%|██████████| 92/92 [06:47<00:00,  4.43s/it]
Segmenting sequence 04: 100%|██████████| 92/92 [07:30<00:00,  4.90s/it]


In [16]:
track_store()

Global labels sequence 01: 100%|██████████| 91/91 [05:01<00:00,  3.31s/it]
Global labels sequence 02: 100%|██████████| 91/91 [1:02:27<00:00, 41.18s/it]
Global labels sequence 03: 100%|██████████| 91/91 [13:30<00:00,  8.91s/it]
Global labels sequence 04: 100%|██████████| 91/91 [27:23<00:00, 18.06s/it]
