In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

### Loss function balancing at different z heights to get to equi-probable detections over full range

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import cellbgnet
import cellbgnet.utils
from cellbgnet.datasets import DataSimulator
from cellbgnet.utils.hardware import cpu, gpu
from cellbgnet.model import CellBGModel
from cellbgnet.simulation.psf_kernel import SMAPSplineCoefficient
from cellbgnet.generic.emitter import EmitterSet
from cellbgnet.train_loss_infer import generate_probmap_cells
from cellbgnet.analyze_eval import spline_crlb_plot
from cellbgnet.utils.plot_funcs import plot_psf
from skimage.io import imread
from skimage.measure import label
import random
import edt
from skimage.filters import gaussian
from scipy.ndimage import rotate
import random
import pickle
import pathlib
from pathlib import Path
from scipy.spatial.distance import cdist
%matplotlib qt5

In [3]:
from cellbgnet.utils.plot_funcs import plot_od, plot_train_record
from cellbgnet.analyze_eval import recognition, plot_full_img_predictions, assemble_full_img_predictions, limited_matching, assess

In [4]:
model_path = Path('/mnt/sda1/SMLAT/training_runs/model_rotated_45_venus_long_run.pkl')

In [5]:
with open(model_path, 'rb') as f:
    model = pickle.load(f)

##### Getting eval images and plotting ground truths and predicted values on images

In [6]:
all_eval_images = model.evaluation_params['eval_imgs']

In [7]:
all_eval_images.shape

(30, 1041, 1302)

In [8]:
ground_truth = model.evaluation_params['ground_truth']

In [9]:
ground_truth_np = np.array(ground_truth)

In [10]:
image_number = 0
single_eval_image = all_eval_images[image_number]
single_ground_truth = ground_truth_np[np.where(ground_truth_np[:, 1] == image_number + 1)]
single_eval_plot_image = torch.from_numpy(single_eval_image)[None,:]
pos_gt = torch.from_numpy(single_ground_truth[:, 2:5])
pos_gt[:, 0] /= model.psf_params['pixel_size_xy'][0]
pos_gt[:, 1] /= model.psf_params['pixel_size_xy'][1]
phot_gt = torch.from_numpy(single_ground_truth[:, 5])

In [11]:
pos_gt

tensor([[ 404.9698,  109.6079, -112.6771],
        [ 468.2611,  112.1723,  -24.0199],
        [ 397.7820,  112.8342,  -40.9140],
        ...,
        [1042.2054, 1032.9781,  -54.9651],
        [1040.0507, 1033.7281,  -50.6181],
        [1040.8280, 1035.7101,  495.7499]], dtype=torch.float64)

In [12]:
single_ground_truth.shape

(756, 6)

In [13]:
single_ground_truth

array([[ 1.00000000e+00,  1.00000000e+00,  4.04969849e+02,
         1.09607903e+02, -1.12677127e+02,  9.23177004e+02],
       [ 2.00000000e+00,  1.00000000e+00,  4.68261108e+02,
         1.12172279e+02, -2.40198672e+01,  1.74059594e+03],
       [ 3.00000000e+00,  1.00000000e+00,  3.97781982e+02,
         1.12834221e+02, -4.09139991e+01,  2.35499775e+03],
       ...,
       [ 7.54000000e+02,  1.00000000e+00,  1.04220544e+03,
         1.03297815e+03, -5.49651384e+01,  1.52198231e+03],
       [ 7.55000000e+02,  1.00000000e+00,  1.04005066e+03,
         1.03372815e+03, -5.06181121e+01,  2.03393501e+03],
       [ 7.56000000e+02,  1.00000000e+00,  1.04082800e+03,
         1.03571008e+03,  4.95749891e+02,  1.35437813e+03]])

In [14]:
single_eval_plot_image.shape

torch.Size([1, 1041, 1302])

In [15]:
from cellbgnet.utils.plot_funcs import PlotFrameCoord

In [16]:
PlotFrameCoord(single_eval_plot_image, pos_tar=pos_gt, annotate_tar_z=True).plot()

#### Make predictions on the single image

In [17]:
fov_size = [all_eval_images.shape[2] * 65, all_eval_images.shape[1] * 65]

In [18]:
fov_size

[84630, 67665]

In [19]:
single_eval_image.shape

(1041, 1302)

