In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from rtseg.oldseg.transforms import UnetTestTransforms
from rtseg.oldseg.networks import model_dict as seg_model_dict
from rtseg.utils.param_io import load_params
from pathlib import Path
import torch
%matplotlib qt5

In [3]:
phase_path = Path('/home/pk/Documents/REALTIME/data/test/img_000000000.tiff')
fluor_path = Path('/home/pk/Documents/REALTIME/data/test/img_000000000_fluor.tiff')

In [4]:
phase_img = imread(phase_path)
fluor_img = imread(fluor_path)

In [5]:
fig, ax = plt.subplots(nrows=2, ncols=1)
ax[0].imshow(phase_img, cmap='gray')
ax[1].imshow(fluor_img, cmap='gray', vmin=0, vmax=500)
plt.show()

#### Segment phase

In [6]:
params_path = Path("/home/pk/Documents/rtseg/rtseg/resources/reference_params/reference_linux.yaml")
params = load_params(params_path, ref_type='expt')

In [7]:
segment_params = params.Segmentation

In [8]:
model = seg_model_dict[segment_params.architecture]
segment_model = model.parse(channels_by_scale=segment_params.model_params.channels_by_scale,
                                 num_outputs=segment_params.model_params.num_outputs,
                                 upsample_type=segment_params.model_params.upsample_type,
                                 feature_fusion_type=segment_params.model_params.feature_fusion_type).to(device='cuda:0')
segment_model_path = params.Segmentation.model_paths.both
segment_model.load_state_dict(torch.load(segment_model_path, map_location='cuda:0'))
segment_model.eval()

Unet(
  (down_layers): Sequential(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ConvBlock(
      (block): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [9]:
pre_segment_transforms = UnetTestTransforms()

In [10]:
raw_shape = phase_img.shape

In [11]:
seg_sample = pre_segment_transforms({'phase': phase_img.astype('float32'), 'raw_shape': raw_shape})

In [12]:
seg_sample

{'phase': tensor([[[ 0.1105,  0.1157,  0.1183,  ...,  0.1289,  0.1285,  0.1298],
          [ 0.1097,  0.1155,  0.1156,  ...,  0.1262,  0.1259,  0.1262],
          [ 0.1120,  0.1160,  0.1195,  ...,  0.1283,  0.1295,  0.1289],
          ...,
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]]]),
 'raw_shape': (891, 2048)}

In [15]:
with torch.inference_mode():
    seg_pred = segment_model(seg_sample['phase'].unsqueeze(0).to('cuda:0')).sigmoid().cpu().numpy().squeeze(0)

In [16]:
seg_pred.shape

(2, 896, 2048)

In [17]:
cell_mask = seg_pred[0][:raw_shape[0], :raw_shape[1]] > 0.9

In [18]:
channel_mask = seg_pred[1][:raw_shape[0], :raw_shape[1]] > 0.8

In [19]:
fig, ax = plt.subplots(nrows=2, ncols=1)
ax[0].imshow(cell_mask)
ax[1].imshow(channel_mask)
plt.show()

In [20]:
cell_mask.shape, fluor_img.shape

((891, 2048), (891, 2048))

#### Detect dots

In [21]:
from rtseg.dotdetection.detect import compute_spot_binary_mask, compute_wavelet_planes
from skimage.measure import label, regionprops

In [23]:
binary_spot_mask = compute_spot_binary_mask(fluor_img, cell_mask, noise_threshold=3.0, wavelet_plane_no=1, device='cuda:0')

In [24]:
binary_spot_mask.shape

(891, 2048)

In [25]:
plt.figure()
plt.imshow(binary_spot_mask)
plt.show()

In [26]:
spot_stats = regionprops(label(binary_spot_mask), fluor_img)
min_spot_area = 0.0 # 4 is default
max_axes_ratio = 1.7

In [27]:
spot_filtered = [item for item in spot_stats if item.area > min_spot_area]

spot_axes_ratio = np.array([spot.axis_major_length / spot.axis_minor_length for spot in spot_filtered])
spot_areas = np.array([spot.area for spot in spot_filtered])
# indices in spot filtered based on area
dot_coords = [spot.centroid_weighted for spot in spot_filtered]

In [28]:
len(dot_coords)
dot_coords_np = np.array(dot_coords)

In [31]:
def plot_dots(fluor_image, dots):
    plt.figure()
    plt.imshow(fluor_image, cmap='gray')
    plt.plot(dots[:, 1], dots[:, 0], 'ro')
    plt.show()

In [32]:
plot_dots(fluor_img, dot_coords_np)

#### Calculate backbones

In [34]:
cell_mask_label = label(cell_mask)

In [35]:
plt.figure()
plt.imshow(cell_mask_label)
plt.show()

In [33]:
def plot_dots_on_mask(mask, dots):
    plt.figure()
    plt.imshow(mask)
    plt.plot(dots[:, 1], dots[:, 0], 'ro')
    plt.show()

In [36]:
plot_dots_on_mask(cell_mask_label, dot_coords_np)

In [37]:
from rtseg.cells.utils import regionprops_custom

In [39]:
props = regionprops_custom(cell_mask_label)

In [40]:
len(props)

741

In [41]:
dot_coords_np.shape

(1212, 2)

In [42]:
props[2].arc_length, props[2].poles, props[2].fit_coeff

(array([19.28411794]),
 array([[-0.5       ,  3.82288664],
        [ 6.5       , -4.67106966]]),
 array([-0.65794834,  2.73426772,  5.35450759]))

In [45]:
dot_coords_int = dot_coords_np.astype('int')

In [49]:
x, y = dot_coords_int[:, 0], dot_coords_int[:, 1]

In [51]:
cell_mask_label[x, y]

array([  0,   0,   0, ..., 721, 733,   0], dtype=int32)

In [62]:
dot_labels = cell_mask_label[x, y]

In [63]:
len(x), len(y), len(dot_labels)

(1212, 1212, 1212)

In [67]:
dots_inside_cells_idx = np.nonzero(dot_labels)[0]

In [100]:
dots_inside_cells_idx[20:40]

array([ 97,  99, 100, 101, 102, 103, 104, 105, 107, 108, 109, 110, 111,
       112, 113, 114, 115, 116, 117, 118])

In [89]:
cell_labels_for_dots = dot_labels[dots_inside_cells_idx]

In [101]:
cell_labels_for_dots[20:40]

array([ 99,  59, 101,  73,  84,  63,  79,  92,  98,  91,  81,  93, 111,
       107, 108, 106,  95, 109,  88, 115], dtype=int32)

In [105]:
plt.figure()
plt.imshow(cell_mask_label == 115)
plt.plot(dot_coords_np[118][1], dot_coords_np[118][0], 'ro')
plt.show()

In [106]:
plt.figure()
plt.imshow(props[114].image)
plt.show()

In [107]:
props[114].arc_length

array([9.69695333])

In [108]:
props[114].poles

array([[-0.5       , 22.2831658 ],
       [ 7.5       , 26.26673705]])

In [113]:
props[114].axis_major_length

59.02163246623325

In [115]:
%%timeit
np.rot90(cell_mask_label)

5.2 μs ± 14.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [116]:
np.rot90(cell_mask_label).shape

(2048, 891)

In [110]:
plt.figure()
plt.imshow(np.rot90(cell_mask_label))
plt.show()

In [98]:
dot_coords_np[:, 0].max()

716.5766723842196

In [99]:
dot_coords_np[:, 1].max()

2039.506221719457

In [95]:
cell_mask_label[

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)

#### Make forkplots