In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

### Plotting for paramter estimation in chromosomal data

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import pathlib
from pathlib import Path
import torch
sns.set_style('white')
from skimage.io import imread, imsave
import scipy.stats
import edt
from skimage import segmentation
import pickle
%matplotlib qt5

In [3]:
MASKS_DIR = Path('/mnt/sda1/SMLAT/data/real_data/replisome_dots/pool_3051_repli_pos01/phase_venus_mask/')
PHASE_DIR = Path('/mnt/sda1/SMLAT/data/real_data/replisome_dots/pool_3051_repli_pos01/phase_venus/')
VENUS_DIR = Path('/mnt/sda1/SMLAT/data/real_data/replisome_dots/pool_3051_repli_pos01/venus/')
edt_dump_path = Path('/mnt/sda1/SMLAT/data/real_data/replisome_dots/pool_3051_repli_pos01/edt_noise.pkl')

In [4]:
mask_filenames = sorted(MASKS_DIR.glob('*.png'))
phase_filenames = sorted(PHASE_DIR.glob('*.tiff'))
venus_filenames = sorted(VENUS_DIR.glob('*.tiff'))

In [5]:
index = 45
phase_img = imread(phase_filenames[index])
mask_img = imread(mask_filenames[index])
venus_img = imread(venus_filenames[index])

In [6]:
from cellbgnet.param_estimation import *
from cellbgnet.plotting import *

In [7]:
 edt_pool_alphas, edt_pool_betas = get_full_edt_maps(MASKS_DIR, VENUS_DIR, save_filename=None)

100%|█████████████████████████████████████████████████████████████| 301/301 [00:49<00:00,  6.07it/s]


In [8]:
edt_pool_alphas, edt_pool_betas

({1: 109.41085851651867,
  2: 109.46318904734099,
  3: 112.78694062995059,
  4: 115.7861086468145,
  5: 118.67457586429566,
  6: 120.66542572044445,
  0: 137.51260796750466},
 {1: 1.1580122722645674,
  2: 1.1774333961388979,
  3: 1.1586569054314162,
  4: 1.1366204048159263,
  5: 1.1146217447885085,
  6: 1.1013536627764136,
  0: 0.8401670299654705})

### Chip background distribution fit.. 

In [9]:
def chromo_mean_var_bg_outside(fluor_img, cellseg_mask, dilate=True, roi=None,
                 plot=False, dilate_px=1, return_alpha_beta=False, save_path=None):
    """
    Function to evaluate mean and variance of the pixels outside the cells
    and fit gamma distribution to the pixels

    Arguments:
    ----------
        fluor_img (np.ndarray): fluorescence image
        cellseg_mask (np.ndarray): corresponding cell mask for the fluorescence image
        dilate (bool): Should you dilate the cell mask a bit, defualt is 1 in dilate_px
        roi (list of 4 ints): rchw of the ROI you want to calculate incase cell masks
                              are not that great in some regions
        plot (bool): Plot the fitted gamma distribution

        return_alpha_beta (bool): function returns alpha and beta instead of mean and
                    variance of the gamma distribution

    Returns:
    ----------
        mean, variance of the fitted gamma distribution to the pixel values that
        are outside the cells
    """
    if roi is not None:
        cellseg_mask = cellseg_mask[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]]
        fluor_img = fluor_img[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]]

    if dilate:
        cellseg_mask = segmentation.expand_labels(cellseg_mask, distance=dilate_px)
    binary_mask = cellseg_mask == 0
    
    img_outside_cells = np.zeros_like(fluor_img)

    outside_inds = np.where(binary_mask == 1)
    img_outside_cells[outside_inds[0], outside_inds[1]] = fluor_img[outside_inds[0], outside_inds[1]]

    only_pixels_outside = img_outside_cells[outside_inds[0], outside_inds[1]].ravel()

    # remove outliers in the outside pixels at 99% 
    collect_bg_only = only_pixels_outside[np.where(only_pixels_outside < np.percentile(only_pixels_outside, 99))]

    # now fit a gamma distribution and return the mean and variance of this fitted distribution
    fit_alpha, fit_loc, fit_beta = scipy.stats.gamma.fit(collect_bg_only, floc=0.0)
    low, high = collect_bg_only.min(), collect_bg_only.max()

    if plot:
        fig, ax = plt.subplots()
        ax.hist(collect_bg_only, bins=np.arange(low, high), histtype='step', label='data')
        ax.hist(np.random.gamma(shape=fit_alpha, scale=fit_beta, size=len(collect_bg_only)+0),
                        bins=np.arange(low, high), histtype='step', label='fit')
        ax.set_xlabel('Chip background (ADU)', fontsize=16)
        ax.set_ylabel('Counts', fontsize=16)
        ax.tick_params(axis='x', labelsize=16)
        ax.tick_params(axis='y', labelsize=16)
        plt.legend()
        plt.show()
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
            
    if return_alpha_beta:
        return fit_alpha, fit_beta
    else:
        return fit_alpha * fit_beta, fit_alpha * fit_beta * fit_beta


