In [1]:
%reload_ext autoreload
%autoreload 2
#%reload_ext notexbook'
#%texify

#### Generating Evaluation data using 30 random segmentation masks

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

import os

os.chdir('/home/axtr7550/Chromosome_project/cellbgnet/')
import cellbgnet
import cellbgnet.utils

from cellbgnet.datasets import DataSimulator
from cellbgnet.utils.hardware import cpu, gpu
from cellbgnet.model import CellBGModel
from cellbgnet.simulation.psf_kernel import SMAPSplineCoefficient
from cellbgnet.generic.emitter import EmitterSet
from cellbgnet.train_loss_infer import generate_probmap_cells

os.chdir('/home/axtr7550/Chromosome_project/cellbgnet/notebooks')

from skimage.io import imread
from skimage.measure import label
import random
import edt
from skimage.filters import gaussian
from scipy.ndimage import rotate
import random
import pickle
import pathlib
from pathlib import Path
%matplotlib qt5

In [3]:
param_file = '../cellbgnet/utils/reference_files/reference_Axel_45deg_lower_photon.yaml'
param = cellbgnet.utils.param_io.ParamHandling().load_params(param_file)

In [4]:
psf_params = param.PSF.to_dict()
simulation_params = param.Simulation.to_dict()
hardware_params = param.Hardware.to_dict()
train_size = simulation_params['train_size']

In [5]:
with open(param.Simulation.edt_noise_map_path, 'rb') as fp:
    edt_noise_map = pickle.load(fp)

In [6]:
model = cellbgnet.model.CellBGModel(param)

INITIATING CUDA IMPLEMENTATION
training sliding windows on camera chip:
Area num: 0, field_xy: [0, 127, 0, 127]
Area num: 1, field_xy: [114, 241, 0, 127]
Area num: 2, field_xy: [228, 355, 0, 127]
Area num: 3, field_xy: [342, 469, 0, 127]
Area num: 4, field_xy: [456, 583, 0, 127]
Area num: 5, field_xy: [570, 697, 0, 127]
Area num: 6, field_xy: [684, 811, 0, 127]
Area num: 7, field_xy: [798, 925, 0, 127]
Area num: 8, field_xy: [912, 1039, 0, 127]
Area num: 9, field_xy: [1026, 1153, 0, 127]
Area num: 10, field_xy: [1140, 1267, 0, 127]
Area num: 11, field_xy: [1174, 1301, 0, 127]
Area num: 12, field_xy: [0, 127, 114, 241]
Area num: 13, field_xy: [114, 241, 114, 241]
Area num: 14, field_xy: [228, 355, 114, 241]
Area num: 15, field_xy: [342, 469, 114, 241]
Area num: 16, field_xy: [456, 583, 114, 241]
Area num: 17, field_xy: [570, 697, 114, 241]
Area num: 18, field_xy: [684, 811, 114, 241]
Area num: 19, field_xy: [798, 925, 114, 241]
Area num: 20, field_xy: [912, 1039, 114, 241]
Area num: 21,

#### Read 30 random segmentation masks

In [7]:
filenames = model.data_generator.cell_mask_filenames


In [8]:
random_filenames = random.choices(filenames, k=30)

In [9]:
cell_masks = []
for i, file in enumerate(random_filenames,0):
    img = imread(file)
    #img = rotate(img, random_angles[i], reshape=False)
    cell_masks.append(img)



In [10]:
img.shape

(1041, 1302)

In [11]:
cell_masks_batch = np.stack(cell_masks)

In [12]:
cell_masks_batch.shape

(30, 1041, 1302)

In [26]:
plt.figure()
plt.imshow(cell_masks_batch[0])
plt.show()

In [13]:
prob_map = (cell_masks_batch > 0) * 0.003125

In [16]:
plt.figure()
plt.imshow(prob_map[0])
plt.show()

In [14]:
data_gen = DataSimulator(psf_params, simulation_params, hardware_params)

INITIATING CUDA IMPLEMENTATION


In [15]:
S, X_os, Y_os, Z, I, field_xy = data_gen.sampling(batch_size=30, prob_map=gpu(prob_map), local_context=None,
                                                  iter_num=None)

In [102]:
plt.close('all')

In [None]:

S.shape

In [19]:
plt.figure()
plt.imshow(S[0][0].cpu().numpy())
plt.show()

In [20]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=6)
ax[0].imshow(S[index][0].cpu().numpy())
ax[0].set_title('Dot location')
x_os_img = ax[1].imshow(X_os[index][0].cpu().numpy())
ax[1].set_title('X offset')
#fig.colorbar(x_os_img, ax=ax[1])
y_os_img = ax[2].imshow(Y_os[index][0].cpu().numpy())
ax[2].set_title('Y offset')
#fig.colorbar(y_os_img, ax=ax[2])
z_vals_img = ax[3].imshow(Z[index][0].cpu().numpy())
ax[3].set_title('Z value')
#fig.colorbar(z_vals_img, ax=ax[3])
i_vals_img = ax[4].imshow(I[index][0].cpu().numpy())
ax[4].set_title('Intensity Value')
#fig.colorbar(i_vals_img, ax=ax[4])
ax[5].imshow(cell_masks_batch[index])
ax[5].set_title('cell mask')
plt.show()

