In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

### Plotting for paramter estimation in chromosomal data

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import pathlib
from pathlib import Path
import torch
sns.set_style('white')
from skimage.io import imread, imsave
import scipy.stats
import edt
from skimage import segmentation
import pickle
%matplotlib qt5

In [3]:
MASKS_DIR = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/EXP-23-CA3045/pool_3045_ter_pos01/phase_venus_mask/')
PHASE_DIR = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/EXP-23-CA3045/pool_3045_ter_pos01/phase_venus/')
VENUS_DIR = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/EXP-23-CA3045/pool_3045_ter_pos01/venus/')
edt_dump_path = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/EXP-23-CA3045/edt_noise.pkl')

In [4]:
mask_filenames = sorted(MASKS_DIR.glob('*.png'))
phase_filenames = sorted(PHASE_DIR.glob('*.tiff'))
venus_filenames = sorted(VENUS_DIR.glob('*.tiff'))

In [5]:
index = 45
phase_img = imread(phase_filenames[index])
mask_img = imread(mask_filenames[index])
venus_img = imread(venus_filenames[index])

In [6]:
from cellbgnet.param_estimation import *
from cellbgnet.plotting import *

In [7]:
 edt_pool_alphas, edt_pool_betas = get_full_edt_maps(MASKS_DIR, VENUS_DIR, save_filename=None,
                                                     bg_cutoff_percentile=80)

100%|█████████████████████████████████████████████████████████████| 301/301 [00:46<00:00,  6.41it/s]


In [8]:
edt_pool_alphas, edt_pool_betas

({1: 215.611487316387,
  2: 218.34196215012432,
  3: 222.02796471622432,
  4: 224.02947978628708,
  5: 225.8350492254097,
  6: 229.52910371210007,
  0: 180.30555399547802},
 {1: 0.5267494232005538,
  2: 0.5235224548764762,
  3: 0.517300767333869,
  4: 0.5137420714595694,
  5: 0.510302348908574,
  6: 0.5037518533667191,
  0: 0.6027986764257343})

### Chip background distribution fit.. 

In [9]:
def chromo_mean_var_bg_outside(fluor_img, cellseg_mask, dilate=True, roi=None,
                 plot=False, dilate_px=1, return_alpha_beta=False, save_path=None):
    """
    Function to evaluate mean and variance of the pixels outside the cells
    and fit gamma distribution to the pixels

    Arguments:
    ----------
        fluor_img (np.ndarray): fluorescence image
        cellseg_mask (np.ndarray): corresponding cell mask for the fluorescence image
        dilate (bool): Should you dilate the cell mask a bit, defualt is 1 in dilate_px
        roi (list of 4 ints): rchw of the ROI you want to calculate incase cell masks
                              are not that great in some regions
        plot (bool): Plot the fitted gamma distribution

        return_alpha_beta (bool): function returns alpha and beta instead of mean and
                    variance of the gamma distribution

    Returns:
    ----------
        mean, variance of the fitted gamma distribution to the pixel values that
        are outside the cells
    """
    if roi is not None:
        cellseg_mask = cellseg_mask[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]]
        fluor_img = fluor_img[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]]

    if dilate:
        cellseg_mask = segmentation.expand_labels(cellseg_mask, distance=dilate_px)
    binary_mask = cellseg_mask == 0
    
    img_outside_cells = np.zeros_like(fluor_img)

    outside_inds = np.where(binary_mask == 1)
    img_outside_cells[outside_inds[0], outside_inds[1]] = fluor_img[outside_inds[0], outside_inds[1]]

    only_pixels_outside = img_outside_cells[outside_inds[0], outside_inds[1]].ravel()

    # remove outliers in the outside pixels at 99% 
    collect_bg_only = only_pixels_outside[np.where(only_pixels_outside < np.percentile(only_pixels_outside, 99))]

    # now fit a gamma distribution and return the mean and variance of this fitted distribution
    fit_alpha, fit_loc, fit_beta = scipy.stats.gamma.fit(collect_bg_only, floc=0.0)
    low, high = collect_bg_only.min(), collect_bg_only.max()

    if plot:
        fig, ax = plt.subplots()
        ax.hist(collect_bg_only, bins=np.arange(low, high), histtype='step', label='data')
        ax.hist(np.random.gamma(shape=fit_alpha, scale=fit_beta, size=len(collect_bg_only)+0),
                        bins=np.arange(low, high), histtype='step', label='fit')
        ax.set_xlabel('Chip background (ADU)', fontsize=16)
        ax.set_ylabel('Counts', fontsize=16)
        ax.tick_params(axis='x', labelsize=16)
        ax.tick_params(axis='y', labelsize=16)
        plt.legend()
        plt.show()
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
            
    if return_alpha_beta:
        return fit_alpha, fit_beta
    else:
        return fit_alpha * fit_beta, fit_alpha * fit_beta * fit_beta


In [10]:
index = 46
phase_img = imread(phase_filenames[index])
mask_img = imread(mask_filenames[index])
venus_img = imread(venus_filenames[index])


In [11]:
SAVE_DIR = Path('/mnt/sda1/SMLAT/figures/supplementary/param_est_chromo/')

