In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import os
import seaborn as sns

from IPython.display import clear_output
import gudhi as gd  
from scipy.spatial.distance import cdist
from scipy.stats import mannwhitneyu

from skimage.io import imread, imsave
from skimage.color import rgb2gray, rgb2hsv
from skimage.draw import ellipse
from skimage.filters import threshold_otsu
from skimage.measure import label, moments, regionprops
from skimage.morphology import binary_dilation
from skimage.segmentation import mark_boundaries

from sklearn.decomposition import PCA

def draw(clone_names, ns, nys, name, dtype=None, binwidth=None, kde=True, discrete=True):
    tab = np.array([[clone_names[c], ns[nys == c][r]] for c in range(len(clone_names)) for r in range(ns[nys == c].shape[0])])
    df = pd.DataFrame(tab, columns=['clone', name])
    if dtype is not None:
        df = df.astype({name: dtype})
    if dtype is not None:
        sns.displot(data=df, x=name, hue='clone', binwidth=binwidth, kde=kde, linewidth=0)
    else:
        sns.displot(data=df, x=name, hue='clone', kde=kde, linewidth=0)

def print_pvalues(clone_names, ns, nys):
    print('number of bacteria per patch')
    pvalues = np.zeros((len(clone_names), len(clone_names)))
    for c1 in range(len(clone_names)):
        for c2 in range(len(clone_names)):
            pvalues[c1, c2] = mannwhitneyu(ns[nys == c1], ns[nys == c2]).pvalue
    print(pvalues)
    print(pvalues < 0.01)


In [None]:
dirname = '/Users/bartoszzielinski/Databases/strains/mil_top_patches'
csvname, clone_names = ('ABC_TT_top_patches.csv', ['A', 'B', 'C'])
train_test = ['train', 'test']
dimension = 0


In [None]:
# 1. generate grayscale patches
if False:
    for tt in train_test:
        csvname_tt = csvname.replace('TT', tt)
        data = pd.read_csv('{}/{}'.format(dirname, csvname_tt))

        for i, filename in enumerate(data.path):
            print('{}-'.format(len(data.path) - i), end='')

            gs_filename = '{}/{}/gs/{}'.format(dirname, tt, filename)
            gs_dirname = os.path.dirname(gs_filename)
            if not os.path.exists(gs_dirname):
                os.makedirs(gs_dirname)
            im = imread('{}/{}/original/{}'.format(dirname, tt, filename))
            grayscale = np.uint8(255 - rgb2gray(im) * 255)
            imsave('{}/{}/gs/{}'.format(dirname, tt, filename), grayscale)
    print("Success!")


In [None]:
# 2. run CellProfiler

In [None]:
# 3. reject improper segmentations and show them
for tt in train_test:
    csvname_tt = csvname.replace('TT', tt)
    data = pd.read_csv('{}/{}'.format(dirname, csvname_tt))

    if False:
        to_reject = []
        for i, filename in enumerate(data.path):

            print('{}-'.format(len(data.path) - i), end='')
            gs_filename = '{}/{}/gs/{}'.format(dirname, tt, filename)
            segment_filename = '{}/{}/gs/{}_segment.npy'.format(dirname, tt, filename[:-4])

            im = imread('{}/{}/original/{}'.format(dirname, tt, filename))
            gs = imread(gs_filename)
            segment = np.load(segment_filename)

            if segment.dtype == np.float32:
                plt.subplot(1, 2, 1)
                plt.imshow(im)
            else:
                plt.subplot(1, 2, 1)
                plt.imshow(mark_boundaries(im, segment))
            plt.title('{}/{}'.format(i, len(data.path)))
            plt.subplot(1, 2, 2)
            plt.imshow(segment)
            plt.show()

            key = input('bad - 0')
            clear_output()
            if key == '0':
                to_reject.append(i)
        to_reject = np.array(to_reject)

        np.save('{}/{}_to_reject.npy'.format(dirname, csvname_tt[:-4]), to_reject)
    else:
        to_reject = np.load('{}/{}_to_reject.npy'.format(dirname, csvname_tt[:-4]))

    for c in range(len(clone_names)):
        print('Filtered out {}: {} / {}'.format(clone_names[c],
                                                (data.y_true[to_reject] == c).sum(),
                                                (data.y_true == c).sum()))

    k = np.ceil(np.sqrt(len(to_reject)))
    plt.figure(figsize=(k * 4, k * 4))
    for i in range(len(to_reject)):
        filename = data.path[to_reject[i]]
        im = imread('{}/{}/original/{}'.format(dirname, tt, filename))

        plt.subplot(k, k, i + 1)
        plt.imshow(im)
        plt.axis('off')
        plt.title(clone_names[data.y_true[to_reject[i]]])


