In [1]:
%reload_ext autoreload
%autoreload 2
#%reload_ext notexbook'
#%texify

#### Generating Evaluation data using 30 random segmentation masks

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

import os

os.chdir('/home/axtr7550/Chromosome_project/cellbgnet/')
import cellbgnet
import cellbgnet.utils

from cellbgnet.datasets import DataSimulator
from cellbgnet.utils.hardware import cpu, gpu
from cellbgnet.model import CellBGModel
from cellbgnet.simulation.psf_kernel import SMAPSplineCoefficient
from cellbgnet.generic.emitter import EmitterSet
from cellbgnet.train_loss_infer import generate_probmap_cells

os.chdir('/home/axtr7550/Chromosome_project/cellbgnet/notebooks')

from skimage.io import imread
from skimage.measure import label
import random
import edt
from skimage.filters import gaussian
from scipy.ndimage import rotate
import random
import pickle
import pathlib
from pathlib import Path
%matplotlib qt5

In [3]:
param_file = '../cellbgnet/utils/reference_files/reference.yaml'
param = cellbgnet.utils.param_io.ParamHandling().load_params(param_file)

In [4]:
psf_params = param.PSF.to_dict()
simulation_params = param.Simulation.to_dict()
hardware_params = param.Hardware.to_dict()
train_size = simulation_params['train_size']

In [5]:
with open(param.Simulation.edt_noise_map_path, 'rb') as fp:
    edt_noise_map = pickle.load(fp)

In [6]:
model = cellbgnet.model.CellBGModel(param)

INITIATING CPU IMPLEMENTATION
training sliding windows on camera chip:
Area num: 0, field_xy: [0, 127, 0, 127]
Area num: 1, field_xy: [114, 241, 0, 127]
Area num: 2, field_xy: [228, 355, 0, 127]
Area num: 3, field_xy: [342, 469, 0, 127]
Area num: 4, field_xy: [456, 583, 0, 127]
Area num: 5, field_xy: [570, 697, 0, 127]
Area num: 6, field_xy: [684, 811, 0, 127]
Area num: 7, field_xy: [798, 925, 0, 127]
Area num: 8, field_xy: [912, 1039, 0, 127]
Area num: 9, field_xy: [1026, 1153, 0, 127]
Area num: 10, field_xy: [1140, 1267, 0, 127]
Area num: 11, field_xy: [1174, 1301, 0, 127]
Area num: 12, field_xy: [0, 127, 114, 241]
Area num: 13, field_xy: [114, 241, 114, 241]
Area num: 14, field_xy: [228, 355, 114, 241]
Area num: 15, field_xy: [342, 469, 114, 241]
Area num: 16, field_xy: [456, 583, 114, 241]
Area num: 17, field_xy: [570, 697, 114, 241]
Area num: 18, field_xy: [684, 811, 114, 241]
Area num: 19, field_xy: [798, 925, 114, 241]
Area num: 20, field_xy: [912, 1039, 114, 241]
Area num: 21, 

#### Read 30 random segmentation masks

In [7]:
filenames = model.data_generator.cell_mask_filenames


In [8]:
random_filenames = random.choices(filenames, k=30)

In [9]:
cell_masks = []
for i, file in enumerate(random_filenames,0):
    img = imread(file)
    #img = rotate(img, random_angles[i], reshape=False)
    cell_masks.append(img)



In [10]:
img.shape

(1041, 1302)

In [11]:
cell_masks_batch = np.stack(cell_masks)

In [12]:
cell_masks_batch.shape

(30, 1041, 1302)

In [12]:
plt.figure()
plt.imshow(cell_masks_batch[0])
plt.show()

In [13]:
prob_map = (cell_masks_batch > 0) * 0.003125

In [14]:
plt.figure()
plt.imshow(prob_map[0])
plt.show()

In [14]:
data_gen = DataSimulator(psf_params, simulation_params, hardware_params)

INITIATING CPU IMPLEMENTATION


In [15]:
S, X_os, Y_os, Z, I, field_xy = data_gen.sampling(batch_size=128, prob_map=gpu(prob_map), local_context=None,
                                                  iter_num=None)

