In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

### Plotting for paramter estimation in chromosomal data

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import pathlib
from pathlib import Path
import torch
sns.set_style('white')
from skimage.io import imread, imsave
import scipy.stats
import edt
from skimage import segmentation
import pickle
%matplotlib qt5

In [3]:
MASKS_DIR = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/pooled_8865/phase_venus_mask/')
PHASE_DIR = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/pooled_8865/phase_venus/')
VENUS_DIR = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/pooled_8865/venus/')
edt_dump_path = Path('/mnt/sda1/SMLAT/data/real_data/chromosome_dots/pooled_8865/edt_noise.pkl')

In [4]:
mask_filenames = sorted(MASKS_DIR.glob('*.png'))
phase_filenames = sorted(PHASE_DIR.glob('*.tiff'))
venus_filenames = sorted(VENUS_DIR.glob('*.tiff'))

In [5]:
index = 45
phase_img = imread(phase_filenames[index])
mask_img = imread(mask_filenames[index])
venus_img = imread(venus_filenames[index])

In [6]:
from cellbgnet.param_estimation import *
from cellbgnet.plotting import *

In [7]:
 edt_pool_alphas, edt_pool_betas = get_full_edt_maps(MASKS_DIR, VENUS_DIR, save_filename=None)

  aest = (3-s + np.sqrt((s-3)**2 + 24*s)) / (12*s)
  func = lambda a: np.log(a) - sc.digamma(a) - s
100%|█████████████████████████████████████████████████████████████| 601/601 [01:31<00:00,  6.57it/s]


In [8]:
edt_pool_alphas, edt_pool_betas

({1: 198.26783822336986,
  2: 202.19717812453737,
  3: 206.53614247375737,
  4: 211.50892435104902,
  5: 218.42672100183074,
  6: 236.5096097576979,
  0: 175.97058789482784},
 {1: 0.5817275038062398,
  2: 0.5755147241351868,
  3: 0.5680015773432711,
  4: 0.5567916977275827,
  5: 0.5408661434281803,
  6: 0.5033149478710346,
  0: 0.6266468482700739})

### Chip background distribution fit.. 

In [12]:
def chromo_mean_var_bg_outside(fluor_img, cellseg_mask, dilate=True, roi=None,
                 plot=False, dilate_px=1, return_alpha_beta=False, save_path=None):
    """
    Function to evaluate mean and variance of the pixels outside the cells
    and fit gamma distribution to the pixels

    Arguments:
    ----------
        fluor_img (np.ndarray): fluorescence image
        cellseg_mask (np.ndarray): corresponding cell mask for the fluorescence image
        dilate (bool): Should you dilate the cell mask a bit, defualt is 1 in dilate_px
        roi (list of 4 ints): rchw of the ROI you want to calculate incase cell masks
                              are not that great in some regions
        plot (bool): Plot the fitted gamma distribution

        return_alpha_beta (bool): function returns alpha and beta instead of mean and
                    variance of the gamma distribution

    Returns:
    ----------
        mean, variance of the fitted gamma distribution to the pixel values that
        are outside the cells
    """
    if roi is not None:
        cellseg_mask = cellseg_mask[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]]
        fluor_img = fluor_img[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]]

    if dilate:
        cellseg_mask = segmentation.expand_labels(cellseg_mask, distance=dilate_px)
    binary_mask = cellseg_mask == 0
    
    img_outside_cells = np.zeros_like(fluor_img)

    outside_inds = np.where(binary_mask == 1)
    img_outside_cells[outside_inds[0], outside_inds[1]] = fluor_img[outside_inds[0], outside_inds[1]]

    only_pixels_outside = img_outside_cells[outside_inds[0], outside_inds[1]].ravel()

    # remove outliers in the outside pixels at 99% 
    collect_bg_only = only_pixels_outside[np.where(only_pixels_outside < np.percentile(only_pixels_outside, 99))]

    # now fit a gamma distribution and return the mean and variance of this fitted distribution
    fit_alpha, fit_loc, fit_beta = scipy.stats.gamma.fit(collect_bg_only, floc=0.0)
    low, high = collect_bg_only.min(), collect_bg_only.max()

    if plot:
        fig, ax = plt.subplots()
        ax.hist(collect_bg_only, bins=np.arange(low, high), histtype='step', label='data')
        ax.hist(np.random.gamma(shape=fit_alpha, scale=fit_beta, size=len(collect_bg_only)+0),
                        bins=np.arange(low, high), histtype='step', label='fit')
        ax.set_xlabel('Chip background (ADU)', fontsize=16)
        ax.set_ylabel('Counts', fontsize=16)
        ax.tick_params(axis='x', labelsize=16)
        ax.tick_params(axis='y', labelsize=16)
        plt.legend()
        plt.show()
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
            
    if return_alpha_beta:
        return fit_alpha, fit_beta
    else:
        return fit_alpha * fit_beta, fit_alpha * fit_beta * fit_beta


