In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

### Loss function balancing at different z heights to get to equi-probable detections over full range

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import cellbgnet
import cellbgnet.utils
from cellbgnet.datasets import DataSimulator
from cellbgnet.utils.hardware import cpu, gpu
from cellbgnet.model import CellBGModel
from cellbgnet.simulation.psf_kernel import SMAPSplineCoefficient
from cellbgnet.generic.emitter import EmitterSet
from cellbgnet.train_loss_infer import generate_probmap_cells
from cellbgnet.analyze_eval import spline_crlb_plot
from cellbgnet.utils.plot_funcs import plot_psf
from skimage.io import imread
from skimage.measure import label
import random
import edt
from skimage.filters import gaussian
from scipy.ndimage import rotate
import random
import pickle
import pathlib
from pathlib import Path
from scipy.spatial.distance import cdist
%matplotlib qt5

In [3]:
from cellbgnet.utils.plot_funcs import plot_od, plot_train_record
from cellbgnet.analyze_eval import recognition, plot_full_img_predictions, assemble_full_img_predictions, limited_matching, assess

In [4]:
model_path = Path('/mnt/sda1/SMLAT/training_runs/model_debug_factor_offset_changed.pkl')

In [5]:
with open(model_path, 'rb') as f:
    model = pickle.load(f)

##### Getting eval images and plotting ground truths and predicted values on images

In [6]:
all_eval_images = model.evaluation_params['eval_imgs']

In [7]:
all_eval_images.shape

(30, 1041, 1302)

In [8]:
ground_truth = model.evaluation_params['ground_truth']

In [9]:
ground_truth_np = np.array(ground_truth)

In [10]:
image_number = 0
single_eval_image = all_eval_images[image_number]
single_ground_truth = ground_truth_np[np.where(ground_truth_np[:, 1] == image_number + 1)]
single_eval_plot_image = torch.from_numpy(single_eval_image)[None,:]
pos_gt = torch.from_numpy(single_ground_truth[:, 2:5])
pos_gt[:, 0] /= model.psf_params['pixel_size_xy'][0]
pos_gt[:, 1] /= model.psf_params['pixel_size_xy'][1]
phot_gt = torch.from_numpy(single_ground_truth[:, 5])

In [11]:
pos_gt

