In [1]:
import cv2
import glob
import numpy as np
import pandas as pd
import dask as d
import matplotlib.pyplot as plt
from skimage.measure import block_reduce
from skimage.transform import resize
from sklearn.mixture import GaussianMixture

In [2]:
cell_data_path = "D:cell_area_data/HMVEC/"
output_data_path = "D:cell_area_data/results"

In [3]:
ds_fact = 8
pinhole_cut = 50
sd_coef = 0.5
tif_min = 0
tif_max = 2**16
gs_min = 0
gs_max = 255

In [4]:
def load_img(img_name, dsamp=True, ds_fact=2, ds_func=np.mean):
    img = cv2.imread(img_name, cv2.IMREAD_ANYDEPTH)
    if dsamp:
        # return block_reduce(
        #     cv2.imread(img_name, cv2.IMREAD_ANYDEPTH),
        #     block_size=(block_size),
        #     func=np.mean
        # )
        return resize(
            img,
            (img.shape[0] // ds_fact, img.shape[1] // ds_fact),
            anti_aliasing=True, preserve_range=True
        )
    else:
        return img

    
def min_max_(x, a, b, mn, mx):
    x = x.astype(float)
    return a + ( (x - mn) * (b - a) ) / (mx - mn)


def gen_circ_mask(center, rad, shape, mask_val):
    circ_mask = np.zeros(shape, dtype="uint8")
    return cv2.circle(circ_mask, center, rad, mask_val, cv2.FILLED)


def apply_mask(img, mask):
    return cv2.bitwise_and(img, img, mask=mask)


def load_and_norm(img_name, a, b, mn, mx, dsamp=True, ds_fact=2, ds_func=np.mean):
    img = load_img(img_name, dsamp=True, ds_fact=ds_fact, ds_func=np.mean)
    return min_max_(img, a, b, mn, mx)


def exec_threshold(masked, pinhole_idx, sd_coef, rs):
    # Select pinhole pixels
    X = masked[pinhole_idx][:, np.newaxis]
    gm = GaussianMixture(n_components=2, random_state=rs).fit(X)
    # Get GMM components
    means = gm.means_.squeeze()
    sds = np.sqrt(gm.covariances_.squeeze())
    # Get mean foreground mean & threshold value
    fg_dist_idx = np.argmax(means)
    fg_thresh = min(gs_max, means[fg_dist_idx] + sds[fg_dist_idx] * sd_coef)
    # Apply threshold
    bg_mask = ~(X.squeeze() > fg_thresh)
    bg_pinhole_idx = (pinhole_idx[0][bg_mask], pinhole_idx[1][bg_mask])
    gmm_masked = np.copy(masked)
    gmm_masked[bg_pinhole_idx] = 0
    return gmm_masked


def mask_and_threshold(img, circ_mask, pinhole_idx, sd_coef, rs):
    masked = apply_mask(img, circ_mask).astype(float)
    return exec_threshold(masked, pinhole_idx, sd_coef, rs)


def compute_area_pct(img, ref_area):
    return np.sum(img > 0) / ref_area * 100

In [5]:
img_names = [img.replace("\\", "/") for img in glob.glob(f"{cell_data_path}/*.tif")]

In [6]:
gs_ds_imgs = d.compute(
    [d.delayed(load_and_norm)(img_n, gs_min, gs_max, tif_min, tif_max, dsamp=True, ds_fact=ds_fact, ds_func=np.mean) for img_n in img_names]
)[0]

In [7]:
img_shape = gs_ds_imgs[0].shape
img_center = (img_shape[0] // 2, img_shape[1] // 2)
circ_rad = img_center[0] - (pinhole_cut // ds_fact)

In [8]:
circ_mask = gen_circ_mask(img_center, circ_rad, img_shape, gs_max)
pinhole_idx = np.where(circ_mask > 0)
circ_pix_area = np.sum(circ_mask > 0)

In [9]:
rs_all = np.random.RandomState(seed=123)
gmm_masked_all = d.compute(
    [d.delayed(mask_and_threshold)(img, circ_mask, pinhole_idx, sd_coef, rs_all)
     for img in gs_ds_imgs]
)[0]


In [10]:
us_gmm_masked_all = d.compute(
    [d.delayed(resize)(gimg, (gimg.shape[0] * ds_fact, gimg.shape[1] * ds_fact)) for gimg in gmm_masked_all]
)[0]

In [11]:
area_pcts = d.compute(
    [d.delayed(compute_area_pct)(gimg, circ_pix_area) for gimg in gmm_masked_all]
)[0]

In [12]:
img_ids = [img_n.split("/")[-1][:-4] for img_n in img_names]
area_df = pd.DataFrame(
    data = {"image_id": img_ids, "area_pct": area_pcts}
)

In [13]:
area_df.to_csv(f"{output_data_path}/area_results.csv", index=False)

In [14]:
for i in range(len(img_ids)):
    cv2.imwrite(
        f"{output_data_path}/{img_ids[i]}_processed.tif",
        min_max_(us_gmm_masked_all[i], tif_min, tif_max, gs_min, gs_max).astype(np.uint16)
    )

In [15]:
# idx = 50

In [16]:
# gmm_masked = mask_and_threshold(gs_ds_imgs[idx], circ_mask, pinhole_idx, sd_coef, np.random.RandomState(seed=123))
# compute_area_pct(gmm_masked, circ_pix_area)

In [17]:
# plt.imshow(gmm_masked, cmap="gray")
# plt.show()

In [18]:
# us_gmm_masked = resize(gmm_masked, (gmm_masked.shape[0] * ds_fact, gmm_masked.shape[1] * ds_fact))

In [19]:
# img_id = img_names[idx].split("/")[-1][:-4]
# cv2.imwrite(
#     f"{output_data_path}/{img_id}_processed.tif",
#     min_max_(us_gmm_masked, tif_min, tif_max, gs_min, gs_max).astype(np.uint16)
# )