In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from rtseg.oldseg.transforms import UnetTestTransforms
from rtseg.oldseg.networks import model_dict as seg_model_dict
from rtseg.utils.param_io import load_params
from rtseg.cells.utils import compute_projected_points, compute_arc_length
from pathlib import Path
import torch
%matplotlib qt5

In [3]:
phase_path = Path('/home/pk/Documents/rtclient/rtclient/resources/test_images/phase_dummy.tiff')
fluor_path = Path('/home/pk/Documents/rtclient/rtclient/resources/test_images/fluor_dummy.tiff')

In [4]:
phase_img = imread(phase_path)
fluor_img = imread(fluor_path)

In [5]:
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].imshow(phase_img, cmap='gray')
ax[1].imshow(fluor_img, cmap='gray', vmin=0, vmax=500)
plt.show()

#### Segment phase

In [6]:
params_path = Path("/home/pk/Documents/rtseg/rtseg/resources/reference_params/reference_linux.yaml")
params = load_params(params_path, ref_type='expt')

In [7]:
segment_params = params.Segmentation

In [8]:
model = seg_model_dict[segment_params.architecture]
segment_model = model.parse(channels_by_scale=segment_params.model_params.channels_by_scale,
                                 num_outputs=segment_params.model_params.num_outputs,
                                 upsample_type=segment_params.model_params.upsample_type,
                                 feature_fusion_type=segment_params.model_params.feature_fusion_type).to(device='cuda:0')
segment_model_path = params.Segmentation.model_paths.both
segment_model.load_state_dict(torch.load(segment_model_path, map_location='cuda:0'))
segment_model.eval()

