In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from rtseg.oldseg.transforms import UnetTestTransforms
from rtseg.oldseg.networks import model_dict as seg_model_dict
from rtseg.utils.param_io import load_params
from rtseg.cells.utils import compute_projected_points, compute_arc_length
from pathlib import Path
import torch
%matplotlib qt5

In [3]:
phase_path = Path('/home/pk/Documents/rtclient/rtclient/resources/test_images/phase_dummy.tiff')
fluor_path = Path('/home/pk/Documents/rtclient/rtclient/resources/test_images/fluor_dummy.tiff')

In [4]:
phase_img = imread(phase_path)
fluor_img = imread(fluor_path)

In [5]:
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].imshow(phase_img, cmap='gray')
ax[1].imshow(fluor_img, cmap='gray', vmin=0, vmax=500)
plt.show()

#### Segment phase

In [6]:
params_path = Path("/home/pk/Documents/rtseg/rtseg/resources/reference_params/reference_linux.yaml")
params = load_params(params_path, ref_type='expt')

In [7]:
segment_params = params.Segmentation

In [8]:
model = seg_model_dict[segment_params.architecture]
segment_model = model.parse(channels_by_scale=segment_params.model_params.channels_by_scale,
                                 num_outputs=segment_params.model_params.num_outputs,
                                 upsample_type=segment_params.model_params.upsample_type,
                                 feature_fusion_type=segment_params.model_params.feature_fusion_type).to(device='cuda:0')
segment_model_path = params.Segmentation.model_paths.both
segment_model.load_state_dict(torch.load(segment_model_path, map_location='cuda:0'))
segment_model.eval()