In [16]:
S.shape

torch.Size([30, 1, 1041, 1302])

In [18]:
plt.figure()
plt.imshow(S[0][0].cpu().numpy())
plt.show()

In [17]:
index = 0
fig, ax = plt.subplots(nrows=1, ncols=6)
ax[0].imshow(S[index][0].cpu().numpy())
ax[0].set_title('Dot location')
x_os_img = ax[1].imshow(X_os[index][0].cpu().numpy())
ax[1].set_title('X offset')
#fig.colorbar(x_os_img, ax=ax[1])
y_os_img = ax[2].imshow(Y_os[index][0].cpu().numpy())
ax[2].set_title('Y offset')
#fig.colorbar(y_os_img, ax=ax[2])
z_vals_img = ax[3].imshow(Z[index][0].cpu().numpy())
ax[3].set_title('Z value')
#fig.colorbar(z_vals_img, ax=ax[3])
i_vals_img = ax[4].imshow(I[index][0].cpu().numpy())
ax[4].set_title('Intensity Value')
#fig.colorbar(i_vals_img, ax=ax[4])
ax[5].imshow(cell_masks_batch[index])
ax[5].set_title('cell mask')
plt.show()

In [20]:
imgs_sim = data_gen.simulate_psfs(S, X_os, Y_os, Z, I)
#6 seconds on cuda for 128

ON CPU /n

In [21]:
imgs_sim.to('cpu');
torch.cuda.empty_cache()

In [110]:
index = 14
fig, ax = plt.subplots(nrows=1, ncols=5)
ax[0].imshow(prob_map[index])
ax[0].set_title('Probabilty map')
ax[1].imshow(cell_masks_batch[index])
ax[1].set_title('Cell mask sampled')
ax[2].imshow(imgs_sim[index, 0].cpu().numpy())
ax[2].set_title('Simulated PSFs')
ax[3].imshow(S[index, 0].cpu().numpy())
ax[3].set_title('Emitter pixel locs')
ax[4].imshow(I[index, 0].cpu().numpy() * psf_params['photon_scale'])
ax[4].set_title('Intensity values')
plt.show()

In [111]:
S.shape

torch.Size([16, 1, 1041, 1302])

In [112]:
xyzi = torch.cat([X_os[:, :, None], Y_os[:, :, None], Z[:, :, None], I[:, :, None]], 2)
xyzi = xyzi[:, 0]

In [113]:
S = S[:, 0]

In [114]:
s_inds = tuple(S.nonzero().transpose(1, 0))
# get these molecules' sub-pixel xy offsets, z positions and photons
xyzi_true = xyzi[s_inds[0], :, s_inds[1], s_inds[2]]

In [115]:
xyzi_true

tensor([[-0.4194, -0.0597, -0.5988,  0.9801],
        [ 0.0782,  0.0996,  0.8808,  0.5664],
        [ 0.2575, -0.3706,  0.2145,  0.9321],
        ...,
        [-0.4316,  0.1570,  0.2781,  0.2726],
        [ 0.0060,  0.2557, -0.3518,  0.3088],
        [ 0.3582, -0.3238,  0.9471,  0.3994]], device='cuda:0')

In [72]:
xyzi_true[:, 0] += s_inds[2] + 0.5

In [73]:
xyzi_true[:, 1] += s_inds[1] + 0.5

In [74]:
pos_tar = xyzi_true[:, :3]

In [75]:
pos_tar.shape

torch.Size([310, 3])

In [76]:
pos_tar[:, 2] *= 500

In [77]:
pos_tar

