In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import sys
import pathlib
import pickle
from pathlib import Path
from skimage.io import imread
import seaborn as sns
sns.set_style('white')
sys.setrecursionlimit(10000)
%matplotlib qt5

In [3]:
import cellbgnet
import cellbgnet.utils
from cellbgnet.datasets import DataSimulator
from cellbgnet.utils.hardware import cpu, gpu
from cellbgnet.model import CellBGModel
from cellbgnet.utils.plot_funcs import plot_od, plot_train_record

In [4]:
model_path = Path('/mnt/sda1/SMLAT/training_runs/model_rotated_45_venus_beads.pkl')

In [5]:
with open(model_path, 'rb') as f:
    chromo_model = pickle.load(f)

In [6]:
chromo_model.evaluation_params['eval_imgs'].shape

(30, 1041, 1302)

In [7]:
plot_train_record(chromo_model)

#### Photon count range, get background average from simulation using model paramteres

In [8]:
mol_photons = (chromo_model.data_generator.simulation_params['min_photon'] + 1) /2 * chromo_model.data_generator.psf_params['photon_scale']

In [9]:
mol_photons

2875.0

#### Plot network outputs for the eval images after tiling and re-tiling of one eval image

In [10]:
from cellbgnet.analyze_eval import recognition, plot_full_img_predictions, assemble_full_img_predictions

In [11]:
eval_img = chromo_model.evaluation_params['eval_imgs']

In [12]:
eval_img.shape

(30, 1041, 1302)

In [13]:
plt.figure()
plt.imshow(eval_img[0], cmap='gray')
plt.colorbar()
plt.show()

In [14]:
fov_size = [eval_img.shape[2] * 65, eval_img.shape[1] * 65]

In [15]:
fov_size

[84630, 67665]

In [16]:
eval_img.mean(0).max()

129.0405049641927

In [17]:
eval_img.mean()

113.72014866131897

In [51]:
preds_tmp, n_per_img, plot_data = recognition(model=chromo_model, eval_imgs_all=eval_img,
                                             batch_size=16, use_tqdm=False,
                                             nms=True, candidate_threshold=0.05,
                                             nms_threshold=0.05, 
                                             pixel_nm=chromo_model.data_generator.psf_params['pixel_size_xy'],
                                             plot_num=1,
                                             win_size=128,
                                             padding=True,
                                             start_field_pos=[0, 0],
                                             padded_background=chromo_model.evaluation_params['padded_background'])

processing area:99/99, input field_xy:[1258 1301 1001 1040], use_coordconv:True, retain locs in area:[1278, 1301, 1021, 1040]


In [19]:
plot_full_img_predictions(chromo_model, plot_infs=plot_data, eval_csv=None, plot_num=1, fov_size=fov_size, pixel_size=[65, 65])

In [20]:
img_infs = assemble_full_img_predictions(chromo_model, plot_data)

In [21]:
img_infs.keys()

dict_keys(['Probs', 'XO', 'YO', 'ZO', 'Int', 'BG', 'XO_sig', 'YO_sig', 'ZO_sig', 'Int_sig', 'Probs_ps', 'XO_ps', 'YO_ps', 'ZO_ps', 'Samples_ps', 'raw_img', 'only_bg'])

In [52]:
ground_truth = chromo_model.evaluation_params['ground_truth']

In [53]:
perf_dict, matches = limited_matching(ground_truth, preds_tmp, min_int=0, limited_x=[0, fov_size[0]],
                                     limited_y=[0, fov_size[1]], border=450, print_res=True, tolerance=250,
                                      tolerance_ax=np.inf)

FOV: x=[0, 84630] y=[0, 67665]
after FOV and border segmentation,truth: 2888 ,preds: 2926
Length of preds_list: 2926
83 85
75 76
99 99
104 104
93 94
93 94
93 93
79 83
96 96
76 77
89 91
100 100
111 112
94 95
84 86
111 111
99 99
94 94
98 98
98 100
92 94
123 126
108 110
111 113
101 102
99 101
90 91
103 104
94 97
98 101
Recall: 1.000
Precision: 0.987
Jaccard: 98.633
RMSE_lat: 25.714
RMSE_ax: 46.150
RMSE_vol: 52.830
Jaccard/RMSE: 3.836
Eff_lat: 74.249
Eff_ax: 76.885
Eff_3d: 75.567
FN: 1.0 FP: 39.0