In [20]:
preds_tmp, n_per_img, plot_data = recognition(model=model, eval_imgs_all=single_eval_image[np.newaxis,:],
                                             batch_size=16, use_tqdm=False,
                                             nms=True, candidate_threshold=0.2,
                                             nms_threshold=0.5, 
                                             pixel_nm=model.data_generator.psf_params['pixel_size_xy'],
                                             plot_num=1,
                                             win_size=128,
                                             padding=True,
                                             start_field_pos=[0, 0],
                                             padded_background=model.evaluation_params['padded_background'])

processing area:99/99, input field_xy:[1258 1301 1001 1040], use_coordconv:True, retain locs in area:[1278, 1301, 1021, 1040]


In [21]:
preds_tmp_np = np.array(preds_tmp)
pos_out = torch.from_numpy(preds_tmp_np[:, 2:5])
phot_out = torch.from_numpy(preds_tmp_np[:, 5])
pos_out[:, 0] /= model.psf_params['pixel_size_xy'][0]
pos_out[:, 1] /= model.psf_params['pixel_size_xy'][1]

In [22]:
img_infs = assemble_full_img_predictions(model, plot_data)

In [23]:
img_infs.keys()

dict_keys(['Probs', 'XO', 'YO', 'ZO', 'Int', 'BG', 'XO_sig', 'YO_sig', 'ZO_sig', 'Int_sig', 'Probs_ps', 'XO_ps', 'YO_ps', 'ZO_ps', 'Samples_ps', 'raw_img', 'only_bg'])

In [24]:
plt.figure()
plt.imshow(img_infs['Probs'])
plt.show()

In [25]:
plt.figure()
plt.imshow(img_infs['XO'])
plt.show()

In [26]:
plt.figure()
plt.imshow(img_infs['YO'])
plt.show()

In [27]:
plt.figure()
plt.imshow(img_infs['ZO'])
plt.show()