In [None]:
# 4a. statistics on bacteria and connected components - prepare data
ims = []
objs = []
ns = []
nys = []
ps = []
for tt in train_test:
    csvname_tt = csvname.replace('TT', tt)
    data = pd.read_csv('{}/{}'.format(dirname, csvname_tt))
    to_reject = np.load('{}/{}_to_reject.npy'.format(dirname, csvname_tt[:-4]))

    for i, filename in enumerate(data.path):
        if i in to_reject:
            continue

        # print('{}-'.format(len(data.path) - i), end='')
        im = imread('{}/{}/original/{}'.format(dirname, tt, filename))
        obj = np.load('{}/{}/gs/{}_segment.npy'.format(dirname, tt, filename[:-4]))
        n = int(obj.max())

        ims.append(im)
        objs.append(obj)
        ns.append(n)
        nys.append(data.y_true[i])  # MAYBE SHOULD BE y_pred?
        ps.append(data.A[i])

ims = np.stack(ims)
objs = np.stack(objs).astype(np.uint8)
ns = np.stack(ns)
nys = np.stack(nys)
ps = np.stack(ps)

# show number of bacteria per patch
plt.figure()
draw(clone_names, ns, nys, 'number of bacteria per patch', dtype='int32', binwidth=5)
plt.ylabel('density')


In [None]:
# 4b. statistics on bacteria and connected components - count number of connected components
inter_objs_clones_sample  = [[] for c in range(len(clone_names))]
inter_ims_clones_sample = [[] for c in range(len(clone_names))]

step = 5
for s in range(0, ns.max(), step):
    limit_inter = np.logical_and(ns >= max(1, s), ns < s + step)
    take = [np.logical_and(limit_inter, nys == c).sum() for c in range(len(clone_names))]
    take_min = np.min(take)

    if take_min > 0:
        for c in range(len(clone_names)):
            limit_inter_clone = np.logical_and(limit_inter, nys == c)

            perm = np.random.choice(limit_inter_clone.sum(), take_min, replace=False)
            inter_objs_clones_sample[c].append(objs[limit_inter_clone][perm])
            inter_ims_clones_sample[c].append(ims[limit_inter_clone][perm])

for c in range(len(clone_names)):
    inter_objs_clones_sample[c] = np.concatenate(inter_objs_clones_sample[c])
    inter_ims_clones_sample[c] = np.concatenate(inter_ims_clones_sample[c])

    perm = np.random.permutation(inter_objs_clones_sample[c].shape[0])
    inter_objs_clones_sample[c] = inter_objs_clones_sample[c][perm]
    inter_ims_clones_sample[c] = inter_ims_clones_sample[c][perm]


In [None]:
# 4c. statistics on bacteria and connected components
ccs_sample = []
ns_sample = []
ys_sample = []
for c in range(len(clone_names)):
    for obj in inter_objs_clones_sample[c]:
        obj_cc = label(obj > 0)  # connected components
        ns_sample.append(obj.max())
        ccs_sample.append(obj_cc.max())
        ys_sample.append(c)

ccs_sample = np.stack(ccs_sample)
ns_sample = np.stack(ns_sample)
ys_sample = np.stack(ys_sample)

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(obj)
plt.axis('off')
plt.title('Bacteria segmentation')
plt.subplot(1, 2, 2)
plt.imshow(obj_cc)
plt.axis('off')
plt.title('Connected component')

plt.figure()
draw(clone_names, ns_sample, ys_sample, 'number of bacteria per patch', dtype='int32', binwidth=5)
plt.ylabel('density')
print('number of bacteria per patch')
print_pvalues(clone_names, ns_sample, ys_sample)

plt.figure()
draw(clone_names, ccs_sample, ys_sample, 'number of connected components per patch', dtype='int32', binwidth=1)
plt.ylabel('density')
print('number of connected components per patch')
print_pvalues(clone_names, ccs_sample, ys_sample)