In [13]:
index = 46
phase_img = imread(phase_filenames[index])
mask_img = imread(mask_filenames[index])
venus_img = imread(venus_filenames[index])


In [14]:
SAVE_DIR = Path('/mnt/sda1/SMLAT/figures/supplementary/param_est_chromo/')

In [15]:
save_path = SAVE_DIR / Path('chip_bg.svg')
chromo_mean_var_bg_outside(venus_img, mask_img, plot=True, return_alpha_beta=True, save_path=save_path)

(161.07263108692743, 0.6733346526376891)

### EDT distribution fits

In [16]:
edt_data = chromo_edt_mean_variance_inside(venus_img, mask_img, plot=True, return_values=True)

In [17]:
edt_data

{'mean': {1: 114.16397290848919,
  2: 115.24250132485426,
  3: 116.11537358382331,
  4: 116.56604215456674,
  5: 117.08593681631224,
  6: 118.04283604135894,
  7: 122.46153846153847},
 'stddev': {1: 8.134112888446257,
  2: 8.087184946385637,
  3: 7.977773529851019,
  4: 7.961130339295636,
  5: 7.806935779351558,
  6: 7.420375732005541,
  7: 4.068932663211752},
 'counts': {1: 25986, 2: 18870, 3: 13946, 4: 12810, 5: 11427, 6: 3385, 7: 13},
 'values': {1: array([102, 104, 109, ...,  99,  99, 102], dtype=uint16),
  2: array([117, 113, 115, ...,  97, 103,  95], dtype=uint16),
  3: array([120, 105, 110, ...,  99, 100, 101], dtype=uint16),
  4: array([116, 109, 110, ...,  90,  92,  92], dtype=uint16),
  5: array([123, 114, 120, ..., 106,  98, 106], dtype=uint16),
  6: array([126,  96, 107, ..., 113, 109, 115], dtype=uint16),
  7: array([122, 121, 124, 125, 126, 126, 121, 111, 123, 126, 118, 126, 123],
        dtype=uint16)},
 'fits': {1: {'alpha': 197.22734098656704,
   'loc': -1.5,
   'beta'

In [24]:
fig, ax = plt.subplots()
for edt_index in range(1, 7):
    low, high = edt_data['values'][edt_index].min(), edt_data['values'][edt_index].max()
    fit_alpha_edt, fit_loc_edt, fit_beta_edt = scipy.stats.gamma.fit(edt_data['values'][edt_index], floc=-1.5)
    #ax.hist(edt_data['values'][edt_index], bins=np.arange(low, high), histtype='barstacked', label='data')
    ax.hist(np.random.gamma(shape=fit_alpha_edt, scale=fit_beta_edt, size=100000),
            bins=np.arange(low, high), histtype='step', label=str(edt_index), density=True, stacked=False)
    ax.set_xlabel('Background inside cells (ADU)', fontsize=16)
    ax.set_ylabel('Fitted gamma distribution', fontsize=16)
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)
    leg = ax.legend()
    leg.set_title('Distance from cell boundary', prop={'size':16})
    plt.setp(plt.gca().get_legend().get_texts(), fontsize='16')

    plt.show()
save_path = SAVE_DIR / Path('edt_vs_cellbg.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)


#### Mean and variances distributions over images (2 violin plots)

In [25]:
from tqdm import tqdm
import math

In [26]:
collected_means = {0:[], 1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
collected_stddev = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
for index in tqdm(range(len(venus_filenames))):
    if index > 100:
        phase_img = imread(phase_filenames[index])
        mask_img = imread(mask_filenames[index])
        venus_img = imread(venus_filenames[index])
        inside = chromo_edt_mean_variance_inside(venus_img, mask_img)
        means = inside['mean']
        stddev = inside['stddev']
        outside = chromo_mean_var_bg_outside(venus_img, mask_img, plot=False, return_alpha_beta=False)
        means[0] = outside[0]
        stddev[0] = math.sqrt(outside[1])
        for key, value in means.items():
            if key in collected_means:
                collected_means[key].append(value)
        for key, value in stddev.items():
            if key in collected_stddev:
                collected_stddev[key].append(value)

  aest = (3-s + np.sqrt((s-3)**2 + 24*s)) / (12*s)
  func = lambda a: np.log(a) - sc.digamma(a) - s
100%|█████████████████████████████████████████████████████████████| 601/601 [01:41<00:00,  5.92it/s]


In [32]:
import pandas as pd

In [33]:
collected_means

{0: [108.174613888942,
  108.16348177330076,
  108.2158797169576,
  108.1217168780194,
  108.37735217036783,
  108.30580563672792,
  108.0974363361944,
  108.04832405671043,
  108.16474464991065,
  108.07004695444229,
  108.20902443270366,
  108.1983960295321,
  108.17655800794658,
  108.31049640916916,
  108.1951580748122,
  107.99847016743352,
  107.90400197702458,
  108.25634014872979,
  108.29488022983362,
  108.26229032343919,
  108.23783202314168,
  108.37677853743688,
  108.19212987403841,
  108.13109729726133,
  108.01486223401514,
  108.26374965604558,
  108.03071836788942,
  108.1999850478469,
  108.03292967019632,
  108.09911760458527,
  108.05756406485501,
  108.12475051724567,
  108.14440577613125,
  108.18339849998962,
  108.16484632673689,
  108.1158131266562,
  108.23790180691972,
  107.98144769033429,
  108.21366485443995,
  108.08924326103654,
  108.18858789836246,
  108.17507284895983,
  108.05352923711794,
  108.1856147777884,
  108.20737707643958,
  108.05002666648

In [34]:
df = pd.DataFrame(collected_means)

In [35]:
df

Unnamed: 0,0,1,2,3,4,5,6
0,108.174614,113.694095,114.749436,115.393650,115.971240,116.251138,117.312738
1,108.163482,113.554231,114.597189,115.352183,115.940777,116.287523,117.113088
2,108.215880,113.689720,114.579276,115.531524,115.943919,116.368123,117.249921
3,108.121717,113.529056,114.549865,115.434758,115.896882,116.411044,117.432432
4,108.377352,113.816437,114.786449,115.674084,116.084050,116.660263,117.626662
...,...,...,...,...,...,...,...
495,107.919517,113.133504,113.902845,114.751329,115.103415,115.409082,115.731828
496,107.927306,112.823440,113.741084,114.470566,115.017280,115.385615,115.918157
497,107.855171,112.876162,113.684689,114.460268,114.867101,115.234766,116.044393
498,107.854763,112.910804,113.798416,114.646012,114.826792,115.295254,116.164596


In [36]:
import seaborn as sns

In [38]:


ax = sns.violinplot(data=df, linewidth=0.5)
ax.set_xlabel('Euclidean distance from boundary', fontsize=16)
ax.set_ylabel('Mean of fitted gamma distribution', fontsize=16)
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
fig = ax.get_figure()
save_path = SAVE_DIR / Path('edt_vs_mean_over_time.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)

In [39]:
df = pd.DataFrame(collected_stddev)

In [40]:


ax = sns.violinplot(data=df, linewidth=0.5)
ax.set_xlabel('Euclidean distance from boundary', fontsize=16)
ax.set_ylabel('Standard-deviation of fitted gamma distribution', fontsize=16)
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
fig = ax.get_figure()
save_path = SAVE_DIR / Path('edt_vs_stddev_over_time.svg')
fig.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)

#### Plot of edt mean, variance map on the cell mask

In [31]:
with open(edt_dump_path, 'rb') as fp:
    edt_noise_map = pickle.load(fp)

In [32]:
edt_noise_map

{'alphas': {1: 198.26783822336986,
  2: 202.19717812453737,
  3: 206.53614247375737,
  4: 211.50892435104902,
  5: 218.42672100183074,
  6: 236.5096097576979,
  0: 175.97058789482784},
 'betas': {1: 0.5817275038062398,
  2: 0.5755147241351868,
  3: 0.5680015773432711,
  4: 0.5567916977275827,
  5: 0.5408661434281803,
  6: 0.5033149478710346,
  0: 0.6266468482700739}}

In [33]:
def alpha_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [34]:
def beta_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [35]:
def mean_bg_mask(dists, alpha_map, beta_map):
    dists_copy = np.copy(dists)
    for edt_val, alpha_bg in alpha_map.items():
        dists_copy[dists == int(edt_val)] = alpha_bg * beta_map[edt_val]
    return dists_copy

In [36]:
def variance_bg_mask(dists, alpha_map, beta_map):
    dists_copy = np.copy(dists)
    for edt_val, alpha_bg in alpha_map.items():
        dists_copy[dists == int(edt_val)] = math.sqrt(alpha_bg * beta_map[edt_val] * beta_map[edt_val])
    return dists_copy

In [37]:

dists = np.zeros_like(mask_img)
dists = edt.edt(mask_img)
dists = np.round(dists)
dists = np.ceil(dists)

In [38]:
fitted_beta_map = edt_noise_map['betas']
fitted_alpha_map = edt_noise_map['alphas']
mean_bg_cells = mean_bg_mask(dists, fitted_alpha_map, fitted_beta_map)
stddev_bg_cells = variance_bg_mask(dists, fitted_alpha_map, fitted_beta_map)
alpha_bg_cells = alpha_bg_mask(dists, fitted_alpha_map)
beta_bg_cells = beta_bg_mask(dists, fitted_beta_map)

In [39]:
alpha_t = torch.from_numpy(alpha_bg_cells)
beta_t = 1.0/torch.from_numpy(beta_bg_cells)

In [40]:
m = torch.distributions.gamma.Gamma(concentration=alpha_t, rate=beta_t)

In [41]:
sample = m.sample()

In [44]:
fig, ax = plt.subplots(nrows=1, ncols=3)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells)
ax[0].set_title('Alpha values')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
ax[1].imshow(beta_bg_cells)
#ax[1].set_title('Std dev bg values')
ax[1].set_title('Beta values')
ax[2].imshow(dists)
ax[2].set_title('EDT')
plt.show()

In [45]:
plt.figure()
plt.imshow(sample.cpu().numpy(), cmap='gray')
plt.show()

In [55]:
roi = [500, 400, 256, 256]

In [56]:
save_phase = SAVE_DIR / Path('mean_image.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(mean_bg_cells[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [58]:
save_phase = SAVE_DIR / Path('stddev.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(stddev_bg_cells[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [59]:
save_phase = SAVE_DIR / Path('distance_from_boundary.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(dists[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]])
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)

In [60]:
save_phase = SAVE_DIR / Path('sampled_bg.svg')
fig, ax = plt.subplots(nrows=1, ncols=1)
im = ax.imshow(sample.cpu().numpy()[roi[0]: roi[0] + roi[2], roi[1]: roi[1] + roi[3]], cmap='gray')
ax.set_axis_off()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
fig.savefig(save_phase, bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
plt.close(fig)