In [16]:
imgs_sim = data_gen.simulate_psfs(S, X_os, Y_os, Z, I)
#6 seconds on cuda for 128

In [101]:
#imgs_sim.to('cpu');
#torch.cuda.empty_cache()

In [22]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=5)
ax[0].imshow(prob_map[index])
ax[0].set_title('Probabilty map')
ax[1].imshow(cell_masks_batch[index])
ax[1].set_title('Cell mask sampled')
ax[2].imshow(imgs_sim[index, 0].cpu().numpy())
ax[2].set_title('Simulated PSFs')
ax[3].imshow(S[index, 0].cpu().numpy())
ax[3].set_title('Emitter pixel locs')
ax[4].imshow(I[index, 0].cpu().numpy() * psf_params['photon_scale'])
ax[4].set_title('Intensity values')
plt.show()

In [19]:
S.shape

torch.Size([30, 1041, 1302])

In [17]:
xyzi = torch.cat([X_os[:, :, None], Y_os[:, :, None], Z[:, :, None], I[:, :, None]], 2)
xyzi = xyzi[:, 0]

In [18]:
S = S[:, 0]

In [20]:
s_inds = tuple(S.nonzero().transpose(1, 0))
# get these molecules' sub-pixel xy offsets, z positions and photons
xyzi_true = xyzi[s_inds[0], :, s_inds[1], s_inds[2]]

In [21]:
xyzi_true

tensor([[-0.1504,  0.2609, -0.6300,  0.9678],
        [ 0.2061,  0.0233, -0.1468,  0.2924],
        [-0.3372, -0.0662,  0.6360,  0.5397],
        ...,
        [-0.0041,  0.0134, -0.0677,  0.3519],
        [ 0.3476, -0.1170, -0.5400,  0.5809],
        [ 0.3627,  0.4325,  0.8345,  0.9288]], device='cuda:0')

In [22]:
xyzi_true[:, 0] += s_inds[2] + 0.5

In [23]:
xyzi_true[:, 1] += s_inds[1] + 0.5

In [24]:
pos_tar = xyzi_true[:, :3]

In [25]:
pos_tar.shape

torch.Size([10047, 3])

In [26]:
pos_tar[:, 2] *= 500

In [27]:
pos_tar

tensor([[ 592.3497,   99.7609, -314.9914],
        [1016.7061,   99.5233,  -73.4096],
        [ 932.1628,  100.4338,  318.0047],
        ...,
        [ 619.4959,  955.5134,  -33.8575],
        [ 721.8476,  955.3829, -269.9925],
        [ 736.8627,  956.9324,  417.2450]], device='cuda:0')

In [28]:
xyzi_true[0, 3]

tensor(0.9678, device='cuda:0')

In [29]:
plt.figure()
plt.imshow((single_psfs[0] * 0.3924 *  3000).numpy())
plt.show()

type: name 'single_psfs' is not defined

In [30]:
from cellbgnet.utils.plot_funcs import PlotFrameCoord

In [31]:
imgs_sim.shape

torch.Size([30, 1, 1041, 1302])

In [32]:
plt.figure()
PlotFrameCoord(imgs_sim[0].cpu(), pos_tar=pos_tar.cpu(),plot_colorbar_frame=True, annotate_tar_z=True).plot()
plt.show()

In [33]:
plt.figure()
plt.imshow(imgs_sim[13, 0].cpu().numpy())
plt.show()

### Noise adding inside and outside cells

In [34]:

dists = np.zeros_like(cell_masks_batch)
for i in range(len(cell_masks_batch)):
    dists[i] = edt.edt(cell_masks_batch[i])

In [35]:
dists.shape

(30, 1041, 1302)

In [36]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=3)
ax[0].imshow(cell_masks_batch[index])
ax[0].set_title('Cell mask')
ax[1].imshow(dists[index])
ax[1].set_title('Edt')
ax[2].imshow(prob_map[index])
ax[2].set_title('Prob map')
plt.show()

In [37]:
def mean_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [38]:
def variance_bg_mask(dists, variance_map):
    dists_copy = np.copy(dists)
    for edt_val, stddev_bg in variance_map.items():
        dists_copy[dists == int(edt_val)] = stddev_bg
    return dists_copy

In [39]:
dists = np.ceil(dists)