In [12]:
save_path = SAVE_DIR / Path('chip_bg.svg')
chromo_mean_var_bg_outside(venus_img, mask_img, plot=True, return_alpha_beta=True, save_path=save_path)

(167.08758682864607, 0.6420951539956198)

### EDT distribution fits

In [13]:
edt_data = chromo_edt_mean_variance_inside(venus_img, mask_img, plot=True, return_values=True)

In [14]:
edt_data

{'mean': {1: 110.54809422869256,
  2: 111.22860292091485,
  3: 111.76904231625835,
  4: 111.85523630083216,
  5: 112.13982626127631,
  6: 112.25685920577617,
  7: 110.55555555555556},
 'stddev': {1: 7.670276180894463,
  2: 7.628324454356531,
  3: 7.477441901492216,
  4: 7.554174757560259,
  5: 7.451233205396816,
  6: 7.393619668069318,
  7: 7.461919787399236},
 'counts': {1: 25003, 2: 18145, 3: 13470, 4: 12738, 5: 11972, 6: 5540, 7: 180},
 'values': {1: array([ 93, 120, 122, ...,  97,  99, 103], dtype=uint16),
  2: array([105, 102, 113, ..., 103, 108, 111], dtype=uint16),
  3: array([118, 116, 119, ..., 110, 107, 105], dtype=uint16),
  4: array([ 99, 100, 100, ...,  97, 109, 123], dtype=uint16),
  5: array([111, 104, 117, ..., 101,  96, 116], dtype=uint16),
  6: array([119, 119, 118, ...,  96, 111, 116], dtype=uint16),
  7: array([114, 115, 100,  92, 116, 109, 121, 123, 107, 124, 122,  97, 113,
         115, 118, 114, 105, 122, 119, 114, 113, 114, 124, 118, 111, 114,
         108, 117,

In [15]:
fig, ax = plt.subplots()
for edt_index in range(1, 7):
    low, high = edt_data['values'][edt_index].min(), edt_data['values'][edt_index].max()
    fit_alpha_edt, fit_loc_edt, fit_beta_edt = scipy.stats.gamma.fit(edt_data['values'][edt_index], floc=-1.5)
    #ax.hist(edt_data['values'][edt_index], bins=np.arange(low, high), histtype='barstacked', label='data')
    ax.hist(np.random.gamma(shape=fit_alpha_edt, scale=fit_beta_edt, size=100000),
            bins=np.arange(low, high), histtype='step', label=str(edt_index), density=True, stacked=False)
    ax.set_xlabel('Background inside cells (ADU)', fontsize=16)
    ax.set_ylabel('Fitted gamma distribution', fontsize=16)
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)
    leg = ax.legend()
    leg.set_title('Distance from cell boundary', prop={'size':16})
    plt.setp(plt.gca().get_legend().get_texts(), fontsize='16')

    plt.show()
save_path = SAVE_DIR / Path('edt_vs_cellbg.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)


#### Mean and variances distributions over images (2 violin plots)

In [16]:
from tqdm import tqdm
import math

In [17]:
collected_means = {0:[], 1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
collected_stddev = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
for index in tqdm(range(len(venus_filenames))):
    if index > 20:
        phase_img = imread(phase_filenames[index])
        mask_img = imread(mask_filenames[index])
        venus_img = imread(venus_filenames[index])
        inside = chromo_edt_mean_variance_inside(venus_img, mask_img)
        means = inside['mean']
        stddev = inside['stddev']
        outside = chromo_mean_var_bg_outside(venus_img, mask_img, plot=False, return_alpha_beta=False)
        means[0] = outside[0]
        stddev[0] = math.sqrt(outside[1])
        for key, value in means.items():
            if key in collected_means:
                collected_means[key].append(value)
        for key, value in stddev.items():
            if key in collected_stddev:
                collected_stddev[key].append(value)

100%|█████████████████████████████████████████████████████████████| 301/301 [00:57<00:00,  5.27it/s]


In [18]:
import pandas as pd

In [19]:
collected_means

{0: [107.9961226836825,
  108.1085862819322,
  108.00040207874713,
  107.77776633795307,
  107.87058252410904,
  107.88919370029126,
  107.91083987982213,
  107.63533365563711,
  107.79550595270472,
  107.59693600243202,
  107.62569195293045,
  107.76605142251651,
  107.68867759011343,
  107.75183177598228,
  107.46397445463613,
  107.50192808082704,
  107.64803119299832,
  107.46280226666488,
  107.49047432504148,
  107.47460197436459,
  107.46878353919388,
  107.35924052263202,
  107.45648238954116,
  107.41778137863268,
  107.1849206109989,
  107.286129795496,
  107.25596149657356,
  107.32351902644575,
  107.28088508717516,
  107.21242809384694,
  107.27022562732832,
  107.09688121429596,
  107.31140938136137,
  107.25054026254247,
  107.17591440468989,
  107.25437947839413,
  107.11049625163791,
  107.12529221061939,
  107.0998776331692,
  107.03702001980136,
  107.03524293486981,
  106.97746299964065,
  107.0365312267692,
  106.87501014480506,
  106.90638980572385,
  106.81501079

In [20]:
df = pd.DataFrame(collected_means)

In [21]:
df

Unnamed: 0,0,1,2,3,4,5,6
0,107.996123,112.858934,113.451803,114.151804,114.583221,114.650188,115.017205
1,108.108586,113.106584,113.955211,114.480481,114.982948,115.176334,115.296296
2,108.000402,112.711304,113.428422,114.149151,114.267519,114.575242,114.729177
3,107.777766,112.253606,112.943294,113.551822,113.839501,113.864919,114.022193
4,107.870583,112.244426,113.005462,113.522426,113.954859,114.313812,114.214345
...,...,...,...,...,...,...,...
275,106.242904,108.005850,108.426894,108.765548,108.967960,109.103792,109.214105
276,106.188200,107.879983,108.251482,108.664973,108.847729,109.033926,109.252876
277,106.189695,107.821361,108.378872,108.689416,108.874768,109.158347,109.381128
278,106.154334,107.980981,108.343761,108.679847,108.950577,108.967627,109.276701


In [22]:
import seaborn as sns

In [23]:


ax = sns.violinplot(data=df, linewidth=0.5)
ax.set_xlabel('Euclidean distance from boundary', fontsize=16)
ax.set_ylabel('Mean of fitted gamma distribution', fontsize=16)
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
fig = ax.get_figure()
save_path = SAVE_DIR / Path('edt_vs_mean_over_time.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)

In [24]:
df = pd.DataFrame(collected_stddev)

In [25]:


ax = sns.violinplot(data=df, linewidth=0.5)
ax.set_xlabel('Euclidean distance from boundary', fontsize=16)
ax.set_ylabel('Standard-deviation of fitted gamma distribution', fontsize=16)
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
fig = ax.get_figure()
save_path = SAVE_DIR / Path('edt_vs_stddev_over_time.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)

#### Plot of edt mean, variance map on the cell mask

In [26]:
with open(edt_dump_path, 'rb') as fp:
    edt_noise_map = pickle.load(fp)

In [27]:
edt_noise_map

{'alphas': {1: 236.54638292153774,
  2: 240.7631598605963,
  3: 245.80542733063263,
  4: 248.2721979977929,
  5: 250.56159873895828,
  6: 255.11618135604485,
  0: 180.30555399547802},
 'betas': {1: 0.47607266437417495,
  2: 0.47003777203611585,
  3: 0.46211845907868015,
  4: 0.45833689490682666,
  5: 0.4544948106907408,
  6: 0.4476987237184442,
  0: 0.6027986764257343}}

In [28]:
def alpha_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [29]:
def beta_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [30]:
def mean_bg_mask(dists, alpha_map, beta_map):
    dists_copy = np.copy(dists)
    for edt_val, alpha_bg in alpha_map.items():
        dists_copy[dists == int(edt_val)] = alpha_bg * beta_map[edt_val]
    return dists_copy

In [31]:
def variance_bg_mask(dists, alpha_map, beta_map):
    dists_copy = np.copy(dists)
    for edt_val, alpha_bg in alpha_map.items():
        dists_copy[dists == int(edt_val)] = math.sqrt(alpha_bg * beta_map[edt_val] * beta_map[edt_val])
    return dists_copy

In [32]:

dists = np.zeros_like(mask_img)
dists = edt.edt(mask_img)
dists = np.round(dists)
dists = np.ceil(dists)

In [33]:
fitted_beta_map = edt_noise_map['betas']
fitted_alpha_map = edt_noise_map['alphas']
mean_bg_cells = mean_bg_mask(dists, fitted_alpha_map, fitted_beta_map)
stddev_bg_cells = variance_bg_mask(dists, fitted_alpha_map, fitted_beta_map)
alpha_bg_cells = alpha_bg_mask(dists, fitted_alpha_map)
beta_bg_cells = beta_bg_mask(dists, fitted_beta_map)

In [34]:
alpha_t = torch.from_numpy(alpha_bg_cells)
beta_t = 1.0/torch.from_numpy(beta_bg_cells)

In [35]:
m = torch.distributions.gamma.Gamma(concentration=alpha_t, rate=beta_t)

In [36]:
sample = m.sample()

In [37]:
fig, ax = plt.subplots(nrows=1, ncols=3)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells)
ax[0].set_title('Alpha values')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
ax[1].imshow(beta_bg_cells)
#ax[1].set_title('Std dev bg values')
ax[1].set_title('Beta values')
ax[2].imshow(dists)
ax[2].set_title('EDT')
plt.show()

In [38]:
plt.figure()
plt.imshow(sample.cpu().numpy(), cmap='gray')
plt.show()

In [39]:
roi = [500, 400, 256, 256]

In [40]:
save_phase = SAVE_DIR / Path('mean_image.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(mean_bg_cells[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [41]:
save_phase = SAVE_DIR / Path('stddev.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(stddev_bg_cells[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [42]:
save_phase = SAVE_DIR / Path('distance_from_boundary.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(dists[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [43]:
save_phase = SAVE_DIR / Path('sampled_bg.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(sample.cpu().numpy()[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]], cmap='gray')
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)