In [28]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['Probs'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis', plot_colorbar_frame=True).plot()

In [29]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['BG'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis', plot_colorbar_frame=True).plot()

In [30]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['Samples_ps'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis',).plot()

In [28]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['XO_sig'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis').plot()

In [26]:
PlotFrameCoord(torch.from_numpy(img_infs['XO'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True).plot()

In [27]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['YO'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True).plot()

In [28]:
PlotFrameCoord(single_eval_plot_image, pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True).plot()

In [31]:
for key, value in img_infs.items():
    plt.figure()
    plt.imshow(value[400: 700, 500:800])
    plt.colorbar()
    plt.title(key)
    plt.show()

In [32]:
S = img_infs['Samples_ps']

In [33]:
S.shape

(1041, 1302)

In [34]:
S = torch.from_numpy(S)[None, None, :, :]

In [35]:
plt.figure()
plt.imshow(S[0, 0])
plt.show()

In [36]:
plt.figure()
plt.imshow(img_infs['Probs'])
plt.show()

In [34]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['Probs'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis').plot()

In [37]:
pos_out.shape, pos_gt.shape

(torch.Size([684, 3]), torch.Size([756, 3]))

In [38]:
fov_size

[84630, 67665]

#### Checking matches and RMSE skews in y vs z to see if y distribution is somehow wider than z or x distribution

In [39]:
pos_gt.shape,pos_out.shape

(torch.Size([756, 3]), torch.Size([684, 3]))

In [40]:
single_ground_truth[:, 2:4] *= 65.0

In [41]:
single_ground_truth_list = single_ground_truth.tolist()

In [42]:
single_ground_truth_list

[[1.0,
  1.0,
  26323.040161132812,
  7124.5136642456055,
  -112.67712712287903,
  923.1770038604736],
 [2.0,
  1.0,
  30436.972045898438,
  7291.19815826416,
  -24.019867181777954,
  1740.5959367752075],
 [3.0,
  1.0,
  25855.828857421875,
  7334.2243576049805,
  -40.91399908065796,
  2354.997754096985],
 [4.0,
  1.0,
  26447.940216064453,
  7349.622383117676,
  -357.7025532722473,
  2173.9929914474487],
 [5.0,
  1.0,
  32299.729461669922,
  7321.059417724609,
  -489.61642384529114,
  2049.1873025894165],
 [6.0,
  1.0,
  27625.852966308594,
  7440.637855529785,
  -452.7055025100708,
  2500.631332397461],
 [7.0,
  1.0,
  33655.12664794922,
  7563.219032287598,
  216.67331457138062,
  1295.5700755119324],
 [8.0,
  1.0,
  36583.578186035156,
  7572.381477355957,
  -277.67014503479004,
  1625.6259083747864],
 [9.0,
  1.0,
  30761.87484741211,
  7579.0375900268555,
  -113.72137069702148,
  812.0206296443939],
 [10.0,
  1.0,
  31634.276885986328,
  7635.766792297363,
  490.17196893692017,
 

In [43]:
preds_tmp

[[1.0,
  1.0,
  26253.744140625,
  7193.4482421875,
  42.609073638916016,
  989.2841796875,
  0.6344854831695557,
  27.281381607055664,
  24.080612182617188,
  95.92286682128906,
  144.73667907714844,
  -0.09623167663812637,
  -0.33156344294548035],
 [2.0,
  1.0,
  30425.71875,
  7305.64501953125,
  -55.469085693359375,
  1682.1983642578125,
  0.9934322834014893,
  15.451836585998535,
  14.255281448364258,
  41.902122497558594,
  125.0210952758789,
  0.08796855807304382,
  0.3945363759994507],
 [3.0,
  1.0,
  32285.83984375,
  7289.6162109375,
  -470.2463073730469,
  1901.7357177734375,
  0.934432864189148,
  14.453184127807617,
  14.92622184753418,
  17.721376419067383,
  102.18761444091797,
  -0.294775128364563,
  0.1479398012161255],
 [4.0,
  1.0,
  25847.388671875,
  7334.837890625,
  -51.79137420654297,
  1962.27734375,
  1.0225481986999512,
  12.510900497436523,
  12.156954765319824,
  37.589752197265625,
  130.73532104492188,
  -0.34786343574523926,
  -0.15633974969387054],
 [5.

In [44]:
len(preds_tmp)

684

In [45]:
len(single_ground_truth_list)

756

In [46]:
preds_tmp

[[1.0,
  1.0,
  26253.744140625,
  7193.4482421875,
  42.609073638916016,
  989.2841796875,
  0.6344854831695557,
  27.281381607055664,
  24.080612182617188,
  95.92286682128906,
  144.73667907714844,
  -0.09623167663812637,
  -0.33156344294548035],
 [2.0,
  1.0,
  30425.71875,
  7305.64501953125,
  -55.469085693359375,
  1682.1983642578125,
  0.9934322834014893,
  15.451836585998535,
  14.255281448364258,
  41.902122497558594,
  125.0210952758789,
  0.08796855807304382,
  0.3945363759994507],
 [3.0,
  1.0,
  32285.83984375,
  7289.6162109375,
  -470.2463073730469,
  1901.7357177734375,
  0.934432864189148,
  14.453184127807617,
  14.92622184753418,
  17.721376419067383,
  102.18761444091797,
  -0.294775128364563,
  0.1479398012161255],
 [4.0,
  1.0,
  25847.388671875,
  7334.837890625,
  -51.79137420654297,
  1962.27734375,
  1.0225481986999512,
  12.510900497436523,
  12.156954765319824,
  37.589752197265625,
  130.73532104492188,
  -0.34786343574523926,
  -0.15633974969387054],
 [5.

In [47]:
perf, matches = limited_matching(single_ground_truth_list, preds_tmp, min_int=0.0, 
                                 limited_x=[0, fov_size[0]],
                                 limited_y=[0, fov_size[1]],
                                border=False, print_res=True, tolerance=250, tolerance_ax=np.inf)

FOV: x=[0, 84630] y=[0, 67665]
after FOV and border segmentation,truth: 756 ,preds: 684
Recall: 0.903
Precision: 0.999
Jaccard: 90.225
RMSE_lat: 36.556
RMSE_ax: 89.749
RMSE_vol: 96.909
Jaccard/RMSE: 2.468
Eff_lat: 62.159
Eff_ax: 54.073
Eff_3d: 58.116
FN: 73.0 FP: 1.0


In [48]:
matches.shape

(683, 12)

In [49]:
matches[4]

array([ 6.36187897e+04,  3.61440506e+04, -3.45728904e+02,  2.77138317e+03,
        6.36174336e+04,  3.61444727e+04, -3.97310242e+02,  2.38585571e+03,
        5.22960901e-01,  3.82167320e+01,  3.90226288e+02, -2.70253658e-01])

In [50]:
x_diffs = matches[:, 0] - matches[:, 4]
y_diffs = matches[:, 1] - matches[:, 5]
z_diffs = matches[:, 2] - matches[:, 6]
ph_diffs = matches[:, 3] - matches[:, 7]

In [51]:
plt.figure()
plt.hist(x_diffs, bins=100)
plt.title(f"Mean: {np.mean(x_diffs)}, std: {np.std(x_diffs)}")
plt.xlabel('Differnce btw x and gt_x')
plt.ylabel('Counts')
plt.show()

In [52]:
plt.figure()
plt.hist(y_diffs, bins=100)
plt.title(f"Mean: {np.mean(y_diffs)}, std: {np.std(y_diffs)}")
plt.xlabel('Differnce btw y and gt_y')
plt.ylabel('Counts')
plt.show()

In [53]:
plt.figure()
plt.hist(z_diffs, bins=100)
plt.title(f"Mean: {np.mean(z_diffs)}, std: {np.std(z_diffs)}")
plt.xlabel('Differnce btw z and gt_z')
plt.ylabel('Counts')
plt.show()

In [54]:
np.mean(y_diffs)

0.4558806524095214

In [55]:
np.mean(z_diffs)

15.649960848566902

In [56]:
np.mean(x_diffs)

-0.571043313335011

In [57]:
np.mean(ph_diffs)

-5.251326297073029

In [58]:
np.std(x_diffs)

26.584505629595558

In [59]:
np.std(y_diffs)

25.08142171853473

In [60]:
np.std(z_diffs)

88.37418228992681

In [61]:
plt.figure()
plt.hist(y_diffs, bins=100)
plt.show()

In [101]:
plt.figure()
plt.hist(z_diffs, bins=100)
plt.show()

#### Make noise-free psf and other things using ground truth for calculating all loss components

In [1]:
def make_psfs_sim(model, pos_gt, phot_gt, image_shape):
    pass

In [28]:
pos_gt

tensor([[ 990.8008,  100.3148, -218.1367],
        [1045.1801,  101.8276, -498.4584],
        [ 582.6328,  103.8077, -334.6290],
        ...,
        [ 560.3091,  957.3259,  289.8114],
        [ 856.0256,  957.2742,  228.9757],
        [ 483.5667,  997.2133,   32.6427]], dtype=torch.float64)

In [39]:
x_os_vals = pos_gt[:, 0] % 1

In [40]:
x_centers = torch.round(pos_gt[:, 0])
y_centers = torch.round(pos_gt[:, 1])

In [41]:
x_os_vals = pos_gt[:, 0] - x_centers
y_os_vals = pos_gt[:, 1] - y_centers

In [43]:
plt.figure()
plt.hist(x_os_vals.numpy(), bins=50)
plt.show()


In [44]:
plt.figure()
plt.hist(y_os_vals.numpy(), bins=50)
plt.show()

In [54]:
checks = torch.stack((x_os, x_centers, pos_gt[:, 0]), dim=1)

In [60]:
torch.set_printoptions(sci_mode=False)

In [61]:
print(checks)

tensor([[    -0.1992,    991.0000,    990.8008],
        [     0.1801,   1045.0000,   1045.1801],
        [    -0.3672,    583.0000,    582.6328],
        [     0.4778,    765.0000,    765.4778],
        [    -0.2069,    678.0000,    677.7931],
        [    -0.3398,    837.0000,    836.6602],
        [     0.2245,    607.0000,    607.2245],
        [     0.2552,    962.0000,    962.2552],
        [    -0.0344,   1008.0000,   1007.9656],
        [    -0.0966,    884.0000,    883.9034],
        [    -0.4121,    962.0000,    961.5879],
        [     0.1062,    995.0000,    995.1062],
        [     0.1448,    476.0000,    476.1448],
        [    -0.0737,    589.0000,    588.9263],
        [    -0.4006,    951.0000,    950.5994],
        [    -0.2094,    965.0000,    964.7906],
        [     0.3410,    966.0000,    966.3410],
        [    -0.2516,   1099.0000,   1098.7484],
        [    -0.4905,    845.0000,    844.5095],
        [     0.3820,    873.0000,    873.3820],
        [     0.0673

##### Now we have all the images to compare loss functions, one eval images at a time

#### Using eval data to untangle loss component's issue.

#### Do statistics on loss functions values as a function of z and photon count by pooling images

#### Adjust loss functions as necessary using various strategies and masking techniques and dice loss

#### prob_map

In [8]:
prob_map = np.zeros((1, 128, 128))

In [9]:
prob_map[:, 63, 63] = 1.0

In [10]:
plt.figure()
plt.imshow(prob_map[0])
plt.show()

In [11]:
prob_map = torch.from_numpy(prob_map)

### imgs_sim, empty arrays

In [12]:
batch_size = 1

In [13]:
imgs_sim = torch.zeros([batch_size, 1, prob_map.shape[1], prob_map.shape[2]])
xyzi_gt = torch.zeros([batch_size, 0, 4])
s_mask = torch.zeros([batch_size, 0])
pix_cor = torch.zeros([batch_size, 0, 2])

### Sampling

In [14]:
blink_p = prob_map.reshape(prob_map.shape[0], 1, prob_map.shape[1], prob_map.shape[2])

In [15]:
locs1 = torch.distributions.Binomial(1, blink_p).sample()

In [16]:
locs1.shape

torch.Size([1, 1, 128, 128])

In [17]:
plt.figure()
plt.imshow(locs1[0, 0].numpy())
plt.show()

In [18]:
zeros = torch.zeros_like(locs1)

In [19]:
z = torch.clone(locs1) * 0.0
x_os = torch.clone(locs1) * 0.0
y_os = torch.clone(locs1) * 0.0
ints = torch.clone(locs1) * 0.5

In [20]:
z *= locs1
x_os *= locs1
y_os *= locs1
ints *= locs1

In [21]:
fig, ax = plt.subplots(nrows=2, ncols=3)
ax[0, 0].imshow(locs1[0, 0].numpy())
ax[0, 0].set_title('Dot location')

ax[0, 1].imshow(x_os[0, 0].numpy())
ax[0, 1].set_title('X offsets')

ax[1, 0].imshow(y_os[0, 0].numpy())
ax[1, 0].set_title('Y offsets')

ax[1, 1].imshow(z[0, 0].numpy())
ax[1, 1].set_title('z value in [-1, 1]')

ax[0, 2].imshow(ints[0, 0].numpy())
ax[0, 2].set_title('Intensities')
plt.tight_layout()
plt.show()

#### Placing PSFs on the emitter location (simulate_psfs method)

In [22]:
(batch_size, n_inp, h, w) = locs1.shape 

In [23]:
xyzi = torch.cat([x_os.reshape([-1, 1, h, w]), y_os.reshape([-1, 1, h, w]),
                  z.reshape([-1, 1, h, w]), ints.reshape([-1, 1, h, w])], 1)

In [24]:
S = locs1.reshape([-1, h, w])
n_samples = S.shape[0] // xyzi.shape[0]
XYZI_rep  = xyzi.repeat_interleave(n_samples, 0)

s_inds = tuple(S.nonzero().transpose(1, 0))

x_os_vals = (XYZI_rep[:, 0][s_inds])[:, None, None]
y_os_vals = (XYZI_rep[:, 1][s_inds])[:, None, None]
z_vals = psf_params['z_scale'] * (XYZI_rep[:, 2][s_inds])[:, None, None]
i_vals = psf_params['photon_scale'] * (XYZI_rep[:, 3][s_inds])[:, None, None]

In [129]:
s_inds

(tensor([0]), tensor([63]), tensor([63]))

In [130]:
x_os_vals, y_os_vals, z_vals, i_vals

(tensor([[[0.]]], dtype=torch.float64),
 tensor([[[0.]]], dtype=torch.float64),
 tensor([[[0.]]], dtype=torch.float64),
 tensor([[[1500.]]], dtype=torch.float64))

In [131]:
n_emitters = len(s_inds[0])

In [133]:
xyz = torch.zeros((n_emitters, 3))
psf_int_vals = torch.ones((n_emitters, ))
psf_size = psf_params['psf_size'] 
xyz[:, 0] = psf_size // 2 + x_os_vals[:, 0, 0]
xyz[:, 1] = psf_size // 2 + y_os_vals[:, 0, 0]
xyz[:, 2] = z_vals[:, 0, 0]
# each frame will have one emitter only
frame_ix = torch.arange(n_emitters)

# it is fine to do on CPU as we don't do decode like large pre-sampling batches
# it will be slow, but whatever.. 
em = EmitterSet(xyz=xyz, phot=psf_int_vals.cpu(), frame_ix=frame_ix.long().cpu(),
                id=torch.arange(n_emitters).long(), xy_unit='px',
                px_size=psf_params['pixel_size_xy'])

psf = SMAPSplineCoefficient(psf_params['calib_file']).init_spline(
        xextent=[-0.5, psf_size-0.5],
        yextent=[-0.5, psf_size-0.5],
        img_shape=[psf_size, psf_size],
        device='cpu',
        roi_size=None, roi_auto_center=None
)


In [134]:
psf_size

41

In [135]:
print(em)

EmitterSet
::num emitters: 1
::xy unit: px
::px size: tensor([65., 65.])
::frame range: 0 - 0
::spanned volume: [20. 20.  0.] - [20. 20.  0.]


In [136]:
psf_size // 2

20

In [137]:
em.xyz_px, em.phot, em.frame_ix

(tensor([[20., 20.,  0.]]), tensor([1.]), tensor([0]))

In [138]:
psf_sim = psf.forward(em.xyz_px, em.phot, em.frame_ix, ix_low=0, ix_high=batch_size-1)

In [139]:
plt.figure()
plt.imshow(psf_sim[0].cpu().numpy())
plt.show()

In [140]:
psf._coeff.shape

torch.Size([40, 40, 53, 64])

In [141]:
psf_sim.shape

torch.Size([1, 41, 41])

In [142]:
psf.ref0

tensor([20., 20., 28.])

In [143]:
psf.vx_size

tensor([ 1.,  1., 25.])

In [144]:
psf.roi_size_px

torch.Size([40, 40])

In [145]:
psf.ref_re

In [146]:
psf.roi_size_px

torch.Size([40, 40])

In [147]:
psf._roi_native

torch.Size([40, 40])

In [148]:
psf._roi_size_nm

(tensor(40.), tensor(40.))

In [149]:
psf.max_roi_chunk

500000

In [150]:
xyz

tensor([[20., 20.,  0.]])

In [151]:
xyz_r, xyz_px = psf.frame2roi_coord(xyz)

In [152]:
xyz_r, xyz_px

(tensor([[20., 20.,  0.]]), tensor([[0, 0]], dtype=torch.int32))

In [153]:
psf.coord2impl(xyz_r)

tensor([[ 0.,  0., 28.]])

In [159]:
i_vals

tensor([[[1500.]]], dtype=torch.float64)

In [211]:
xyz = torch.tensor([[0.0, 0.0, 0.0],
                    [0.0, 0.0, 1.0],
                    [0.0, 0.0, 200.0],
                    [0.0, 0.0, 10.0]])
ints = torch.tensor([[[1500.0]],[[1500.0]], [[1500.0]], [[1500.0]]])

In [212]:
result = psf.forward_rois(xyz, phot=i_vals)

In [213]:
result.shape

torch.Size([4, 40, 40])

In [214]:
plt.figure()
plt.imshow(result[0].cpu().numpy())
plt.show()

In [215]:
plt.figure()
plt.imshow(result[1].cpu().numpy())
plt.show()

In [216]:
plt.figure()
plt.imshow(result[2].cpu().numpy())
plt.show()

In [217]:
plt.figure()
plt.imshow(result[3].cpu().numpy())
plt.show()

### Different loss functions' components

#### MSE loss

In [6]:
def mse_loss(psf_imgs_est, psf_imgs_gt):
    loss = nn.MSELoss(reduction='none')
    cost = loss(psf_imgs_est, psf_imgs_gt)
    cost = cost.sum(-1).sum(-1)
    return cost

#### Cross-entropy loss

In [7]:
def cross_entropy_loss(P, locs):
    loss_cse = -(locs * torch.log(P) + (1 - locs) * torch.log(1 - P))
    loss_cse = loss_cse.sum(-1).sum(-1)
    return loss_cse

#### Count loss

In [9]:
def count_loss(P, locs):
    log_prob = 0
    prob_mean = P.sum(-1).sum(-1)
    prob_var = (P - P **2).sum(-1).sum(-1)
    X = locs.sum(-1)
    log_prob += 1 / 2 * ((X - prob_mean) ** 2) / prob_var + 1 / 2 * torch.log(2 * np.pi * prob_var)
    return log_prob


#### Localization loss

In [None]:
def localization_loss(P, xyzi_est, xyzi_sig, xyzi_gt, s_mask):
    