tensor([[ 428.8350,   98.6586,   63.2603],
        [ 672.7158,  100.2869,  146.5399],
        [ 723.7394,  101.1278,   40.0457],
        [ 630.3995,  102.7299,  183.6089],
        [ 829.1885,  102.7786,  370.0405],
        [ 970.6247,  103.4018, -165.6273],
        [ 712.7186,  104.1456,  182.0688],
        [ 984.1409,  104.3181,  462.1411],
        [ 991.2529,  104.3582, -320.4955],
        [ 696.0889,  105.2057,  -93.5657],
        [ 695.3876,  106.4184,  179.6660],
        [ 504.2025,  107.0870,  -77.1151],
        [ 855.7621,  107.0682, -138.1775],
        [ 870.2623,  107.3833,  348.8920],
        [ 695.2911,  146.5059, -499.9881],
        [ 657.4214,  151.1000,  139.7282],
        [ 803.5670,  151.4060,  232.1150],
        [ 880.5832,  150.6203, -163.9624],
        [ 472.7848,  151.9511,  356.7798],
        [ 787.2191,  152.1462,  312.4704],
        [ 923.3469,  151.8769, -138.7312],
        [ 860.7389,  188.3015,  -37.3298],
        [ 571.5889,  188.6805, -103.7272],
        [ 5

In [12]:
single_ground_truth.shape

(321, 6)

In [13]:
single_eval_plot_image.shape

torch.Size([1, 1041, 1302])

In [14]:
from cellbgnet.utils.plot_funcs import PlotFrameCoord

In [15]:
PlotFrameCoord(single_eval_plot_image, pos_tar=pos_gt, annotate_tar_z=True).plot()

#### Make predictions on the single image

In [16]:
fov_size = [all_eval_images.shape[2] * 65, all_eval_images.shape[1] * 65]

In [17]:
fov_size

[84630, 67665]

In [18]:
single_eval_image.shape

(1041, 1302)

In [19]:
preds_tmp, n_per_img, plot_data = recognition(model=model, eval_imgs_all=single_eval_image[np.newaxis,:],
                                             batch_size=16, use_tqdm=False,
                                             nms=True, candidate_threshold=0.2,
                                             nms_threshold=0.5, 
                                             pixel_nm=model.data_generator.psf_params['pixel_size_xy'],
                                             plot_num=1,
                                             win_size=128,
                                             padding=True,
                                             start_field_pos=[0, 0],
                                             padded_background=model.evaluation_params['padded_background'])

processing area:99/99, input field_xy:[1258 1301 1001 1040], use_coordconv:True, retain locs in area:[1278, 1301, 1021, 1040]


In [20]:
preds_tmp_np = np.array(preds_tmp)
pos_out = torch.from_numpy(preds_tmp_np[:, 2:5])
phot_out = torch.from_numpy(preds_tmp_np[:, 5])
pos_out[:, 0] /= model.psf_params['pixel_size_xy'][0]
pos_out[:, 1] /= model.psf_params['pixel_size_xy'][1]

In [21]:
img_infs = assemble_full_img_predictions(model, plot_data)

In [22]:
img_infs.keys()

dict_keys(['Probs', 'XO', 'YO', 'ZO', 'Int', 'BG', 'XO_sig', 'YO_sig', 'ZO_sig', 'Int_sig', 'Probs_ps', 'XO_ps', 'YO_ps', 'ZO_ps', 'Samples_ps', 'raw_img', 'only_bg'])

In [25]:
plt.figure()
plt.imshow(img_infs['Probs'])
plt.show()

In [24]:
plt.figure()
plt.imshow(img_infs['XO'])
plt.show()

In [25]:
plt.figure()
plt.imshow(img_infs['YO'])
plt.show()

In [28]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['Probs'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis', plot_colorbar_frame=True).plot()

In [30]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['BG'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis', plot_colorbar_frame=True).plot()

In [27]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['Samples_ps'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis',).plot()

In [28]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['XO_sig'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis').plot()

In [23]:
PlotFrameCoord(torch.from_numpy(img_infs['XO'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True).plot()

In [33]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['YO'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True).plot()

In [29]:
PlotFrameCoord(single_eval_plot_image, pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True).plot()

In [28]:
for key, value in img_infs.items():
    plt.figure()
    plt.imshow(value[400: 700, 500:800])
    plt.colorbar()
    plt.title(key)
    plt.show()

In [31]:
S = img_infs['Samples_ps']

In [32]:
S.shape

(1041, 1302)

In [33]:
S = torch.from_numpy(S)[None, None, :, :]

In [34]:
plt.figure()
plt.imshow(S[0, 0])
plt.show()

In [35]:
plt.figure()
plt.imshow(img_infs['Probs'])
plt.show()

In [23]:
plt.figure()
PlotFrameCoord(torch.from_numpy(img_infs['Probs'])[None,:], pos_tar=pos_gt, pos_out=pos_out,
               annotate_out_z=True, annotate_tar_z=True, frame_cmap='viridis').plot()

In [24]:
pos_out.shape, pos_gt.shape

(torch.Size([305, 3]), torch.Size([321, 3]))

In [25]:
fov_size

[84630, 67665]

#### Checking matches and RMSE skews in y vs z to see if y distribution is somehow wider than z or x distribution

In [26]:
pos_gt.shape,pos_out.shape

(torch.Size([321, 3]), torch.Size([305, 3]))

In [27]:
single_ground_truth[:, 2:4] *= 65.0

In [28]:
single_ground_truth_list = single_ground_truth.tolist()

In [29]:
single_ground_truth_list

[[1.0,
  1.0,
  27874.274444580078,
  6412.806510925293,
  63.26025724411011,
  1967.9381847381592],
 [2.0,
  1.0,
  43726.524353027344,
  6518.647232055664,
  146.53992652893066,
  981.4921617507935],
 [3.0,
  1.0,
  47043.05969238281,
  6573.305511474609,
  40.045738220214844,
  1361.0200881958008],
 [4.0,
  1.0,
  40975.965881347656,
  6677.442283630371,
  183.60894918441772,
  1338.6248052120209],
 [5.0,
  1.0,
  53897.254943847656,
  6680.610160827637,
  370.0404763221741,
  2442.0154094696045],
 [6.0,
  1.0,
  63090.60516357422,
  6721.117630004883,
  -165.62727093696594,
  2999.8978972434998],
 [7.0,
  1.0,
  46326.71081542969,
  6769.464454650879,
  182.06876516342163,
  1949.516773223877],
 [8.0,
  1.0,
  63969.16046142578,
  6780.678977966309,
  462.14109659194946,
  2048.5260486602783],
 [9.0,
  1.0,
  64431.436462402344,
  6783.282508850098,
  -320.49545645713806,
  2838.1049036979675],
 [10.0,
  1.0,
  45245.7763671875,
  6838.370742797852,
  -93.56573224067688,
  1013.890

In [34]:
len(preds_tmp)

305

In [35]:
len(single_ground_truth_list)

321

In [30]:
preds_tmp

[[1.0,
  1.0,
  27876.439453125,
  6419.39990234375,
  54.13360595703125,
  2100.329833984375,
  1.051848292350769,
  12.910325050354004,
  7.962179183959961,
  19.237895965576172,
  119.81484985351562,
  -0.1317095011472702,
  -0.2399992197751999],
 [2.0,
  1.0,
  32781.20703125,
  6947.05517578125,
  -93.6034164428711,
  1081.0718994140625,
  0.9395679831504822,
  14.742554664611816,
  16.16427993774414,
  30.721887588500977,
  102.64664459228516,
  0.3262811005115509,
  -0.1222243532538414],
 [1.0,
  1.0,
  40961.49609375,
  6694.92529296875,
  143.8317108154297,
  1286.730712890625,
  0.9888822436332703,
  28.0789852142334,
  12.303084373474121,
  29.693077087402344,
  125.6823501586914,
  0.17686861753463745,
  -0.0011471593752503395],
 [1.0,
  1.0,
  43723.2890625,
  6510.40625,
  186.57858276367188,
  1119.80419921875,
  0.8768191337585449,
  36.87968826293945,
  13.514337539672852,
  27.10999870300293,
  113.3970947265625,
  -0.33400586247444153,
  0.16009347140789032],
 [2.0,


In [20]:
perf, matches = limited_matching(ground_truth, preds_tmp, min_int=0.0, 
                                 limited_x=[0, fov_size[0]],
                                 limited_y=[0, fov_size[1]],
                                border=False, print_res=True, tolerance=250, tolerance_ax=np.inf)

FOV: x=[0, 84630] y=[0, 67665]
after FOV and border segmentation,truth: 10205 ,preds: 8531
Recall: 0.835
Precision: 0.999
Jaccard: 83.452
RMSE_lat: 39.980
RMSE_ax: 49.804
RMSE_vol: 63.866
Jaccard/RMSE: 2.087
Eff_lat: 56.731
Eff_ax: 70.101
Eff_3d: 63.416
FN: 1682.0 FP: 8.0


In [21]:
matches.shape

(8523, 12)

In [22]:
matches[4]

array([6.81358453e+04, 2.39877374e+04, 8.44882727e+01, 2.71717840e+03,
       6.81350859e+04, 2.39897246e+04, 8.07179489e+01, 2.77438037e+03,
       1.09022915e+00, 2.02472630e+01, 1.16772621e+02, 2.32106015e-01])

In [23]:
x_diffs = matches[:, 0] - matches[:, 4]
y_diffs = matches[:, 1] - matches[:, 5]
z_diffs = matches[:, 2] - matches[:, 6]
ph_diffs = matches[:, 3] - matches[:, 7]

In [24]:
plt.figure()
plt.hist(x_diffs, bins=100)
plt.title(f"Mean: {np.mean(x_diffs)}, std: {np.std(x_diffs)}")
plt.xlabel('Differnce btw x and gt_x')
plt.ylabel('Counts')
plt.show()

In [25]:
plt.figure()
plt.hist(y_diffs, bins=100)
plt.title(f"Mean: {np.mean(y_diffs)}, std: {np.std(y_diffs)}")
plt.xlabel('Differnce btw y and gt_y')
plt.ylabel('Counts')
plt.show()

In [26]:
plt.figure()
plt.hist(z_diffs, bins=100)
plt.title(f"Mean: {np.mean(z_diffs)}, std: {np.std(z_diffs)}")
plt.xlabel('Differnce btw z and gt_z')
plt.ylabel('Counts')
plt.show()

In [28]:
np.mean(y_diffs)

0.7636621466713819

In [29]:
np.mean(z_diffs)

5.64333775986471

In [30]:
np.mean(x_diffs)

-2.8665829292146685

In [31]:
np.mean(ph_diffs)

-48.28006091660111

In [33]:
np.std(x_diffs)

47.5545421380998

In [35]:
np.std(y_diffs)

42.16064040454709

In [36]:
np.std(z_diffs)

52.89254293053526

In [100]:
plt.figure()
plt.hist(y_diffs, bins=100)
plt.show()

In [101]:
plt.figure()
plt.hist(z_diffs, bins=100)
plt.show()

#### Make noise-free psf and other things using ground truth for calculating all loss components

In [1]:
def make_psfs_sim(model, pos_gt, phot_gt, image_shape):
    pass

In [28]:
pos_gt

tensor([[ 990.8008,  100.3148, -218.1367],
        [1045.1801,  101.8276, -498.4584],
        [ 582.6328,  103.8077, -334.6290],
        ...,
        [ 560.3091,  957.3259,  289.8114],
        [ 856.0256,  957.2742,  228.9757],
        [ 483.5667,  997.2133,   32.6427]], dtype=torch.float64)

In [39]:
x_os_vals = pos_gt[:, 0] % 1

In [40]:
x_centers = torch.round(pos_gt[:, 0])
y_centers = torch.round(pos_gt[:, 1])

In [41]:
x_os_vals = pos_gt[:, 0] - x_centers
y_os_vals = pos_gt[:, 1] - y_centers

In [43]:
plt.figure()
plt.hist(x_os_vals.numpy(), bins=50)
plt.show()


In [44]:
plt.figure()
plt.hist(y_os_vals.numpy(), bins=50)
plt.show()

In [54]:
checks = torch.stack((x_os, x_centers, pos_gt[:, 0]), dim=1)

In [60]:
torch.set_printoptions(sci_mode=False)

In [61]:
print(checks)

tensor([[    -0.1992,    991.0000,    990.8008],
        [     0.1801,   1045.0000,   1045.1801],
        [    -0.3672,    583.0000,    582.6328],
        [     0.4778,    765.0000,    765.4778],
        [    -0.2069,    678.0000,    677.7931],
        [    -0.3398,    837.0000,    836.6602],
        [     0.2245,    607.0000,    607.2245],
        [     0.2552,    962.0000,    962.2552],
        [    -0.0344,   1008.0000,   1007.9656],
        [    -0.0966,    884.0000,    883.9034],
        [    -0.4121,    962.0000,    961.5879],
        [     0.1062,    995.0000,    995.1062],
        [     0.1448,    476.0000,    476.1448],
        [    -0.0737,    589.0000,    588.9263],
        [    -0.4006,    951.0000,    950.5994],
        [    -0.2094,    965.0000,    964.7906],
        [     0.3410,    966.0000,    966.3410],
        [    -0.2516,   1099.0000,   1098.7484],
        [    -0.4905,    845.0000,    844.5095],
        [     0.3820,    873.0000,    873.3820],
        [     0.0673

##### Now we have all the images to compare loss functions, one eval images at a time

#### Using eval data to untangle loss component's issue.

#### Do statistics on loss functions values as a function of z and photon count by pooling images

#### Adjust loss functions as necessary using various strategies and masking techniques and dice loss

#### prob_map

In [8]:
prob_map = np.zeros((1, 128, 128))

In [9]:
prob_map[:, 63, 63] = 1.0

In [10]:
plt.figure()
plt.imshow(prob_map[0])
plt.show()

In [11]:
prob_map = torch.from_numpy(prob_map)

### imgs_sim, empty arrays

In [12]:
batch_size = 1

In [13]:
imgs_sim = torch.zeros([batch_size, 1, prob_map.shape[1], prob_map.shape[2]])
xyzi_gt = torch.zeros([batch_size, 0, 4])
s_mask = torch.zeros([batch_size, 0])
pix_cor = torch.zeros([batch_size, 0, 2])

### Sampling

In [14]:
blink_p = prob_map.reshape(prob_map.shape[0], 1, prob_map.shape[1], prob_map.shape[2])

In [15]:
locs1 = torch.distributions.Binomial(1, blink_p).sample()

In [16]:
locs1.shape

torch.Size([1, 1, 128, 128])

In [17]:
plt.figure()
plt.imshow(locs1[0, 0].numpy())
plt.show()

In [18]:
zeros = torch.zeros_like(locs1)

In [19]:
z = torch.clone(locs1) * 0.0
x_os = torch.clone(locs1) * 0.0
y_os = torch.clone(locs1) * 0.0
ints = torch.clone(locs1) * 0.5

In [20]:
z *= locs1
x_os *= locs1
y_os *= locs1
ints *= locs1

In [21]:
fig, ax = plt.subplots(nrows=2, ncols=3)
ax[0, 0].imshow(locs1[0, 0].numpy())
ax[0, 0].set_title('Dot location')

ax[0, 1].imshow(x_os[0, 0].numpy())
ax[0, 1].set_title('X offsets')

ax[1, 0].imshow(y_os[0, 0].numpy())
ax[1, 0].set_title('Y offsets')

ax[1, 1].imshow(z[0, 0].numpy())
ax[1, 1].set_title('z value in [-1, 1]')

ax[0, 2].imshow(ints[0, 0].numpy())
ax[0, 2].set_title('Intensities')
plt.tight_layout()
plt.show()

#### Placing PSFs on the emitter location (simulate_psfs method)

In [22]:
(batch_size, n_inp, h, w) = locs1.shape 

In [23]:
xyzi = torch.cat([x_os.reshape([-1, 1, h, w]), y_os.reshape([-1, 1, h, w]),
                  z.reshape([-1, 1, h, w]), ints.reshape([-1, 1, h, w])], 1)

In [24]:
S = locs1.reshape([-1, h, w])
n_samples = S.shape[0] // xyzi.shape[0]
XYZI_rep  = xyzi.repeat_interleave(n_samples, 0)

s_inds = tuple(S.nonzero().transpose(1, 0))

x_os_vals = (XYZI_rep[:, 0][s_inds])[:, None, None]
y_os_vals = (XYZI_rep[:, 1][s_inds])[:, None, None]
z_vals = psf_params['z_scale'] * (XYZI_rep[:, 2][s_inds])[:, None, None]
i_vals = psf_params['photon_scale'] * (XYZI_rep[:, 3][s_inds])[:, None, None]

In [129]:
s_inds

(tensor([0]), tensor([63]), tensor([63]))

In [130]:
x_os_vals, y_os_vals, z_vals, i_vals

(tensor([[[0.]]], dtype=torch.float64),
 tensor([[[0.]]], dtype=torch.float64),
 tensor([[[0.]]], dtype=torch.float64),
 tensor([[[1500.]]], dtype=torch.float64))

In [131]:
n_emitters = len(s_inds[0])

In [133]:
xyz = torch.zeros((n_emitters, 3))
psf_int_vals = torch.ones((n_emitters, ))
psf_size = psf_params['psf_size'] 
xyz[:, 0] = psf_size // 2 + x_os_vals[:, 0, 0]
xyz[:, 1] = psf_size // 2 + y_os_vals[:, 0, 0]
xyz[:, 2] = z_vals[:, 0, 0]
# each frame will have one emitter only
frame_ix = torch.arange(n_emitters)

# it is fine to do on CPU as we don't do decode like large pre-sampling batches
# it will be slow, but whatever.. 
em = EmitterSet(xyz=xyz, phot=psf_int_vals.cpu(), frame_ix=frame_ix.long().cpu(),
                id=torch.arange(n_emitters).long(), xy_unit='px',
                px_size=psf_params['pixel_size_xy'])

psf = SMAPSplineCoefficient(psf_params['calib_file']).init_spline(
        xextent=[-0.5, psf_size-0.5],
        yextent=[-0.5, psf_size-0.5],
        img_shape=[psf_size, psf_size],
        device='cpu',
        roi_size=None, roi_auto_center=None
)


In [134]:
psf_size

41

In [135]:
print(em)

EmitterSet
::num emitters: 1
::xy unit: px
::px size: tensor([65., 65.])
::frame range: 0 - 0
::spanned volume: [20. 20.  0.] - [20. 20.  0.]


In [136]:
psf_size // 2

20

In [137]:
em.xyz_px, em.phot, em.frame_ix

(tensor([[20., 20.,  0.]]), tensor([1.]), tensor([0]))

In [138]:
psf_sim = psf.forward(em.xyz_px, em.phot, em.frame_ix, ix_low=0, ix_high=batch_size-1)

In [139]:
plt.figure()
plt.imshow(psf_sim[0].cpu().numpy())
plt.show()

In [140]:
psf._coeff.shape

torch.Size([40, 40, 53, 64])

In [141]:
psf_sim.shape

torch.Size([1, 41, 41])

In [142]:
psf.ref0

tensor([20., 20., 28.])

In [143]:
psf.vx_size

tensor([ 1.,  1., 25.])

In [144]:
psf.roi_size_px

torch.Size([40, 40])

In [145]:
psf.ref_re

In [146]:
psf.roi_size_px

torch.Size([40, 40])

In [147]:
psf._roi_native

torch.Size([40, 40])

In [148]:
psf._roi_size_nm

(tensor(40.), tensor(40.))

In [149]:
psf.max_roi_chunk

500000

In [150]:
xyz

tensor([[20., 20.,  0.]])

In [151]:
xyz_r, xyz_px = psf.frame2roi_coord(xyz)

In [152]:
xyz_r, xyz_px

(tensor([[20., 20.,  0.]]), tensor([[0, 0]], dtype=torch.int32))

In [153]:
psf.coord2impl(xyz_r)

tensor([[ 0.,  0., 28.]])

In [159]:
i_vals

tensor([[[1500.]]], dtype=torch.float64)

In [211]:
xyz = torch.tensor([[0.0, 0.0, 0.0],
                    [0.0, 0.0, 1.0],
                    [0.0, 0.0, 200.0],
                    [0.0, 0.0, 10.0]])
ints = torch.tensor([[[1500.0]],[[1500.0]], [[1500.0]], [[1500.0]]])

In [212]:
result = psf.forward_rois(xyz, phot=i_vals)

In [213]:
result.shape

torch.Size([4, 40, 40])

In [214]:
plt.figure()
plt.imshow(result[0].cpu().numpy())
plt.show()

In [215]:
plt.figure()
plt.imshow(result[1].cpu().numpy())
plt.show()

In [216]:
plt.figure()
plt.imshow(result[2].cpu().numpy())
plt.show()

In [217]:
plt.figure()
plt.imshow(result[3].cpu().numpy())
plt.show()

### Different loss functions' components

#### MSE loss

In [6]:
def mse_loss(psf_imgs_est, psf_imgs_gt):
    loss = nn.MSELoss(reduction='none')
    cost = loss(psf_imgs_est, psf_imgs_gt)
    cost = cost.sum(-1).sum(-1)
    return cost

#### Cross-entropy loss

In [7]:
def cross_entropy_loss(P, locs):
    loss_cse = -(locs * torch.log(P) + (1 - locs) * torch.log(1 - P))
    loss_cse = loss_cse.sum(-1).sum(-1)
    return loss_cse

#### Count loss

In [9]:
def count_loss(P, locs):
    log_prob = 0
    prob_mean = P.sum(-1).sum(-1)
    prob_var = (P - P **2).sum(-1).sum(-1)
    X = locs.sum(-1)
    log_prob += 1 / 2 * ((X - prob_mean) ** 2) / prob_var + 1 / 2 * torch.log(2 * np.pi * prob_var)
    return log_prob


#### Localization loss

In [None]:
def localization_loss(P, xyzi_est, xyzi_sig, xyzi_gt, s_mask):
    