In [None]:
# 5a. sparsity - prepare persistence
pcims = []
pcobjs = []
pcs = []
pds_0 = []
pcys = []

for c in range(len(clone_names)):
    for o in range(len(inter_objs_clones_sample[c])):
        obj = inter_objs_clones_sample[c][o]
        im = inter_ims_clones_sample[c][o]

        regs = regionprops(obj)
        pc = np.stack([reg.centroid for reg in regs])

        if pc.shape[0] > 1:
            skeleton = gd.RipsComplex(distance_matrix=cdist(pc, pc))
            Rips_simplex = skeleton.create_simplex_tree(max_dimension=2)
            persistence = Rips_simplex.persistence()

            persistence_dim_no_inf = [p[1] for p in persistence if p[0] == dimension and p[1][1] != np.inf]
            if len(persistence_dim_no_inf) > 0:
                persistence_0 = np.stack(persistence_dim_no_inf)
                persistence_0_hist, _ = np.histogram(persistence_0[:, 1], bins=np.arange(0, 125, 1))

                pcims.append(im)
                pcobjs.append(obj)
                pcs.append(pc)
                pds_0.append(persistence_0_hist)
                pcys.append(c)

pcims = np.stack(pcims)
pcobjs = np.stack(pcobjs)
pds_0 = np.stack(pds_0)
pcys = np.stack(pcys)

for lim in [[0, 124], [10, 28]]:
    plt.figure()
    for c in range(len(clone_names)):
        sns.lineplot(x="variable", y="value", data=pd.DataFrame(pds_0[pcys == c, :]).melt(), ci=68)
    plt.legend(labels=clone_names)
    plt.xlabel('death time')
    plt.ylabel('mean number')
    plt.title('Persistence (dimension 0)')
    plt.xlim(lim)


In [None]:
# 5b. sparsity
for dt in [[15, 25]]:
    for lim in [[35, 45]]:
        order = np.argsort(pds_0[:, dt[0]:dt[1]].mean(axis=1))
        order = np.array([o for o in order if lim[0] < pcobjs[o].max() < lim[1]])

        intervals = 10
        samples = 5
        skip = order.shape[0] // intervals
        for j in range(samples):
            plt.figure(figsize=(4 * intervals, 4))
            for i in range(intervals):
                index = order[i * skip + j]
                plt.subplot(1, intervals, i + 1)
                plt.imshow(pcims[index])
                plt.axis('off')
                plt.title('{} ({})'.format(clone_names[pcys[index]], pcobjs[index].max()))

        print("patches in limit:", len(order))


In [None]:
# 6a. cell statistics - generate patches with separate cells
if False:
    X = []
    XM = []
    XN = []
    y = []
    B = 16

    for tt in train_test:
        csvname_tt = csvname.replace('TT', tt)
        data = pd.read_csv('{}/{}'.format(dirname, csvname_tt))
        to_reject = np.load('{}/{}_to_reject.npy'.format(dirname, csvname_tt[:-4]))

        for i, filename in enumerate(data.path):
            if i in to_reject:
                continue

            print('{}-'.format(len(data.path) - i), end='')
            im = imread('{}/{}/original/{}'.format(dirname, tt, filename))
            obj = np.load('{}/{}/gs/{}_segment.npy'.format(dirname, tt, filename[:-4]))

            n = int(obj.max())
            something_added = False
            for o in range(1, n):
                mask = obj == o

                M = moments(mask)
                cX = int(M[1, 0] / M[0, 0])
                cY = int(M[0, 1] / M[0, 0])

                if B < cX < im.shape[0] - B and B < cY < im.shape[1] - B:
                    something_added = True

                    patch  = im[cX - B:cX + B, cY - B:cY + B, :]
                    patch_mask  = mask[cX - B:cX + B, cY - B:cY + B]
                    patch_obj = obj[cX - B:cX + B, cY - B:cY + B] * (1 - patch_mask)
                    patch_obj_unique = np.sort(np.unique(patch_obj))
                    if patch_obj_unique[0] == 0:
                        patch_obj_unique = patch_obj_unique[1:]
                    for j in range(len(patch_obj_unique)):
                        patch_obj[patch_obj == patch_obj_unique[j]] = j + 1

                    X.append(patch)
                    XM.append(patch_mask)
                    XN.append(patch_obj)
                    y.append(data.y_true[i])  # MAYBE SHOULD BE y_pred?

    X = np.stack(X)
    XM = np.stack(XM)
    XN = np.stack(XN)
    y = np.stack(y)
    X = np.einsum('ijkl->iljk', X)

    np.save('{}/X.npy'.format(dirname, csvname[:-4]), X)
    np.save('{}/XM.npy'.format(dirname, csvname[:-4]), XM)
    np.save('{}/XN.npy'.format(dirname, csvname[:-4]), XN)
    np.save('{}/y.npy'.format(dirname, csvname[:-4]), y)