In [40]:
#mean_bg_cells = mean_bg_mask(dists, mean_map)
#stddev_bg_cells = variance_bg_mask(dists, stddev_map)
fitted_beta_map = edt_noise_map['betas']
fitted_alpha_map = edt_noise_map['alphas']
alpha_bg_cells = mean_bg_mask(dists, fitted_alpha_map)
beta_bg_cells = variance_bg_mask(dists, fitted_beta_map)

In [41]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=2)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells[index])
ax[0].set_title('Alpha values')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
ax[1].imshow(beta_bg_cells[index])
#ax[1].set_title('Std dev bg values')
ax[1].set_title('Beta values')
plt.show()

In [42]:
alpha_t = torch.from_numpy(alpha_bg_cells)
beta_t = 1.0/torch.from_numpy(beta_bg_cells)

In [43]:
alpha_t.shape, beta_t.shape

(torch.Size([30, 1041, 1302]), torch.Size([30, 1041, 1302]))

In [44]:
m = torch.distributions.gamma.Gamma(concentration=alpha_t, rate=beta_t)
sample = m.sample()

In [45]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=3)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells[index])
ax[0].set_title('Alpha bg cells')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
#ax[1].set_title('Std dev bg values')
ax[1].imshow(beta_bg_cells[index])
ax[1].set_title('Beta bg cells')
img = ax[2].imshow(sample[index].cpu().numpy(), cmap='gray')
ax[2].set_title('Sampled bg values')
fig.colorbar(img, ax=ax[2])
plt.show()

In [46]:
index = 0
plt.figure()
plt.imshow(sample[index].cpu().numpy(), cmap='gray')
plt.show()

### Dot overlay

In [47]:
baseline = 103.0
e_per_adu = 0.39
qe = 0.95
bg_photons = (sample - baseline) * e_per_adu / qe

In [48]:
bg_photons = bg_photons[:, None]

In [49]:
bg_photons = torch.clamp(bg_photons, min=0.0)

In [50]:
imgs_sim.shape, bg_photons.shape

(torch.Size([30, 1, 1041, 1302]), torch.Size([30, 1, 1041, 1302]))

In [51]:
bg_photons = bg_photons.to('cuda:0')

In [52]:
plt.figure()
plt.imshow(bg_photons[index, 0].cpu().numpy())
plt.show()

In [53]:
after_dots = imgs_sim + bg_photons


In [54]:
S.shape

torch.Size([30, 1041, 1302])

In [55]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=6)
ax[0].imshow(prob_map[index])
ax[0].set_title('Probabilty map')
ax[1].imshow(cell_masks_batch[index])
ax[1].set_title('Cell mask sampled')
ax[2].imshow(imgs_sim[index, 0].cpu().numpy())
ax[2].set_title('Simulated PSFs')
#ax[3].imshow(S[index, 0].cpu().numpy())
ax[3].set_title('Emitter pixel locs')
ax[4].imshow(I[index, 0].cpu().numpy() * psf_params['photon_scale'])
ax[4].set_title('Intensity values')
ax[5].imshow(after_dots[index][0].cpu().numpy(), cmap='gray')
ax[5].set_title('After overlaying dots')
plt.show()

### Adding camera noise

In [56]:
RN = 2.2

In [57]:
after_dots = torch.distributions.Poisson(after_dots * qe).sample()