Unet(
  (down_layers): Sequential(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ConvBlock(
      (block): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [9]:
pre_segment_transforms = UnetTestTransforms()

In [10]:
raw_shape = phase_img.shape

In [11]:
seg_sample = pre_segment_transforms({'phase': phase_img.astype('float32'), 'raw_shape': raw_shape})

In [12]:
seg_sample

{'phase': tensor([[[ 0.0775,  0.0847,  0.0867,  ..., -0.5000, -0.5000, -0.5000],
          [ 0.0833,  0.0787,  0.0904,  ..., -0.5000, -0.5000, -0.5000],
          [ 0.0816,  0.0884,  0.0862,  ..., -0.5000, -0.5000, -0.5000],
          ...,
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]]]),
 'raw_shape': (1500, 3036)}

In [13]:
with torch.inference_mode():
    seg_pred = segment_model(seg_sample['phase'].unsqueeze(0).to('cuda:0')).sigmoid().cpu().numpy().squeeze(0)

In [14]:
seg_pred.shape

(2, 1504, 3040)

In [15]:
cell_mask = seg_pred[0][:raw_shape[0], :raw_shape[1]] > 0.85

In [16]:
channel_mask = seg_pred[1][:raw_shape[0], :raw_shape[1]] > 0.8

In [17]:
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].imshow(cell_mask)
ax[1].imshow(channel_mask)
plt.show()

In [18]:
cell_mask.shape, fluor_img.shape

((1500, 3036), (1500, 3036))

#### Detect dots

In [19]:
from rtseg.dotdetection.detect import compute_spot_binary_mask, compute_wavelet_planes
from skimage.measure import label, regionprops

In [20]:
cell_mask_label = label(cell_mask)

In [21]:
cell_mask_rot = np.rot90(cell_mask_label).copy()
fluor_img_rot = np.rot90(fluor_img).copy()

In [22]:
binary_spot_mask = compute_spot_binary_mask(fluor_img_rot, cell_mask_rot, noise_threshold=3.0, wavelet_plane_no=2, device='cuda:0')

In [23]:
binary_spot_mask.shape

(3036, 1500)

In [24]:
plt.figure()
plt.imshow(binary_spot_mask)
plt.show()

In [25]:
spot_stats = regionprops(label(binary_spot_mask), fluor_img_rot)
min_spot_area = 0.0 # 4 is default
max_axes_ratio = 1.7

In [26]:
spot_filtered = [item for item in spot_stats if item.area > min_spot_area]

spot_axes_ratio = np.array([spot.axis_major_length / spot.axis_minor_length for spot in spot_filtered])
spot_areas = np.array([spot.area for spot in spot_filtered])
# indices in spot filtered based on area
dot_coords = [spot.centroid_weighted for spot in spot_filtered]

In [27]:
len(dot_coords)
dot_coords_np = np.array(dot_coords)

In [28]:
def plot_dots(fluor_image, dots):
    plt.figure()
    plt.imshow(fluor_image, cmap='gray')
    plt.plot(dots[:, 1], dots[:, 0], 'ro')
    plt.show()

In [29]:
plot_dots(fluor_img_rot, dot_coords_np)

In [30]:
dot_coords_np[:, 0].max()

3015.8693693693695

In [31]:
dots_on_raw_img = np.zeros_like(dot_coords_np)

In [32]:
raw_shape

(1500, 3036)

In [33]:
dots_on_raw_img[:, 0] = dot_coords_np[:, 1]
dots_on_raw_img[:, 1] = -dot_coords_np[:, 0] + raw_shape[1] - 1

In [34]:
plot_dots(fluor_img, dots_on_raw_img)

#### Calculate backbones

In [35]:
#cell_mask_label = label(cell_mask_rot)
cell_mask_label = np.rot90(cell_mask_label)

In [36]:
plt.figure()
plt.imshow(cell_mask_label)
plt.show()

In [37]:
def plot_dots_on_mask(mask, dots):
    plt.figure()
    plt.imshow(mask)
    plt.plot(dots[:, 1], dots[:, 0], 'ro')
    plt.show()

In [38]:
plot_dots_on_mask(cell_mask_label, dot_coords_np)

In [39]:
from rtseg.cells.utils import regionprops_custom

In [40]:
props = regionprops_custom(cell_mask_label)

In [41]:
len(props)

427

In [42]:
dot_coords_np.shape

(730, 2)

In [44]:
dot_coords_int = dot_coords_np.astype('int')

In [45]:
x, y = dot_coords_int[:, 0], dot_coords_int[:, 1]

In [46]:
cell_mask_label[x, y]

array([  0,   0, 133, 102,   0, 102, 170,   0, 210,   0, 102, 102, 244,
       133, 191, 244, 276, 276, 313,   0, 313, 345, 345, 369, 313, 369,
       410, 369, 392, 410,   0,   0, 138, 121,   0, 203, 235, 179, 235,
       317, 179, 203, 259, 342,   0, 259, 290, 290, 342, 387, 366, 404,
       404,   0, 387, 404,   0,   0, 134, 158, 158,   0,  92,   0, 113,
         0, 183, 231, 231, 262,   0, 310, 341, 341, 326, 341, 341, 341,
         0, 417,   0,   0, 100,   0, 163, 131, 131, 195, 100, 224, 224,
       163, 195, 249, 322, 299, 275, 299, 322, 365, 365, 275, 299, 365,
         0, 365, 187, 169, 241, 271, 187, 212, 212, 308, 241, 343, 271,
       343, 389, 361, 375, 389, 412, 402, 402, 117, 117, 161, 221, 265,
       221, 349, 265, 349, 325, 265, 325, 349, 325, 376, 349, 376,   0,
       111, 143,   0,   0,  83, 166, 143, 166, 111, 218, 285, 237, 304,
       285, 237, 324,   0, 324, 304, 324, 324, 397, 324, 324, 397, 397,
        88,   0, 106, 106, 142, 200, 171, 142, 247, 226, 226, 26

In [47]:
dot_labels = cell_mask_label[x, y]

In [48]:
unique_cell_labels, dot_counts = np.unique(dot_labels, return_counts=True, return_index=False)

In [49]:
len(x), len(y), len(dot_labels)

(730, 730, 730)

In [50]:
dot_idx = np.where(dot_labels == 42)[0]

In [101]:

data_to_save = []
for i, single_cell_label in enumerate(unique_cell_labels, 0):
    if single_cell_label != 0:
        dot_idxs = np.where(dot_labels == single_cell_label)[0]
        dots_per_cell = dot_counts[i]
        # now each of the dot idx will give a set of dots associated with 
        # cell label
        cell_prop = props[single_cell_label-1]
        fit_coeff = cell_prop.fit_coeff
        poles = cell_prop.poles
        img = cell_prop.image
        img_size = img.shape
        bbox = cell_prop.bbox
        arc_length = cell_prop.arc_length[0]
        x_data = np.arange(-0.5, img_size[1]+0.5)
        y_data = fit_coeff[0] * x_data**2 + fit_coeff[1] * x_data + fit_coeff[2]
        # for all dots inside the cell, compute internal coordinates

        #plt.figure()
        #plt.imshow(img)
        #plt.plot(x_data, y_data, 'r--')
        #plt.plot(poles[:, 0], poles[:, 1], '*')
        
        
        for dot_idx in dot_idxs:
            dot_x, dot_y = dot_coords_np[dot_idx]
            local_x, local_y = dot_x - bbox[0], dot_y - bbox[1]
            projected_point, internal_y = compute_projected_points(fit_coeff, np.array([[local_y, local_x]]))
            distance_to_pole_along_arc = compute_arc_length(fit_coeff, poles[0, 0], projected_point[0, 0])

            #plt.plot([local_y, projected_point[0, 0]], [local_x, projected_point[0, 1]], 'b--')
            #plt.plot(local_y, local_x, 'go')
            #plt.plot(projected_point[0, 0], projected_point[0, 1], 'b*')
            #print(distance_to_pole_along_arc[0], arc_length)
            dot_datapoint = {'position': 0,
                             'time': 0,
                             'trap': None,
                             'cell_label': single_cell_label,
                             'area': cell_prop.area,
                             'length': arc_length,
                             'normalization_counts': dots_per_cell,
                             'internal_coord': (distance_to_pole_along_arc[0], internal_y[0]),
                             'normalized_internal_x': distance_to_pole_along_arc[0]/arc_length,
                             'bbox': bbox,
                             'global_coords': (dot_x, dot_y),
                             'local_coords': (local_x, local_y),
                            }
            data_to_save.append(dot_datapoint)
            #plt.title(f"Distance to pole: {distance_to_pole_along_arc[0]}, arc_length: {cell_prop.arc_length[0]}")
        #plt.show()


In [102]:
len(data_to_save)

618

In [103]:
data_to_save[2]

{'position': 0,
 'time': 0,
 'trap': None,
 'cell_label': 41,
 'area': 585.0,
 'length': 55.741157281982076,
 'normalization_counts': 2,
 'internal_coord': (25.340510464455427, 1.4915226399727761),
 'normalized_internal_x': 0.45461041176923256,
 'bbox': (2082, 228, 2104, 286),
 'global_coords': (2090.4628331419594, 254.0679579379948),
 'local_coords': (8.462833141959436, 26.067957937994805)}

In [104]:
import pandas as pd

In [105]:
df = pd.DataFrame(data_to_save)

In [106]:
df

Unnamed: 0,position,time,trap,cell_label,area,length,normalization_counts,internal_coord,normalized_internal_x,bbox,global_coords,local_coords
0,0,0,,40,364.0,28.951477,1,"(8.011039603279928, -3.454163227423582)",0.276706,"(2490, 226, 2511, 258)","(2503.938373853415, 237.00935078813785)","(13.938373853415214, 11.009350788137851)"
1,0,0,,41,585.0,55.741157,2,"(48.411024534828144, 4.679151480968298)",0.868497,"(2082, 228, 2104, 286)","(2083.6719550757825, 277.44032820848327)","(1.6719550757825345, 49.440328208483265)"
2,0,0,,41,585.0,55.741157,2,"(25.340510464455427, 1.4915226399727761)",0.454610,"(2082, 228, 2104, 286)","(2090.4628331419594, 254.0679579379948)","(8.462833141959436, 26.067957937994805)"
3,0,0,,42,2081.0,149.029251,6,"(32.70655568593311, 6.623083151824417)",0.219464,"(1220, 228, 1259, 405)","(1227.7887299371946, 287.36479413817165)","(7.788729937194603, 59.364794138171646)"
4,0,0,,42,2081.0,149.029251,6,"(129.43659064821478, 3.541787385039819)",0.868531,"(1220, 228, 1259, 405)","(1232.4931852599698, 383.934628975265)","(12.493185259969778, 155.934628975265)"
...,...,...,...,...,...,...,...,...,...,...,...,...
613,0,0,,410,914.0,76.301727,2,"(11.258338367911813, -0.5400395566602205)",0.147550,"(125, 1184, 147, 1264)","(131.40410535427228, 1199.7460121182144)","(6.404105354272275, 15.746012118214367)"
614,0,0,,410,914.0,76.301727,2,"(62.83648493084219, -0.8507057531632514)",0.823526,"(125, 1184, 147, 1264)","(137.9744318181818, 1250.4667446524063)","(12.974431818181813, 66.4667446524063)"
615,0,0,,412,582.0,47.668006,1,"(25.011671640703934, 0.1284026062246491)",0.524706,"(535, 1198, 552, 1247)","(541.90560426992, 1222.9154657516838)","(6.905604269919991, 24.91546575168377)"
616,0,0,,413,369.0,37.364992,1,"(20.873044352579818, 1.0844074711967786)",0.558626,"(944, 1204, 963, 1241)","(953.1993706264751, 1224.1914206451079)","(9.199370626475115, 20.191420645107883)"


In [83]:
df2 = pd.DataFrame(data_to_save)
df3 = pd.concat([df, df2])

In [92]:
# each index will give the dot coordinate of from the dot_coords_np array
dots_inside_cells_idx = np.nonzero(dot_labels)[0]

In [93]:
# each element is the label number of the cell in this array
cell_labels_for_dots = dot_labels[dots_inside_cells_idx]

In [94]:
i = 41
dot_x, dot_y = dot_coords_np[dots_inside_cells_idx[i]]
plt.figure()
plt.imshow(cell_mask_label == cell_labels_for_dots[i])
plt.plot(dot_y, dot_x, 'ro')
plt.show()

In [57]:
len(cell_labels_for_dots)

618

In [59]:

for i in range(len(cell_labels_for_dots)):
    cell_label = cell_labels_for_dots[i]
    dot_x, dot_y = dot_coords_np[dots_inside_cells_idx[i]]
    cell_prop = props[cell_label-1]
    fit_coeff = cell_prop.fit_coeff
    poles = cell_prop.poles
    img = cell_prop.image
    img_size = img.shape
    x_data = np.arange(-0.5, img_size[1]+0.5)
    y_data = fit_coeff[0] * x_data**2 + fit_coeff[1] * x_data + fit_coeff[2]
    local_x, local_y = dot_x - cell_prop.bbox[0], dot_y - cell_prop.bbox[1]
    
    
    projected_point, _ = compute_projected_points(fit_coeff, np.array([[local_y, local_x]]))
    #distance_to_pole = np.sqrt((projected_point[0, 0] - poles[0, 1])**2 + (projected_point[0, 1] - poles[0, 0])**2)
    distance_to_pole_along_arc = compute_arc_length(fit_coeff, poles[0, 0], projected_point[0,0])

In [60]:
plt.figure()
plt.imshow(img)
plt.plot(x_data, y_data, 'r--')
plt.plot(poles[:, 0], poles[:, 1], '*')
plt.plot(local_y, local_x, 'go')
plt.plot(projected_point[0, 0], projected_point[0, 1], 'b*')
plt.title(f"Distance to pole: {distance_to_pole_along_arc[0]}, arc_length: {cell_prop.arc_length[0]}")
plt.show()

In [121]:
projected_point

array([[35.32503436,  4.16863798]])

In [106]:
projected_point, _ = compute_projected_points(fit_coeff, np.array([[local_y, local_x]]))


In [107]:
projected_point

array([[12.25525112,  4.09683844]])

In [91]:
internal_x

3.1760833790455294

In [89]:
poles[0]

array([-0.5      ,  2.5322699])

In [86]:
internal_x, internal_y

(3.1760833790455294, 8.991771804717473)

In [75]:
dot_x, dot_y

(45.17608337904553, 491.9917718047175)

In [80]:
plt.figure()
plt.imshow(cell_mask_label == cell_label)
plt.plot(dot_y, dot_x, 'ro')
plt.show()

In [82]:
cell_prop.bbox

(42, 483, 51, 528)

In [59]:
plt.figure()
plt.imshow(cell_mask_label == 31)
plt.plot(dot_coords_np[43][1], dot_coords_np[43][0], 'ro')
plt.show()

In [61]:
props[30].arc_length

array([25.06184152])

In [62]:
props[30].poles

array([[-0.5       ,  5.28563777],
       [24.5       ,  3.6100112 ]])

In [64]:
props[30].axis_major_length

26.441570332452198

In [65]:
props[30].poles

array([[-0.5       ,  5.28563777],
       [24.5       ,  3.6100112 ]])

#### Make forkplots

In [116]:
import pandas as pd
import numpy as np

# Create a sample DataFrame
df = pd.DataFrame({
    'A': np.random.randn(1000),
    'B': np.random.randn(1000),
    'C': pd.date_range('20250101', periods=1000)
})


In [117]:
df

Unnamed: 0,A,B,C
0,0.445403,-0.001812,2025-01-01
1,0.938464,1.193299,2025-01-02
2,-0.640070,1.164751,2025-01-03
3,-0.136731,1.272015,2025-01-04
4,-0.105946,-0.035150,2025-01-05
...,...,...,...
995,0.690558,-0.606212,2027-09-23
996,1.021049,1.539551,2027-09-24
997,-0.314263,-0.134031,2027-09-25
998,-0.330640,-1.569597,2027-09-26