else:
    X = np.load('{}/X.npy'.format(dirname, csvname[:-4]))
    XM = np.load('{}/XM.npy'.format(dirname, csvname[:-4]))
    XN = np.load('{}/XN.npy'.format(dirname, csvname[:-4]))
    y = np.load('{}/y.npy'.format(dirname, csvname[:-4]))

    
print('number of cells {}'.format([(y == c).sum() for c in range(len(clone_names))]))

plt.figure(figsize=(12, 6))
for i in range(6):
    ri = np.random.randint(X.shape[0])
    plt.subplot(3, 6, 3 * i + 1)
    plt.imshow(np.einsum('ljk->jkl', X[ri]))
    plt.subplot(3, 6, 3 * i + 2)
    plt.imshow(XM[ri])
    plt.subplot(3, 6, 3 * i + 3)
    plt.imshow(XN[ri])


In [None]:
# 6b. cell statistics - number of neighbors patches
nn = [XN[i].max() for i in range(XN.shape[0])]
nn = np.array(nn)

plt.figure()
draw(clone_names, nn, y, 'number of neighbors', dtype='int32', binwidth=1, kde=False)

for j in range(3):
    plt.figure(figsize=(12, 6))

    nn_lim_indices = np.where(nn == j)[0]
    nn_lim_indices = np.random.permutation(nn_lim_indices)

    for i in range(6):
        ri = nn_lim_indices[i]
        plt.subplot(3, 6, 3 * i + 1)
        plt.imshow(np.einsum('ljk->jkl', X[ri]))
        plt.subplot(3, 6, 3 * i + 2)
        plt.imshow(XM[ri])
        plt.subplot(3, 6, 3 * i + 3)
        plt.imshow(XN[ri])


In [None]:
# 6b. cell statistics - limit patches
NEIGHBOURS = 1
nn_lim_indices = np.where(nn <= NEIGHBOURS)[0]
nn_lim_indices = np.random.permutation(nn_lim_indices)

X_lim = X[nn_lim_indices]
XM_lim = XM[nn_lim_indices]
XN_lim = XN[nn_lim_indices]
y_lim = y[nn_lim_indices]