type: Expected parameter rate (Tensor of shape (30, 1, 1041, 1302)) of distribution Poisson(rate: torch.Size([30, 1, 1041, 1302])) to satisfy the constraint GreaterThanEq(lower_bound=0.0), but found invalid values:
tensor([[[[ 3.6433,  0.0000,  2.6589,  ...,  2.1402,  0.0000,  7.4989],
          [ 0.0000,  5.6488,  4.4491,  ...,  2.2344,  9.9896,  3.1848],
          [ 0.0000,  0.2239,  4.4357,  ...,  5.9087,  0.0000,  0.8490],
          ...,
          [ 1.4178,  0.0000,  9.3616,  ...,  0.8897,  6.4457,  0.2179],
          [ 2.4254,  0.0000,  0.0000,  ...,  6.8551,  2.0930,  5.8479],
          [ 4.4903,  0.1232,  0.9024,  ..., 10.9690,  1.5086,  9.8907]]],


        [[[ 0.0000,  4.9036,  4.9885,  ...,  3.6938,  7.1447,  5.7808],
          [ 1.4903,  4.2388,  9.7796,  ...,  0.0000,  3.3015,  1.6471],
          [ 9.1195,  0.0000,  2.9933,  ...,  0.0000,  0.8935,  1.8073],
          ...,
          [ 4.3170,  1.4932,  0.0000,  ...,  2.1433,  3.4383,  5.7284],
          [ 7.1442, 11.0152,  0.0000,  ...,  0.9844,  0.0000,  0.0000],
          [ 3.1203,  4.5474,  5.8961,  ...,  6.1050,  0.0000,  1.9243]]],


        [[[ 1.8137,  0.0000,  0.0000,  ...,  0.0000,  3.8405,  1.1628],
          [ 0.0000,  1.5733,  1.0764,  ...,  0.0000,  1.4365,  1.7129],
          [ 0.0000,  0.0000,  0.4781,  ...,  3.4826,  4.2648,  0.0000],
          ...,
          [ 4.2627,  0.3268,  2.2367,  ...,  0.0000,  0.0000,  3.7093],
          [ 6.2466,  4.7369,  7.1070,  ...,  2.2688,  0.0000,  5.8247],
          [ 0.0000,  1.5018,  0.0000,  ...,  1.4406,  0.0000,  1.6580]]],


        ...,


        [[[ 2.7425,  7.8029,  0.0000,  ...,  0.7239,  0.0000,  0.0000],
          [ 0.0000,  4.5423,  2.1154,  ...,  0.0000,  3.6284,  5.8735],
          [ 0.0000,  0.0000,  2.1484,  ...,  0.0000,  0.0000,  3.8325],
          ...,
          [ 0.0000,  3.6809,  0.0000,  ...,  1.7758,  0.0000,  6.1637],
          [ 3.6228,  4.6392,  3.0163,  ...,  0.9121,  5.3673,  6.8132],
          [ 3.5586,  3.7652,  1.3335,  ...,  0.0000,  6.6496,  0.0000]]],


        [[[ 2.5033,  3.3326,  2.6954,  ...,  2.3767,  4.2891,  4.4320],
          [ 0.0000,  0.3468,  1.2668,  ...,  4.0585,  1.8026,  2.5472],
          [ 6.0020,  1.5130,  8.9904,  ...,  6.6324,  0.0000,  3.7817],
          ...,
          [ 5.6655,  0.0000,  0.0000,  ...,  3.1419,  0.0000,  1.3429],
          [ 2.8069,  3.0513,  9.4921,  ...,  5.6689,  5.3812,  0.8931],
          [ 2.3090,  3.5934,  4.2744,  ...,  0.0000,  4.5713,  1.7096]]],


        [[[ 7.6695,  9.9347,  2.9691,  ...,  0.0000,  0.6845,  0.0000],
          [ 1.7859,  9.3476,  5.2251,  ...,  5.0698,  0.1342,  7.2559],
          [ 3.8193,  6.8426,  2.6687,  ...,  3.4620,  0.2959,  3.0779],
          ...,
          [ 2.5391,  2.7302,  1.4567,  ...,  1.0544,  0.0000,  4.6802],
          [ 0.0000,  0.4532,  0.0000,  ...,  0.0000,  5.3342,  2.7940],
          [ 3.5668,  0.0000,  0.0828,  ...,  1.9880,  0.0000,  2.7743]]]],
       device='cuda:0', dtype=torch.float64)

In [58]:
zeros = torch.zeros_like(after_dots)

In [59]:
readout_noise = torch.distributions.Normal(zeros, zeros + RN).sample()

In [60]:
after_dots = after_dots + readout_noise

In [61]:
final_dots_for_net = torch.clamp((after_dots/ e_per_adu) + baseline, min=0)

In [62]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=6)
ax[0].imshow(sample[index].cpu().numpy(), cmap='gray')
ax[0].set_title('Sampled background ADU')
ax[1].imshow(cell_masks_batch[index])
ax[1].set_title('Cell mask sampled')
ax[2].imshow(imgs_sim[index, 0].cpu().numpy())
ax[2].set_title('Simulated PSFs')
#ax[3].imshow(S[index, 0].cpu().numpy())
#ax[3].set_title('Emitter pixel locs')
ax[3].imshow(I[index, 0].cpu().numpy() * psf_params['photon_scale'])
ax[3].set_title('Intensity values')
ax[4].imshow(after_dots[index][0].cpu().numpy(), cmap='gray')
ax[4].set_title('After overlaying dots (photons)')
ax[5].imshow(final_dots_for_net[index][0].cpu().numpy(), cmap='gray')
ax[5].set_title('After camer noise (ADU)')
plt.tight_layout()
plt.show()

In [63]:
plt.figure()
plt.imshow(final_dots_for_net[index][0].cpu().numpy(), cmap='gray')
plt.show()

In [72]:
index = 3
plt.figure()
plt.imshow(final_dots_for_net[index][0].cpu().numpy(), cmap='gray', vmin=80, vmax=300)
plt.colorbar()
plt.show()

In [70]:
final_dots_for_net[index][0].cpu().numpy().max()

304.7641911094785

In [66]:
final_dots_for_net[index][0].cpu().numpy().min()

76.21392442269159

In [68]:
final_dots_for_net.shape

torch.Size([30, 1, 1041, 1302])