In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import sys
import pathlib
import pickle
from pathlib import Path
from skimage.io import imread
import seaborn as sns
sns.set_style('white')
sys.setrecursionlimit(10000)
%matplotlib qt5

In [3]:
import cellbgnet
import cellbgnet.utils
from cellbgnet.datasets import DataSimulator
from cellbgnet.utils.hardware import cpu, gpu
from cellbgnet.model import CellBGModel
from cellbgnet.utils.plot_funcs import plot_od, plot_train_record
from cellbgnet.analyze_eval import recognition, plot_full_img_predictions, assemble_full_img_predictions

In [4]:
model_path = Path('/mnt/sda1/SMLAT/training_runs/model_rotated_45_venus25nm_dec12_700range.pkl')

In [5]:
with open(model_path, 'rb') as f:
    dots_model = pickle.load(f)

In [6]:
dots_model.evaluation_params['eval_imgs'].shape

(30, 1041, 1302)

In [7]:
plot_train_record(dots_model)

In [8]:
eval_img = dots_model.evaluation_params['eval_imgs']
fov_size = [eval_img.shape[2] * 65, eval_img.shape[1] * 65]

In [9]:
eval_img.shape

(30, 1041, 1302)

In [10]:
plt.figure()
plt.imshow(eval_img[0], cmap='gray')
plt.colorbar()
plt.show()

In [13]:
preds_tmp, n_per_img, plot_data = recognition(model=dots_model, eval_imgs_all=eval_img,
                                             batch_size=16, use_tqdm=False,
                                             nms=True, candidate_threshold=0.05,
                                             nms_threshold=0.05, 
                                             pixel_nm=dots_model.data_generator.psf_params['pixel_size_xy'],
                                             plot_num=1,
                                             win_size=128,
                                             padding=True,
                                             start_field_pos=[0, 0],
                                             padded_background=dots_model.evaluation_params['padded_background'])

processing area:99/99, input field_xy:[1258 1301 1001 1040], use_coordconv:True, retain locs in area:[1278, 1301, 1021, 1040]


In [20]:
preds_tmp

[[1.0,
  1.0,
  28071.548828125,
  6910.64306640625,
  -580.4024658203125,
  2122.10009765625,
  1.0663213729858398,
  13.355852127075195,
  9.590226173400879,
  85.94347381591797,
  210.8159942626953,
  -0.13002897799015045,
  0.3175883889198303],
 [3.0,
  1.0,
  27930.359375,
  7111.51123046875,
  493.86138916015625,
  1982.6844482421875,
  0.5421614646911621,
  25.71227264404297,
  30.051565170288086,
  143.42201232910156,
  355.0440979003906,
  -0.30216529965400696,
  -0.5921372175216675],
 [4.0,
  1.0,
  32968.1796875,
  7495.14404296875,
  -67.52432250976562,
  1903.1248779296875,
  1.0937373638153076,
  12.749629020690918,
  11.122882843017578,
  57.60746383666992,
  164.44122314453125,
  0.2027347832918167,
  0.3099079728126526],
 [1.0,
  1.0,
  33730.6875,
  7066.62451171875,
  584.8718872070312,
  2222.175048828125,
  1.0424195528030396,
  11.358304023742676,
  10.661754608154297,
  50.04265213012695,
  166.85073852539062,
  -0.06637296825647354,
  -0.2827022075653076],
 [2.0

In [15]:
plot_full_img_predictions(dots_model, plot_infs=plot_data, eval_csv=None, plot_num=1, fov_size=fov_size, pixel_size=[65, 65])

In [16]:
img_infs = assemble_full_img_predictions(dots_model, plot_data)

In [17]:
img_infs.keys()

dict_keys(['Probs', 'XO', 'YO', 'ZO', 'Int', 'BG', 'XO_sig', 'YO_sig', 'ZO_sig', 'Int_sig', 'Probs_ps', 'XO_ps', 'YO_ps', 'ZO_ps', 'Samples_ps', 'raw_img', 'only_bg'])

In [18]:
img_infs['Probs']

array([[1.12535155e-07, 1.12535155e-07, 1.12535155e-07, ...,
        1.12535155e-07, 1.12535155e-07, 1.12535155e-07],
       [1.12535155e-07, 1.12535155e-07, 1.12535155e-07, ...,
        1.12535155e-07, 1.12535155e-07, 1.12535155e-07],
       [1.12535155e-07, 1.12535155e-07, 1.12535155e-07, ...,
        1.12535155e-07, 1.12535155e-07, 1.12535155e-07],
       ...,
       [1.12535155e-07, 1.12535155e-07, 1.12535155e-07, ...,
        1.12535155e-07, 1.12535155e-07, 1.12535155e-07],
       [1.12535155e-07, 1.12535155e-07, 1.12535155e-07, ...,
        1.12535155e-07, 1.12535155e-07, 1.12535155e-07],
       [1.12535155e-07, 1.12535155e-07, 1.12535155e-07, ...,
        1.12535155e-07, 1.12535155e-07, 1.57667114e-06]])

In [19]:
plt.figure()
plt.imshow(img_infs['Probs'])
plt.show()