In [10]:
index = 46
phase_img = imread(phase_filenames[index])
mask_img = imread(mask_filenames[index])
venus_img = imread(venus_filenames[index])


In [11]:
SAVE_DIR = Path('/mnt/sda1/SMLAT/figures/supplementary/param_est_replisome/')

In [12]:
save_path = SAVE_DIR / Path('chip_bg.svg')
chromo_mean_var_bg_outside(venus_img, mask_img, plot=True, return_alpha_beta=True, save_path=save_path)

(123.13663336708984, 0.898685492828329)

### EDT distribution fits

In [13]:
edt_data = chromo_edt_mean_variance_inside(venus_img, mask_img, plot=True, return_values=True)

In [14]:
edt_data

{'mean': {1: 125.23398307927397,
  2: 127.33541082372494,
  3: 128.81921599884066,
  4: 129.88433244639225,
  5: 130.38572039333047,
  6: 130.81481009431556,
  7: 130.80388978930307},
 'stddev': {1: 12.104647647090697,
  2: 12.201133535294511,
  3: 12.23491159167863,
  4: 12.146472136468692,
  5: 11.933501852548865,
  6: 12.103344186123504,
  7: 11.925842624685318},
 'counts': {1: 26831, 2: 19254, 3: 13801, 4: 12778, 5: 11695, 6: 7846, 7: 617},
 'values': {1: array([118, 118, 147, ..., 119, 114,  96], dtype=uint16),
  2: array([146, 137, 127, ..., 106, 116, 115], dtype=uint16),
  3: array([133, 119, 128, ..., 108, 116, 114], dtype=uint16),
  4: array([121, 149, 111, ...,  99, 111, 110], dtype=uint16),
  5: array([124, 145, 148, ..., 107, 110, 116], dtype=uint16),
  6: array([146, 143, 138, ..., 106, 105, 137], dtype=uint16),
  7: array([143, 144, 133, 149, 126, 148, 132, 142, 127, 125, 126, 117, 151,
         141, 129, 125, 124, 109, 115, 143, 126, 129, 120, 120, 134, 142,
         143

In [15]:
fig, ax = plt.subplots()
for edt_index in range(1, 7):
    low, high = edt_data['values'][edt_index].min(), edt_data['values'][edt_index].max()
    fit_alpha_edt, fit_loc_edt, fit_beta_edt = scipy.stats.gamma.fit(edt_data['values'][edt_index], floc=-1.5)
    #ax.hist(edt_data['values'][edt_index], bins=np.arange(low, high), histtype='barstacked', label='data')
    ax.hist(np.random.gamma(shape=fit_alpha_edt, scale=fit_beta_edt, size=100000),
            bins=np.arange(low, high), histtype='step', label=str(edt_index), density=True, stacked=False)
    ax.set_xlabel('Background inside cells (ADU)', fontsize=16)
    ax.set_ylabel('Fitted gamma distribution', fontsize=16)
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)
    leg = ax.legend()
    leg.set_title('Distance from cell boundary', prop={'size':16})
    plt.setp(plt.gca().get_legend().get_texts(), fontsize='16')

    plt.show()
save_path = SAVE_DIR / Path('edt_vs_cellbg.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)


#### Mean and variances distributions over images (2 violin plots)

In [16]:
from tqdm import tqdm
import math

In [17]:
collected_means = {0:[], 1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
collected_stddev = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
for index in tqdm(range(len(venus_filenames))):
    if index > 20:
        phase_img = imread(phase_filenames[index])
        mask_img = imread(mask_filenames[index])
        venus_img = imread(venus_filenames[index])
        inside = chromo_edt_mean_variance_inside(venus_img, mask_img)
        means = inside['mean']
        stddev = inside['stddev']
        outside = chromo_mean_var_bg_outside(venus_img, mask_img, plot=False, return_alpha_beta=False)
        means[0] = outside[0]
        stddev[0] = math.sqrt(outside[1])
        for key, value in means.items():
            if key in collected_means:
                collected_means[key].append(value)
        for key, value in stddev.items():
            if key in collected_stddev:
                collected_stddev[key].append(value)

100%|█████████████████████████████████████████████████████████████| 301/301 [00:59<00:00,  5.03it/s]


In [18]:
import pandas as pd

In [19]:
collected_means

{0: [110.83507141706333,
  110.9291849232906,
  110.8075098354913,
  110.8132108582304,
  110.65158918226052,
  110.8727145435153,
  110.64418221155671,
  110.6648308975277,
  110.6136412620689,
  110.61342177491595,
  110.55339764085852,
  110.52437146865876,
  110.6001203005492,
  110.54428059216185,
  110.55252611114383,
  110.59193319864502,
  110.50750861489061,
  110.55584424124045,
  110.62392228140594,
  110.63582570232893,
  110.57233042298947,
  110.68609173390156,
  110.54786091659109,
  110.56451477120045,
  110.70120832451903,
  110.6611060427244,
  110.7941972412562,
  110.71893844961762,
  110.45769270952695,
  110.51629733676332,
  110.5533126442323,
  110.6200342760078,
  110.80226821718735,
  110.6259679091815,
  110.59234904834783,
  110.78614088304685,
  110.57491864370752,
  110.62918199245824,
  110.59476131382495,
  110.69223092891158,
  110.71450731386898,
  110.64795674573396,
  110.751769755894,
  110.64980317229555,
  110.6202631353058,
  110.55512052352036,


In [20]:
df = pd.DataFrame(collected_means)

In [21]:
df

Unnamed: 0,0,1,2,3,4,5,6
0,110.835071,126.193384,128.187290,130.122004,130.920458,131.623759,131.807382
1,110.929185,126.008046,128.034579,129.894310,130.573504,131.376284,131.478341
2,110.807510,125.973924,128.082323,129.829166,130.886444,131.561834,131.767436
3,110.813211,126.474654,128.774266,130.459862,131.495797,132.058708,132.124203
4,110.651589,125.528781,127.860082,129.522239,130.574973,131.001538,131.127670
...,...,...,...,...,...,...,...
275,110.079151,123.997330,126.149539,128.083824,129.160418,129.593344,130.672576
276,110.191267,124.341241,126.489415,128.515115,129.342530,130.245502,131.154901
277,110.150787,124.330952,126.604107,128.641328,129.353688,130.140323,131.141808
278,110.128810,124.273464,126.601249,128.449218,129.773929,130.489105,131.523101


In [22]:
import seaborn as sns

In [23]:


ax = sns.violinplot(data=df, linewidth=0.5)
ax.set_xlabel('Euclidean distance from boundary', fontsize=16)
ax.set_ylabel('Mean of fitted gamma distribution', fontsize=16)
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
fig = ax.get_figure()
save_path = SAVE_DIR / Path('edt_vs_mean_over_time.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)

In [24]:
df = pd.DataFrame(collected_stddev)

In [25]:


ax = sns.violinplot(data=df, linewidth=0.5)
ax.set_xlabel('Euclidean distance from boundary', fontsize=16)
ax.set_ylabel('Standard-deviation of fitted gamma distribution', fontsize=16)
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
fig = ax.get_figure()
save_path = SAVE_DIR / Path('edt_vs_stddev_over_time.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)

#### Plot of edt mean, variance map on the cell mask

In [26]:
with open(edt_dump_path, 'rb') as fp:
    edt_noise_map = pickle.load(fp)

In [27]:
edt_noise_map

{'alphas': {1: 97.36472669105797,
  2: 95.27145951791375,
  3: 96.68660962483388,
  4: 98.6097716401184,
  5: 100.68944893076487,
  6: 101.89646166836494,
  0: 137.51260796750466},
 'betas': {1: 1.3114575954947452,
  2: 1.367861823311651,
  3: 1.3705599129650432,
  4: 1.3553258709095097,
  5: 1.335349513429498,
  6: 1.3272756198198252,
  0: 0.8401670299654705}}

In [28]:
def alpha_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [29]:
def beta_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [30]:
def mean_bg_mask(dists, alpha_map, beta_map):
    dists_copy = np.copy(dists)
    for edt_val, alpha_bg in alpha_map.items():
        dists_copy[dists == int(edt_val)] = alpha_bg * beta_map[edt_val]
    return dists_copy

In [31]:
def variance_bg_mask(dists, alpha_map, beta_map):
    dists_copy = np.copy(dists)
    for edt_val, alpha_bg in alpha_map.items():
        dists_copy[dists == int(edt_val)] = math.sqrt(alpha_bg * beta_map[edt_val] * beta_map[edt_val])
    return dists_copy

In [32]:

dists = np.zeros_like(mask_img)
dists = edt.edt(mask_img)
dists = np.round(dists)
dists = np.ceil(dists)

In [33]:
fitted_beta_map = edt_noise_map['betas']
fitted_alpha_map = edt_noise_map['alphas']
mean_bg_cells = mean_bg_mask(dists, fitted_alpha_map, fitted_beta_map)
stddev_bg_cells = variance_bg_mask(dists, fitted_alpha_map, fitted_beta_map)
alpha_bg_cells = alpha_bg_mask(dists, fitted_alpha_map)
beta_bg_cells = beta_bg_mask(dists, fitted_beta_map)

In [34]:
alpha_t = torch.from_numpy(alpha_bg_cells)
beta_t = 1.0/torch.from_numpy(beta_bg_cells)

In [35]:
m = torch.distributions.gamma.Gamma(concentration=alpha_t, rate=beta_t)

In [36]:
sample = m.sample()

In [37]:
fig, ax = plt.subplots(nrows=1, ncols=3)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells)
ax[0].set_title('Alpha values')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
ax[1].imshow(beta_bg_cells)
#ax[1].set_title('Std dev bg values')
ax[1].set_title('Beta values')
ax[2].imshow(dists)
ax[2].set_title('EDT')
plt.show()

In [38]:
plt.figure()
plt.imshow(sample.cpu().numpy(), cmap='gray')
plt.show()

In [39]:
roi = [500, 400, 256, 256]

In [40]:
save_phase = SAVE_DIR / Path('mean_image.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(mean_bg_cells[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [41]:
save_phase = SAVE_DIR / Path('stddev.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(stddev_bg_cells[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [42]:
save_phase = SAVE_DIR / Path('distance_from_boundary.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(dists[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [43]:
save_phase = SAVE_DIR / Path('sampled_bg.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(sample.cpu().numpy()[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]], cmap='gray')
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)