In [None]:
# 6c. cell statistics - define correction methods and show examples
def get_ellipse_mask(mask, details=False):
    mask = mask.astype(np.int64)
    props = regionprops(mask)[0]

    r, c = props.centroid
    orientation = props.orientation

    mask_e = np.zeros(mask.shape).astype(np.int64)
    rr, cc = ellipse(r, c, props.major_axis_length // 2, props.minor_axis_length // 2, rotation=props.orientation)

    limit = np.logical_and(0 <= rr, rr < mask.shape[0])
    limit = np.logical_and(limit, 0 <= cc)
    limit = np.logical_and(limit, cc < mask.shape[1])
    rr, cc = rr[limit], cc[limit]
    mask_e[rr, cc] = 1

    if details:
        return mask_e, props.minor_axis_length, props.major_axis_length
    else:
        return mask_e


def correct_mask(im_gray, mask):
    s = mask.shape[0] // 2
    mask_dilated = binary_dilation(mask, selem=np.ones((5, 5)))

    im_gray_masked = im_gray * mask_dilated
    threshold = threshold_otsu(im_gray[mask_dilated])
    mask_corrected = im_gray_masked > threshold
    mask_corrected_labels = label(mask_corrected, connectivity=1)
    mask_corrected = np.zeros(mask_corrected_labels.shape).astype(np.bool)
    mask_corrected[mask_corrected_labels == mask_corrected_labels[s, s]] = True

    return mask_corrected

for i in range(30, 40):
    plt.figure(figsize=(24, 4))

    im = np.einsum('ljk->jkl', X_lim[i])
    mask = XM_lim[i]
    mask_n = XN_lim[i]
    mask_e = get_ellipse_mask(mask)

    im_gray = np.uint8(255 - rgb2gray(im) * 255)
    mask_corrected = correct_mask(im_gray, mask_e)
    mask_corrected_e = get_ellipse_mask(mask_corrected)

    plt.subplot(1, 6, 1)
    plt.imshow(im)
    plt.axis('off')
    plt.subplot(1, 6, 2)
    plt.imshow(im_gray)
    plt.axis('off')
    plt.subplot(1, 6, 3)
    plt.imshow(mark_boundaries(im, mask))
    plt.axis('off')
    plt.subplot(1, 6, 4)
    plt.imshow(mark_boundaries(im, mask_e))
    plt.axis('off')
    plt.subplot(1, 6, 5)
    plt.imshow(mark_boundaries(im, mask_corrected))
    plt.axis('off')
    plt.subplot(1, 6, 6)
    plt.imshow(mark_boundaries(im, mask_corrected_e))
    plt.axis('off')


In [None]:
# 6d. cell statistics - prepare data
from sklearn.decomposition import PCA

props_lim = {'area': [], 'minor to major ellipse axis': [],
             'red': [], 'green': [], 'blue': [],
             'hue': [], 'saturation': [], 'value': [],
             'RGB PCA - 1st component': []}
XM_corr = []

for i in range(X_lim.shape[0]):
    im = np.einsum('ljk->jkl', X_lim[i])
    mask = XM_lim[i]
    mask_n = XN_lim[i]
    mask_e, minor_axis_length, major_axis_length = get_ellipse_mask(mask, details=True)

    im_gray = np.uint8(255 - rgb2gray(im) * 255)
    mask_corrected = correct_mask(im_gray, mask_e)
    XM_corr.append(mask_corrected)

    hsv_im = rgb2hsv(im)

    props = regionprops(mask_corrected.astype(np.int64))[0]
    props_lim['area'].append(mask_corrected.sum())
    props_lim['minor to major ellipse axis'].append(props.minor_axis_length / props.major_axis_length)
    props_lim['red'].append(im[:, :, 0][mask_corrected].mean())
    props_lim['green'].append(im[:, :, 1][mask_corrected].mean())
    props_lim['blue'].append(im[:, :, 2][mask_corrected].mean())
    props_lim['hue'].append(hsv_im[:, :, 0][mask_corrected].mean())
    props_lim['saturation'].append(hsv_im[:, :, 1][mask_corrected].mean())
    props_lim['value'].append(hsv_im[:, :, 2][mask_corrected].mean())

# rgb pca
RGB = np.stack([props_lim['red'], props_lim['green'], props_lim['blue']]).transpose()
pca = PCA(n_components=3)
RGB_pca = pca.fit_transform(RGB)
print(pca.explained_variance_ratio_)
props_lim['RGB PCA - 1st component'] = RGB_pca[:, 0]


In [None]:
# 6e. cell statistics - present statistics
clone_names_lim = clone_names
take = [(y_lim == c).sum() for c in range(len(clone_names_lim))]
take_min = np.min(take)

all_limit_indices = []
for c in range(len(clone_names_lim)):
    limit_indices = np.random.choice(np.where(y_lim == c)[0], take_min, replace=False)
    all_limit_indices.append(limit_indices)
all_limit_indices = np.concatenate(all_limit_indices)

for prop_name, limits in [('area', []),
                          ('minor to major ellipse axis', [0, 1]),
                          ('RGB PCA - 1st component', [])]:
    props_lim[prop_name] = np.array(props_lim[prop_name])

    draw(clone_names_lim, props_lim[prop_name][all_limit_indices], y_lim[all_limit_indices], prop_name, dtype='float')
    print(prop_name)
    print_pvalues(clone_names_lim, props_lim[prop_name][all_limit_indices], y_lim[all_limit_indices])

    order = np.argsort(props_lim[prop_name])

    intervals = 10
    samples = 5
    skip = X_lim.shape[0] // intervals
    for j in range(samples):
        plt.figure(figsize=(4 * intervals, 4))
        for i in range(intervals):
            oi = order[i * skip + j]
            im = np.einsum('ljk->jkl', X_lim[oi])
            mask = XM_corr[oi]

            plt.subplot(1, intervals, i + 1)
            plt.imshow(mark_boundaries(im, mask))
            plt.title('{:.2f}'.format(props_lim[prop_name][oi]))
            plt.axis('off')