#### Testing real image

In [23]:
import skimage.io as sio

In [24]:
images_dir = Path('/mnt/sda1/SMLAT/data/real_data/only_beads/EXP-23-CA3091/beadsLocalizationAccuracy/lowIntensity/Pos01/Pos0/')
img_filenames = sorted(list(images_dir.glob('*.tif')))

In [25]:
eval_imgs_all = []
for filename in img_filenames:
    eval_imgs_all.append(sio.imread(filename))

In [26]:
eval_imgs_all = np.stack(eval_imgs_all)

In [27]:
eval_imgs_all = eval_imgs_all[:, 1041:, :]

In [28]:
eval_imgs_all.shape

(200, 1041, 1302)

In [40]:
fov_size = [eval_imgs_all.shape[2] * 65, eval_imgs_all.shape[1] * 65]

In [29]:

preds_tmp, n_per_img, plot_data = recognition(model=chromo_model, eval_imgs_all=eval_imgs_all,
                                             batch_size=16, use_tqdm=False,
                                             nms=True, candidate_threshold=0.4,
                                             nms_threshold=0.5, 
                                             pixel_nm=chromo_model.data_generator.psf_params['pixel_size_xy'],
                                             plot_num=1,
                                             win_size=128,
                                             padding=True,
                                             start_field_pos=[0, 0],
                                             padded_background=chromo_model.evaluation_params['padded_background'])


processing area:99/99, input field_xy:[1258 1301 1001 1040], use_coordconv:True, retain locs in area:[1278, 1301, 1021, 1040]


In [30]:
img_infs = assemble_full_img_predictions(chromo_model, plot_data)

In [31]:
img_infs.keys()

dict_keys(['Probs', 'XO', 'YO', 'ZO', 'Int', 'BG', 'XO_sig', 'YO_sig', 'ZO_sig', 'Int_sig', 'Probs_ps', 'XO_ps', 'YO_ps', 'ZO_ps', 'Samples_ps', 'raw_img', 'only_bg'])

### Matching emitters based on dist

In [41]:

preds_tmp_np = np.array(preds_tmp)

In [42]:
frame_wise_xyz = [preds_tmp_np[np.where(preds_tmp_np[:, 1] == i)].tolist() for i in range(1, 201)]

In [43]:
len(frame_wise_xyz)

200

### For emitter in first frame, find the closest emitter in all the other frames < 250 nm distance from it

In [80]:
from scipy.spatial.distance import cdist
from cellbgnet.analyze_eval import limited_matching

In [81]:
first_frame_localizations = frame_wise_xyz[0]

In [82]:
frame_wise_xyz[1]