Unet(
  (down_layers): Sequential(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ConvBlock(
      (block): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [9]:
pre_segment_transforms = UnetTestTransforms()

In [10]:
raw_shape = phase_img.shape

In [11]:
seg_sample = pre_segment_transforms({'phase': phase_img.astype('float32'), 'raw_shape': raw_shape})

In [12]:
seg_sample

{'phase': tensor([[[ 0.0775,  0.0847,  0.0867,  ..., -0.5000, -0.5000, -0.5000],
          [ 0.0833,  0.0787,  0.0904,  ..., -0.5000, -0.5000, -0.5000],
          [ 0.0816,  0.0884,  0.0862,  ..., -0.5000, -0.5000, -0.5000],
          ...,
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]]]),
 'raw_shape': (1500, 3036)}

In [13]:
with torch.inference_mode():
    seg_pred = segment_model(seg_sample['phase'].unsqueeze(0).to('cuda:0')).sigmoid().cpu().numpy().squeeze(0)

In [14]:
seg_pred.shape

(2, 1504, 3040)

In [22]:
cell_mask = seg_pred[0][:raw_shape[0], :raw_shape[1]] > 0.85

In [23]:
channel_mask = seg_pred[1][:raw_shape[0], :raw_shape[1]] > 0.8

In [24]:
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].imshow(cell_mask)
ax[1].imshow(channel_mask)
plt.show()

In [25]:
cell_mask.shape, fluor_img.shape

((1500, 3036), (1500, 3036))

#### Detect dots

In [26]:
from rtseg.dotdetection.detect import compute_spot_binary_mask, compute_wavelet_planes
from skimage.measure import label, regionprops

In [27]:
cell_mask_rot = np.rot90(cell_mask).copy()
fluor_img_rot = np.rot90(fluor_img).copy()

In [105]:
plt.figure()
plt.imshow(label(cell_mask).T)
plt.show()

In [106]:
plt.figure()
plt.imshow(label(cell_mask))
plt.show()

In [32]:
binary_spot_mask = compute_spot_binary_mask(fluor_img_rot, cell_mask_rot, noise_threshold=3.0, wavelet_plane_no=2, device='cuda:0')

In [33]:
binary_spot_mask.shape

(3036, 1500)

In [34]:
plt.figure()
plt.imshow(binary_spot_mask)
plt.show()

In [35]:
spot_stats = regionprops(label(binary_spot_mask), fluor_img_rot)
min_spot_area = 0.0 # 4 is default
max_axes_ratio = 1.7

In [36]:
spot_filtered = [item for item in spot_stats if item.area > min_spot_area]

spot_axes_ratio = np.array([spot.axis_major_length / spot.axis_minor_length for spot in spot_filtered])
spot_areas = np.array([spot.area for spot in spot_filtered])
# indices in spot filtered based on area
dot_coords = [spot.centroid_weighted for spot in spot_filtered]

In [37]:
len(dot_coords)
dot_coords_np = np.array(dot_coords)

In [107]:
def plot_dots(fluor_image, dots):
    plt.figure()
    plt.imshow(fluor_image, cmap='gray')
    plt.plot(dots[:, 1], dots[:, 0], 'ro')
    plt.show()

In [108]:
plot_dots(fluor_img_rot, dot_coords_np)

In [99]:
plt.figure()
plt.matshow(fluor_img, cmap='gray')
plt.show()

In [100]:
plt.matshow(phase_img, cmap='gray')

<matplotlib.image.AxesImage at 0x79d7ee06f810>

In [102]:
plt.matshow(fluor_img.T, cmap='gray')

<matplotlib.image.AxesImage at 0x79d7ecfe4650>

In [98]:
plt.figure()
plt.imshow(fluor_img.T, cmap='gray')
plt.show()

In [85]:
dot_coords_np[:, 0].max()

3015.8693693693695

In [90]:
dots_on_raw_img = np.zeros_like(dot_coords_np)

In [65]:
raw_shape

(1500, 3036)

In [111]:
dots_on_raw_img[:, 0] = dot_coords_np[:, 1]
dots_on_raw_img[:, 1] = -dot_coords_np[:, 0] + raw_shape[1] - 1

In [112]:
plot_dots(fluor_img, dots_on_raw_img)

#### Calculate backbones

In [40]:
cell_mask_label = label(cell_mask_rot)

In [41]:
plt.figure()
plt.imshow(cell_mask_label)
plt.show()

In [42]:
def plot_dots_on_mask(mask, dots):
    plt.figure()
    plt.imshow(mask)
    plt.plot(dots[:, 1], dots[:, 0], 'ro')
    plt.show()

In [43]:
plot_dots_on_mask(cell_mask_label, dot_coords_np)

In [44]:
from rtseg.cells.utils import regionprops_custom

In [45]:
props = regionprops_custom(cell_mask_label)

In [46]:
len(props)

427

In [47]:
dot_coords_np.shape

(730, 2)

In [48]:
props[2].arc_length, props[2].poles, props[2].fit_coeff

(array([46.02826353]),
 array([[-0.5       ,  7.24097261],
        [45.5       ,  5.73080659]]),
 array([-4.63617945e-04, -1.19668886e-02,  7.23510507e+00]))

In [49]:
dot_coords_int = dot_coords_np.astype('int')

In [50]:
x, y = dot_coords_int[:, 0], dot_coords_int[:, 1]

In [51]:
cell_mask_label[x, y]

array([  0,   0,   2,   1,   0,   1,   4,   0,   6,   0,   1,   1,   7,
         2,   5,   7,   8,   8,   9,   0,   9,  10,  10,  11,   9,  11,
        13,  11,  12,  13,   0,   0,  18,  17,   0,  21,  22,  20,  22,
        25,  20,  21,  23,  26,   0,  23,  24,  24,  26,  29,  28,  30,
        30,   0,  29,  30,   0,   0,  34,  33,  33,   0,  31,   0,  32,
         0,  35,  37,  37,  38,   0,  41,  42,  42,  43,  42,  42,  42,
         0,  45,   0,   0,  46,   0,  48,  47,  47,  49,  46,  50,  50,
        48,  49,  51,  56,  55,  53,  55,  56,  58,  58,  53,  55,  58,
         0,  58,  61,  60,  63,  64,  61,  62,  62,  65,  63,  67,  64,
        67,  70,  68,  69,  70,  72,  71,  71,  75,  75,  76,  77,  78,
        77,  82,  78,  82,  81,  78,  81,  82,  81,  83,  82,  83,   0,
        87,  88,   0,   0,  86,  89,  88,  89,  87,  90,  95,  93,  96,
        95,  93,  97,   0,  97,  96,  97,  97,  98,  97,  97,  98,  98,
       102,   0, 103, 103, 104, 106, 105, 104, 108, 107, 107, 10

In [52]:
dot_labels = cell_mask_label[x, y]

In [53]:
len(x), len(y), len(dot_labels)

(730, 730, 730)

In [54]:
# each index will give the dot coordinate of from the dot_coords_np array
dots_inside_cells_idx = np.nonzero(dot_labels)[0]

In [55]:
# each element is the label number of the cell in this array
cell_labels_for_dots = dot_labels[dots_inside_cells_idx]

In [56]:
plt.figure()
plt.imshow(cell_mask_label == cell_labels_for_dots[i])
plt.plot(dot_y, dot_x, 'ro')
plt.show()

NameError: name 'i' is not defined

In [57]:
len(cell_labels_for_dots)

618

In [59]:

for i in range(len(cell_labels_for_dots)):
    cell_label = cell_labels_for_dots[i]
    dot_x, dot_y = dot_coords_np[dots_inside_cells_idx[i]]
    cell_prop = props[cell_label-1]
    fit_coeff = cell_prop.fit_coeff
    poles = cell_prop.poles
    img = cell_prop.image
    img_size = img.shape
    x_data = np.arange(-0.5, img_size[1]+0.5)
    y_data = fit_coeff[0] * x_data**2 + fit_coeff[1] * x_data + fit_coeff[2]
    local_x, local_y = dot_x - cell_prop.bbox[0], dot_y - cell_prop.bbox[1]
    
    
    projected_point, _ = compute_projected_points(fit_coeff, np.array([[local_y, local_x]]))
    #distance_to_pole = np.sqrt((projected_point[0, 0] - poles[0, 1])**2 + (projected_point[0, 1] - poles[0, 0])**2)
    distance_to_pole_along_arc = compute_arc_length(fit_coeff, poles[0, 0], projected_point[0,0])

In [60]:
plt.figure()
plt.imshow(img)
plt.plot(x_data, y_data, 'r--')
plt.plot(poles[:, 0], poles[:, 1], '*')
plt.plot(local_y, local_x, 'go')
plt.plot(projected_point[0, 0], projected_point[0, 1], 'b*')
plt.title(f"Distance to pole: {distance_to_pole_along_arc[0]}, arc_length: {cell_prop.arc_length[0]}")
plt.show()

In [121]:
projected_point

array([[35.32503436,  4.16863798]])

In [106]:
projected_point, _ = compute_projected_points(fit_coeff, np.array([[local_y, local_x]]))


In [107]:
projected_point

array([[12.25525112,  4.09683844]])

In [91]:
internal_x

3.1760833790455294

In [89]:
poles[0]

array([-0.5      ,  2.5322699])

In [86]:
internal_x, internal_y

(3.1760833790455294, 8.991771804717473)

In [75]:
dot_x, dot_y

(45.17608337904553, 491.9917718047175)

In [80]:
plt.figure()
plt.imshow(cell_mask_label == cell_label)
plt.plot(dot_y, dot_x, 'ro')
plt.show()

In [82]:
cell_prop.bbox

(42, 483, 51, 528)

In [59]:
plt.figure()
plt.imshow(cell_mask_label == 31)
plt.plot(dot_coords_np[43][1], dot_coords_np[43][0], 'ro')
plt.show()

In [61]:
props[30].arc_length

array([25.06184152])

In [62]:
props[30].poles

array([[-0.5       ,  5.28563777],
       [24.5       ,  3.6100112 ]])

In [64]:
props[30].axis_major_length

26.441570332452198

In [65]:
props[30].poles

array([[-0.5       ,  5.28563777],
       [24.5       ,  3.6100112 ]])

#### Make forkplots