tensor([[ 890.2706,   99.5039,  165.7851],
        [ 395.2706,  100.4327,   11.7564],
        [ 896.1362,  100.8500,  366.1547],
        [ 660.7841,  101.7123,  -56.5692],
        [ 466.5576,  103.6314,  224.9723],
        [ 880.7797,  103.7611,  377.2502],
        [ 759.2938,  104.0477,  121.3163],
        [ 962.3259,  104.4124, -123.3896],
        [ 515.4755,  105.3145, -177.2394],
        [ 964.3440,  106.8332,  -50.8567],
        [ 424.2227,  108.5312,  139.4918],
        [ 801.9818,  108.3000,   93.4972],
        [ 819.9946,  144.0252,   10.2363],
        [ 745.3091,  145.8513,  -73.6389],
        [ 866.8130,  145.3471, -332.9089],
        [ 911.1019,  146.7239,  459.3462],
        [ 519.8839,  147.5663,  -20.2745],
        [ 961.2925,  148.5636, -389.7341],
        [ 969.5105,  148.6700,  137.1928],
        [ 874.7700,  149.6493,  399.7937],
        [ 908.9656,  150.3132,  416.8392],
        [ 761.9042,  152.9388, -117.0293],
        [1083.5704,  152.9137,  195.6698],
        [ 9

In [78]:
xyzi_true[0, 3]

tensor(0.8469, device='cuda:0')

In [79]:
plt.figure()
plt.imshow((single_psfs[0] * 0.3924 *  3000).numpy())
plt.show()

type: name 'single_psfs' is not defined

In [35]:
from cellbgnet.utils.plot_funcs import PlotFrameCoord

In [36]:
imgs_sim.shape

torch.Size([1, 1, 1041, 1302])

In [37]:
plt.figure()
PlotFrameCoord(imgs_sim[0].cpu(), pos_tar=pos_tar.cpu(),plot_colorbar_frame=True, annotate_tar_z=True).plot()
plt.show()

In [25]:
plt.figure()
plt.imshow(imgs_sim[13, 0].cpu().numpy())
plt.show()

### Noise adding inside and outside cells

In [52]:

dists = np.zeros_like(cell_masks_batch)
for i in range(len(cell_masks_batch)):
    dists[i] = edt.edt(cell_masks_batch[i])

In [53]:
dists.shape

(30, 1041, 1302)

In [54]:
index = 12
fig, ax = plt.subplots(nrows=1, ncols=3)
ax[0].imshow(cell_masks_batch[index])
ax[0].set_title('Cell mask')
ax[1].imshow(dists[index])
ax[1].set_title('Edt')
ax[2].imshow(prob_map[index])
ax[2].set_title('Prob map')
plt.show()

In [55]:
def mean_bg_mask(dists, mean_map):
    dists_copy = np.copy(dists)
    for edt_val, mean_bg in mean_map.items():
        dists_copy[dists == int(edt_val)] = mean_bg
    return dists_copy

In [56]:
def variance_bg_mask(dists, variance_map):
    dists_copy = np.copy(dists)
    for edt_val, stddev_bg in variance_map.items():
        dists_copy[dists == int(edt_val)] = stddev_bg
    return dists_copy

In [57]:
dists = np.ceil(dists)

In [58]:
#mean_bg_cells = mean_bg_mask(dists, mean_map)
#stddev_bg_cells = variance_bg_mask(dists, stddev_map)
fitted_beta_map = edt_noise_map['betas']
fitted_alpha_map = edt_noise_map['alphas']
alpha_bg_cells = mean_bg_mask(dists, fitted_alpha_map)
beta_bg_cells = variance_bg_mask(dists, fitted_beta_map)

In [59]:
index = 10
fig, ax = plt.subplots(nrows=1, ncols=2)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells[index])
ax[0].set_title('Alpha values')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
ax[1].imshow(beta_bg_cells[index])
#ax[1].set_title('Std dev bg values')
ax[1].set_title('Beta values')
plt.show()

In [60]:
alpha_t = torch.from_numpy(alpha_bg_cells)
beta_t = 1.0/torch.from_numpy(beta_bg_cells)

In [61]:
alpha_t.shape, beta_t.shape

(torch.Size([30, 1041, 1302]), torch.Size([30, 1041, 1302]))

In [62]:
m = torch.distributions.gamma.Gamma(concentration=alpha_t, rate=beta_t)
sample = m.sample()

In [63]:
index = 10
fig, ax = plt.subplots(nrows=1, ncols=3)
#ax[0].imshow(mean_bg_cells[index])
ax[0].imshow(alpha_bg_cells[index])
ax[0].set_title('Alpha bg cells')
#ax[0].set_title('Mean bg values')
#ax[1].imshow(stddev_bg_cells[index])
#ax[1].set_title('Std dev bg values')
ax[1].imshow(beta_bg_cells[index])
ax[1].set_title('Beta bg cells')
img = ax[2].imshow(sample[index].cpu().numpy(), cmap='gray')
ax[2].set_title('Sampled bg values')
fig.colorbar(img, ax=ax[2])
plt.show()

In [64]:
index = 15
plt.figure()
plt.imshow(sample[index].cpu().numpy(), cmap='gray')
plt.show()

### Dot overlay

In [65]:
baseline = 103.0
e_per_adu = 0.39
qe = 0.95
bg_photons = (sample - baseline) * e_per_adu / qe

In [66]:
bg_photons = bg_photons[:, None]

In [67]:
bg_photons = torch.clamp(bg_photons, min=0.0)

In [68]:
imgs_sim.shape, bg_photons.shape

(torch.Size([30, 1, 1041, 1302]), torch.Size([30, 1, 1041, 1302]))

In [69]:
bg_photons = bg_photons.to('cuda:0')

In [70]:
plt.figure()
plt.imshow(bg_photons[index, 0].cpu().numpy())
plt.show()

In [71]:
after_dots = imgs_sim + bg_photons

In [72]:
index = 10
fig, ax = plt.subplots(nrows=1, ncols=6)
ax[0].imshow(prob_map[index])
ax[0].set_title('Probabilty map')
ax[1].imshow(cell_masks_batch[index])
ax[1].set_title('Cell mask sampled')
ax[2].imshow(imgs_sim[index, 0].cpu().numpy())
ax[2].set_title('Simulated PSFs')
ax[3].imshow(S[index, 0].cpu().numpy())
ax[3].set_title('Emitter pixel locs')
ax[4].imshow(I[index, 0].cpu().numpy() * psf_params['photon_scale'])
ax[4].set_title('Intensity values')
ax[5].imshow(after_dots[index][0].cpu().numpy(), cmap='gray')
ax[5].set_title('After overlaying dots')
plt.show()

### Adding camera noise

In [73]:
RN = 2.2

In [74]:
after_dots = torch.distributions.Poisson(after_dots * qe).sample()

In [75]:
zeros = torch.zeros_like(after_dots)

In [76]:
readout_noise = torch.distributions.Normal(zeros, zeros + RN).sample()

In [77]:
after_dots = after_dots + readout_noise

In [78]:
final_dots_for_net = torch.clamp((after_dots/ e_per_adu) + baseline, min=0)

In [79]:
index = 10
fig, ax = plt.subplots(nrows=1, ncols=7)
ax[0].imshow(sample[index].cpu().numpy(), cmap='gray')
ax[0].set_title('Sampled background ADU')
ax[1].imshow(cell_masks_batch[index])
ax[1].set_title('Cell mask sampled')
ax[2].imshow(imgs_sim[index, 0].cpu().numpy())
ax[2].set_title('Simulated PSFs')
ax[3].imshow(S[index, 0].cpu().numpy())
ax[3].set_title('Emitter pixel locs')
ax[4].imshow(I[index, 0].cpu().numpy() * psf_params['photon_scale'])
ax[4].set_title('Intensity values')
ax[5].imshow(after_dots[index][0].cpu().numpy(), cmap='gray')
ax[5].set_title('After overlaying dots (photons)')
ax[6].imshow(final_dots_for_net[index][0].cpu().numpy(), cmap='gray')
ax[6].set_title('After camer noise (ADU)')
plt.tight_layout()
plt.show()

In [80]:
plt.figure()
plt.imshow(final_dots_for_net[index][0].cpu().numpy(), cmap='gray')
plt.show()

In [None]:
plt.figure()
plt.imshow()

In [54]:
index = 3
plt.figure()
plt.imshow(final_dots_for_net[index][0].cpu().numpy(), cmap='gray')
plt.colorbar()
plt.show()

In [51]:
final_dots_for_net[index][0].cpu().numpy().max()

331.9748

In [52]:
final_dots_for_net[index][0].cpu().numpy().min()

76.83658