[[2.0,
  2.0,
  39561.421875,
  10639.2373046875,
  -359.8374328613281,
  1586.7744140625,
  0.9642342329025269,
  13.04029369354248,
  12.381147384643555,
  28.80851173400879,
  113.49459075927734,
  -0.36274588108062744,
  -0.31941908597946167],
 [2.0,
  2.0,
  28460.236328125,
  36306.12890625,
  -155.01947021484375,
  2657.050537109375,
  1.2606549263000488,
  7.641078472137451,
  6.961223125457764,
  22.10407066345215,
  118.20963287353516,
  -0.15021535754203796,
  -0.4441536068916321],
 [2.0,
  2.0,
  50980.203125,
  35669.13671875,
  -327.28656005859375,
  2318.23046875,
  0.9578739404678345,
  10.343084335327148,
  9.956559181213379,
  23.9139404296875,
  134.94076538085938,
  0.31081029772758484,
  -0.24404215812683105],
 [2.0,
  2.0,
  63062.4140625,
  38295.67578125,
  -448.3127136230469,
  1112.360107421875,
  0.9987825751304626,
  18.236055374145508,
  18.254316329956055,
  39.090354919433594,
  102.3427963256836,
  0.1910075694322586,
  0.16425830125808716],
 [2.0,
  2.0

In [83]:
_, matches = limited_matching(frame_wise_xyz[0], frame_wise_xyz[1], min_int=0, limited_x=[0, fov_size[0]],
                             limited_y=[0, fov_size[1]], border=450, print_res=True, tolerance=250,
                             tolerance_ax=np.inf)

FOV: x=[0, 84630] y=[0, 67665]
after FOV and border segmentation,truth: 5 ,preds: 5
Length of preds_list: 5
Frame numbers: 1
5 0
matches is empty!
Recall: 0.000
Precision: 0.000
Jaccard: 0.000
RMSE_lat: 0.000
RMSE_ax: 0.000
RMSE_vol: 0.000
Jaccard/RMSE: nan
Eff_lat: 0.000
Eff_ax: 0.000
Eff_3d: 0.000
FN: 5.0 FP: 0.0


  jor = 100 * jaccard / rmse_lat


In [119]:
import copy

In [190]:

def match_two_frames(truth_origin, pred_list_origin, min_int, limited_x=[0, 204800], limited_y=[0, 204800],
                     border=450, print_res=True, tolerance=250, tolerance_ax=np.inf):
    #print('{}{}{}{}'.format('FOV: x=', limited_x, ' y=', limited_y))

    matches = []

    truth = copy.deepcopy(truth_origin)
    pred_list = copy.deepcopy(pred_list_origin)

    truth_array = np.array(truth)
    pred_array = np.array(pred_list)

    # filter prediction and gt according to limited_x;y
    t_inds = np.where(
        (truth_array[:, 2] < limited_x[0]) | (truth_array[:, 2] > limited_x[1]) |
        (truth_array[:, 3] < limited_y[0]) | (truth_array[:, 3] > limited_y[1]))
    p_inds = np.where(
        (pred_array[:, 2] < limited_x[0]) | (pred_array[:, 2] > limited_x[1]) |
        (pred_array[:, 3] < limited_y[0]) | (pred_array[:, 3] > limited_y[1]))
    for t in reversed(t_inds[0]):
        del (truth[t])
    for p in reversed(p_inds[0]):
        del (pred_list[p])

    if len(pred_list) == 0:
        perf_dict = {'recall': np.nan, 'precision': np.nan, 'jaccard': np.nan, 'f_score': np.nan, 'rmse_lat': np.nan,
                     'rmse_ax': np.nan,
                     'rmse_x': np.nan, 'rmse_y': np.nan, 'jor': np.nan, 'eff_lat': np.nan, 'eff_ax': np.nan,
                     'eff_3d': np.nan}
        print('after FOV segmentation, pred_list is empty!')
        return perf_dict, matches

    # delete molecules of ground truth/estimation in the margin area
    if border:
        test_arr = np.array(truth)
        pred_arr = np.array(pred_list)

        t_inds = np.where(
            (test_arr[:, 2] < limited_x[0] + border) | (test_arr[:, 2] > (limited_x[1] - border)) |
            (test_arr[:, 3] < limited_y[0] + border) | (test_arr[:, 3] > (limited_y[1] - border)))
        p_inds = np.where(
            (pred_arr[:, 2] < limited_x[0] + border) | (pred_arr[:, 2] > (limited_x[1] - border)) |
            (pred_arr[:, 3] < limited_y[0] + border) | (pred_arr[:, 3] > (limited_y[1] - border)))
        for t in reversed(t_inds[0]):
            del (truth[t])
        for p in reversed(p_inds[0]):
            del (pred_list[p])

    if len(pred_list) == 0:
        perf_dict = {'recall': np.nan, 'precision': np.nan, 'jaccard': np.nan, 'f_score': np.nan, 'rmse_lat': np.nan,
                     'rmse_ax': np.nan,
                     'rmse_x': np.nan, 'rmse_y': np.nan, 'jor': np.nan, 'eff_lat': np.nan, 'eff_ax': np.nan,
                     'eff_3d': np.nan}
        print('after border, pred_list is empty!')
        return perf_dict, matches

    #print('{}{}{}{}{}'.format('after FOV and border segmentation,'
    #                          , 'truth: ', len(truth), ' ,preds: ', len(pred_list)))

    TP = 0
    FP = 0.0001
    FN = 0.0001
    MSE_lat = 0
    MSE_ax = 0
    MSE_vol = 0

    if len(pred_list):
        #print(f"Length of preds_list: {len(pred_list)}")
        tests = copy.deepcopy(truth)  # gt in each frame
        preds = copy.deepcopy(pred_list)  # prediction in each frame

        #if len(truth) > 0:  # after border filtering and area segmentation, truth could be empty
        #    while truth[0][1] == i:
        #        tests.append(truth.pop(0))  # put all gt in the tests
        #        if len(truth) < 1:
        #            break
        #if len(pred_list) > 0:
        #    while pred_list[0][1] == i:
        #        preds.append(pred_list.pop(0))  # put all predictions in the preds
        #        if len(pred_list) < 1:
        #            break
        #print(len(tests), len(preds))
        # if preds is empty, it means no detection on the frame, all tests are FN
        if len(preds) == 0:
            FN += len(tests)
            # no need to calculate metric
        # if the gt of this frame is empty, all preds on this frame are FP
        if len(tests) == 0:
            FP += len(preds)
            # no need to calculate metric
        # calculate the Euclidean distance between all gt and preds, get a matrix [number of gt, number of preds]
        dist_arr = cdist(np.array(tests)[:, 2:4], np.array(preds)[:, 2:4])
        ax_arr = cdist(np.array(tests)[:, 4:5], np.array(preds)[:, 4:5])
        tot_arr = np.sqrt(dist_arr ** 2 + ax_arr ** 2)
        if tolerance_ax == np.inf:
            tot_arr = dist_arr

        match_tests = copy.deepcopy(tests)
        match_preds = copy.deepcopy(preds)
        #print(dist_arr)
        if dist_arr.size > 0:
            while dist_arr.min() < tolerance:
                r, c = np.where(tot_arr == tot_arr.min())  # select the positions pair with shortest distance
                r = r[0]
                c = c[0]
                if ax_arr[r, c] < tolerance_ax and dist_arr[r, c] < tolerance:  # compare the distance and tolerance
                    if match_tests[r][5] > min_int:  # photons should be larger than min_int

                        MSE_lat += dist_arr[r, c] ** 2
                        MSE_ax += ax_arr[r, c] ** 2
                        MSE_vol += dist_arr[r, c] ** 2 + ax_arr[r, c] ** 2
                        TP += 1
                        matches.append([match_tests[r][2], match_tests[r][3], match_tests[r][4], match_tests[r][5],
                                        match_preds[c][2], match_preds[c][3], match_preds[c][4], match_preds[c][5],
                                        match_preds[c][7], match_preds[c][8], match_preds[c][9],
                                        match_preds[c][10]])

                    dist_arr[r, :] = np.inf
                    dist_arr[:, c] = np.inf
                    tot_arr[r, :] = np.inf
                    tot_arr[:, c] = np.inf

                    tests[r][5] = -100  # photon cannot be negative, work as a flag
                    preds.pop()

                dist_arr[r, c] = np.inf
                tot_arr[r, c] = np.inf
                #print("matched one")

        for i in reversed(range(len(tests))):
            if tests[i][5] < min_int:  # delete matched gt
                del (tests[i])

        FP += len(preds)  # all remaining preds are FP
        FN += len(tests)  # all remaining gt are FN
    else:
        print('after border and FOV segmentation, pred list is empty!')

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    jaccard = TP / (TP + FP + FN)
    rmse_lat = np.sqrt(MSE_lat / (TP + 0.00001))
    rmse_ax = np.sqrt(MSE_ax / (TP + 0.00001))
    rmse_vol = np.sqrt(MSE_vol / (TP + 0.00001))
    jor = 100 * jaccard / rmse_lat

    eff_lat = 100 - np.sqrt((100 - 100 * jaccard) ** 2 + 1 ** 2 * rmse_lat ** 2)
    eff_ax = 100 - np.sqrt((100 - 100 * jaccard) ** 2 + 0.5 ** 2 * rmse_ax ** 2)
    eff_3d = (eff_lat + eff_ax) / 2

    matches = np.array(matches)
    #print("Number of matches: ", matches.shape)
    rmse_x = np.nan
    rmse_y = np.nan
    rmse_z = np.nan
    rmse_i = np.nan
    if len(matches):
        rmse_x = np.sqrt(((matches[:, 0] - matches[:, 4]) ** 2).mean())
        rmse_y = np.sqrt(((matches[:, 1] - matches[:, 5]) ** 2).mean())
        rmse_z = np.sqrt(((matches[:, 2] - matches[:, 6]) ** 2).mean())
        rmse_i = np.sqrt(((matches[:, 3] - matches[:, 7]) ** 2).mean())
    else:
        print('matches is empty!')

    if print_res:
        print('{}{:0.3f}'.format('Recall: ', recall))
        print('{}{:0.3f}'.format('Precision: ', precision))
        print('{}{:0.3f}'.format('Jaccard: ', 100 * jaccard))
        print('{}{:0.3f}'.format('RMSE_lat: ', rmse_lat))
        print('{}{:0.3f}'.format('RMSE_ax: ', rmse_ax))
        print('{}{:0.3f}'.format('RMSE_vol: ', rmse_vol))
        print('{}{:0.3f}'.format('Jaccard/RMSE: ', jor))
        print('{}{:0.3f}'.format('Eff_lat: ', eff_lat))
        print('{}{:0.3f}'.format('Eff_ax: ', eff_ax))
        print('{}{:0.3f}'.format('Eff_3d: ', eff_3d))
        print('FN: ' + str(np.round(FN)) + ' FP: ' + str(np.round(FP)))

    perf_dict = {'recall': recall, 'precision': precision, 'jaccard': jaccard, 'rmse_lat': rmse_lat,
                 'rmse_ax': rmse_ax, 'rmse_vol': rmse_vol, 'rmse_x': rmse_x, 'rmse_y': rmse_y,
                 'rmse_z': rmse_z, 'rmse_i': rmse_i, 'jor': jor, 'eff_lat': eff_lat, 'eff_ax': eff_ax,
                 'eff_3d': eff_3d}

    return perf_dict, matches



### Match first bead

In [201]:
bead_data = []

In [202]:
for beadno in range(len(frame_wise_xyz[0])):
    single_bead_data = {}
    single_bead_data['bead_no'] = beadno
    single_bead_data['x'] = [frame_wise_xyz[0][beadno][2]]
    single_bead_data['y'] = [frame_wise_xyz[0][beadno][3]]
    single_bead_data['z'] = [frame_wise_xyz[0][beadno][4]]
    single_bead_data['ph'] = [frame_wise_xyz[0][beadno][5]]
    single_bead_data['x_sigma'] = [frame_wise_xyz[0][beadno][7]]
    single_bead_data['y_sigma'] = [frame_wise_xyz[0][beadno][8]]
    single_bead_data['z_sigma'] = [frame_wise_xyz[0][beadno][9]]
    single_bead_data['ph_sigma'] = [frame_wise_xyz[0][beadno][10]]
    for i in range(1, 200):
        perf_dict, matches = match_two_frames([frame_wise_xyz[0][beadno]], frame_wise_xyz[i], min_int=0, limited_x=[0, fov_size[0]],
                                     limited_y=[0, fov_size[1]], border=450, print_res=False, tolerance=450,
                                     tolerance_ax=np.inf)
        if len(matches) == 1:
            single_bead_data['x'].append(matches[0, 4])
            single_bead_data['y'].append(matches[0, 5])
            single_bead_data['z'].append(matches[0, 6])
            single_bead_data['ph'].append(matches[0, 7])
            single_bead_data['x_sigma'].append(matches[0, 8])
            single_bead_data['y_sigma'].append(matches[0, 9])
            single_bead_data['z_sigma'].append(matches[0, 10])
            single_bead_data['ph_sigma'].append(matches[0, 11])
            #print(f"First bead x: {matches[:, 0]} -- matched: {matches[:, 4]}")
            #print(f"First bead y: {matches[:, 1]} -- matched: {matches[:, 5]}")
            #print(f"First bead z: {matches[:, 2]} -- matched: {matches[:, 6]}")
            #print(f"First bead ph: {matches[:, 3]} -- matched: {matches[:, 7]}")
            #print("---------")
        else:
            print(f"Skipping frame {i} --- :(")
    bead_data.append(single_bead_data)

matches is empty!
Skipping frame 50 --- :(
matches is empty!
Skipping frame 79 --- :(
matches is empty!
Skipping frame 100 --- :(
matches is empty!
Skipping frame 109 --- :(
matches is empty!
Skipping frame 140 --- :(
matches is empty!
Skipping frame 151 --- :(
matches is empty!
Skipping frame 166 --- :(
matches is empty!
Skipping frame 170 --- :(
matches is empty!
Skipping frame 191 --- :(
matches is empty!
Skipping frame 49 --- :(
matches is empty!
Skipping frame 2 --- :(
matches is empty!
Skipping frame 8 --- :(
matches is empty!
Skipping frame 13 --- :(
matches is empty!
Skipping frame 36 --- :(
matches is empty!
Skipping frame 37 --- :(
matches is empty!
Skipping frame 42 --- :(
matches is empty!
Skipping frame 46 --- :(
matches is empty!
Skipping frame 49 --- :(
matches is empty!
Skipping frame 52 --- :(
matches is empty!
Skipping frame 56 --- :(
matches is empty!
Skipping frame 57 --- :(
matches is empty!
Skipping frame 62 --- :(
matches is empty!
Skipping frame 63 --- :(
matche

  jor = 100 * jaccard / rmse_lat


In [207]:
bead_data[0]['z']

[-354.8688049316406,
 -359.8374328613281,
 -372.4197692871094,
 -264.88800048828125,
 -347.1101379394531,
 -328.5448303222656,
 -332.4219055175781,
 -368.6199645996094,
 -346.70733642578125,
 -365.08038330078125,
 -422.2132263183594,
 -344.4670715332031,
 -340.6261901855469,
 -367.3184509277344,
 -362.8914489746094,
 -350.0606384277344,
 -335.74578857421875,
 -277.60968017578125,
 -324.6253662109375,
 -284.5195007324219,
 -375.144775390625,
 -354.16845703125,
 -325.4465637207031,
 -409.65631103515625,
 -358.8760070800781,
 -324.3769226074219,
 -341.0977478027344,
 -321.09515380859375,
 -334.6322937011719,
 -396.97149658203125,
 -282.299560546875,
 -367.8581848144531,
 -313.337646484375,
 -316.9072265625,
 -386.2077331542969,
 -298.71026611328125,
 -313.7299499511719,
 -325.0224914550781,
 -304.2319641113281,
 -328.16265869140625,
 -330.14813232421875,
 -306.33026123046875,
 -302.2763671875,
 -332.07666015625,
 -368.2107238769531,
 -266.541748046875,
 -288.615966796875,
 -288.9794311523

In [222]:
len(bead_data[0]['x'])

191

In [229]:
def plot_bead_data(single_bead_data):
    fig, ax = plt.subplots(nrows=4, ncols=4)
    ax[0, 0].hist(single_bead_data['x'], bins=30, density=True)
    ax[0, 0].axvspan(min(single_bead_data['x']), max(single_bead_data['x']), color='green', alpha=0.1)
    plt.tight_layout()
    plt.show()

In [230]:
plot_bead_data(bead_data[0])

#### Predictions are a list of numbers where each element corresponds to a localizaiton in the following order


    1. counter of the molecule per tile.. not a global counter on the frame, need to fix this
    2. image number used to index into the number of the image in the prediction arrays.
    Eval image is of shape [1, 1041, 1302], so, the image number will always be 1, if you give one image at a time
    3. x position in nm where 0 is top left corner
    4. y position in nm where 0 is top left corner
    5. z position in nm where 0 is from the reference 0 nm in height
    6. photon counts 
    7. probability afer nms
    8. x_sigma in nm 
    9. y_sigma in nm
    10. z_sigma in nm
    11. photon_counts_sigma
    12. x offset
    13. y offset
