In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pathlib
from pathlib import Path
from skimage.io import imread
import torch.nn.functional as F
import matplotlib as mpl
mpl.use('QtAgg')
%matplotlib qt

#### Hard coded constants

In [2]:
waveletPlanes = 2
levels = 3
fluor_image_path = Path('/home/pk/Documents/waveletCode/data/img_000000000.tiff')
seg_image_path = Path('/home/pk/Documents/waveletCode/data/img_000000000_mask.tiff')
fluor_img = imread(fluor_image_path).astype('float32')
seg_image = imread(seg_image_path)

In [3]:
fluor_img_tensor = torch.from_numpy(fluor_img)
anscombe_trans_tensor = 2 * torch.sqrt(fluor_img_tensor + 3/8)

In [4]:
anscombe_trans_tensor = 2 * torch.sqrt(fluor_img_tensor + 3/8)

In [5]:
anscombe_trans_image = 2 * np.sqrt(fluor_img+3/8)

In [6]:
anscombe_trans_tensor.device

device(type='cpu')

In [7]:
from rtseg.dotdetection.detect import compute_wavelet_planes

In [8]:
w = compute_wavelet_planes(anscombe_trans_tensor, device='cpu')

In [9]:
w.shape

torch.Size([4, 1404, 3200])

In [10]:
%timeit w = compute_wavelet_planes(anscombe_trans_tensor, device='cpu')

1.67 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
w_gpu = compute_wavelet_planes(anscombe_trans_tensor, device='cuda:0')

In [12]:
%timeit w_gpu = compute_wavelet_planes(anscombe_trans_tensor, device='cuda:0')

3.05 ms ± 66.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
w.device, w_gpu.device

(device(type='cpu'), device(type='cuda', index=0))

In [12]:
torch.max(torch.abs((w - w_gpu.cpu()))) 

tensor(1.9073e-05)

In [13]:
def plot_wavelets(wavelet_planes):
    fig, ax = plt.subplots(nrows=2, ncols=2)
    for i in range(4):
        ax[i//2, i % 2].imshow(wavelet_planes[i].numpy(), cmap='gray')
        ax[i//2, i % 2].set_title(f"i = {i}th plane")
    plt.show()

In [14]:
plot_wavelets(w_gpu.cpu())

In [15]:
plot_wavelets(w)

In [16]:
seg_image = torch.from_numpy(seg_image)

#### Construction binary mask of the spots using the wavelet planes

There four levels and you need to filter all levels and pick the workplane level as
filtered image. After his clean the filetered image using binaray morphological 
operations

In [17]:
def binary_spot_mask(w, seg_mask, wavelet_plane_no=2, device='cpu', noise_threshold=3.0,
                          noise_level_division=0.7):
    w = w.to(device)
    seg_mask = seg_mask.to(device)
    ind = seg_mask > 0

    w_masked = w[:, ind]
    w_mean = torch.mean(w_masked, dim=1)[:, None]
    noise_level = torch.median(torch.abs(w_masked - w_mean), dim=1).values / noise_level_division
    threshold = noise_threshold * noise_level
    
    filtered_image = torch.zeros_like(w[wavelet_plane_no], device=device)

    filtered_image[torch.abs(w[wavelet_plane_no]) >= threshold[wavelet_plane_no]] = 1;

    filtered_image *= w[wavelet_plane_no]

    binary_mask = torch.zeros_like(filtered_image, device=device)
    binary_mask[filtered_image > 0] = 1

    # clean up the binary mask
    
    
    return binary_mask

In [29]:
spot_mask = binary_spot_mask(w, seg_image, wavelet_plane_no=1, device='cuda:0', noise_threshold=3.0)

In [30]:
spot_mask.device

device(type='cuda', index=0)

In [31]:
plt.figure()
plt.imshow(spot_mask.cpu())
plt.show()

#### Clean up kernels

We clean using convolutions on the GPU to do the binary morphological equivalents found in matlab

1. clean --> remove isolated pixels

2. hbreak -> remove two horizontal lines connected by one pixel

3. spur -> remove spurious hanging pixels # we skip the spur for now
4. clean -> clean again
5. thicken -> probably a bit more complicated, dilate the spot later

In [32]:
spot_mask.device

device(type='cuda', index=0)

In [33]:
spot_mask.shape

torch.Size([1404, 3200])

In [37]:
def clean_spots(image):
    clean_kernel = torch.ones((1, 1, 3, 3)).to(image.device)
    conved = F.conv2d(image, clean_kernel, padding='same')
    image[conved == 1] = 0
    return image

def clean_hbreak(image):
    kernel1 = torch.tensor([[1.0, 1.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]])[None, None, :].to(image.device)
    kernel2 = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0]])[None, None, :].to(image.device)
    conved1 = F.conv2d(image, kernel1, padding='same')
    conved2 = F.conv2d(image, kernel2, padding='same')
    result = torch.logical_and(conved1 == 7, conved2 == 2)
    image[result] = 0
    return image

def clean_vbreak(image):
    kernel1 = torch.tensor([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0], [1.0, 0.0, 1.0]])[None, None, :].to(image.device)
    kernel2 = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]])[None, None, :].to(image.device)
    conved1 = F.conv2d(image, kernel1, padding='same')
    conved2 = F.conv2d(image, kernel2, padding='same')
    result = torch.logical_and(conved1 == 7, conved2 == 2)
    image[result] = 0
    return image

In [38]:
clean_spots(spot_mask[None,None,:])

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]], device='cuda:0')

In [36]:
plt.figure()
plt.imshow(clean_spots(spot_mask[None,None,:]).cpu().numpy()[0][0], cmap='gray')
plt.show()

In [None]:
def gen_image(shape=(100, 100)):
    prob_dist = torch.ones(shape) * 0.5
    dist = torch.distributions.bernoulli.Bernoulli(probs=prob_dist)
    return dist.sample()[None, None, :]

In [80]:
image = gen_image()
plt.figure()
plt.imshow(image[0][0].numpy(), cmap='gray')
plt.show()

In [81]:
def clean_spots(image):
    clean_kernel = torch.ones((1, 1, 3, 3)).to(image.device)
    conved = F.conv2d(image, clean_kernel, padding='same')
    image[conved == 1] = 0
    return image

In [82]:
plt.figure()
plt.imshow(clean_spots(image)[0][0].numpy(), cmap='gray')
plt.show()

In [92]:
plt.figure()
plt.imshow(image[0][0].numpy(), cmap='gray')
plt.show()

#### We do hbreak and vbreak at the same time, why not?

In [94]:
def clean_hbreak(image):
    kernel1 = torch.tensor([[1.0, 1.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]])[None, None, :]
    kernel2 = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0]])[None, None, :]
    conved1 = F.conv2d(image, kernel1, padding='same')
    conved2 = F.conv2d(image, kernel2, padding='same')
    result = torch.logical_and(conved1 == 7, conved2 == 2)
    image[result] = 0
    return image

In [95]:
def clean_vbreak(image):
    kernel1 = torch.tensor([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0], [1.0, 0.0, 1.0]])[None, None, :]
    kernel2 = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]])[None, None, :]
    conved1 = F.conv2d(image, kernel1, padding='same')
    conved2 = F.conv2d(image, kernel2, padding='same')
    result = torch.logical_and(conved1 == 7, conved2 == 2)
    image[result] = 0
    return image

In [39]:
plt.figure()
plt.imshow(clean_hbreak(image)[0][0].numpy(), cmap='gray')
plt.show()

NameError: name 'image' is not defined

In [97]:
plt.figure()
plt.imshow(clean_vbreak(image)[0][0].numpy(), cmap='gray')
plt.show()

### Finding spot centroids

In [123]:
from skimage.measure import regionprops, label, regionprops_table

In [124]:
def find_spot_centroids(spot_bin_mask, fluor_image, min_spot_area=0, max_axes_ratio=1.7):
    pass

In [133]:
%timeit spot_stats = regionprops_table(label(bin_mask_cpu), fluor_img, properties=['centroid_weighted', 'axis_major_length', 'axis_minor_length', 'area', 'orientation', 'coords'])

356 ms ± 2.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [164]:
spot_stats = regionprops(label(bin_mask_cpu), fluor_img)

In [165]:
len(spot_stats)

1576

In [173]:
min_spot_area = 5
max_spot_area = 100
max_axes_ratio = 1.7

In [174]:
area = [(i, spot['area']) 
            for i, spot in enumerate(spot_stats) if spot['area'] >=  min_spot_area and spot['area'] <= max_spot_area ]

In [175]:
spots_filtered = [spot_stats[spot_idx] for spot_idx, spot_area in area]

In [176]:
len(spots_filtered)

1420

In [177]:
axes_ratio = [spot['axis_major_length'] / spot['axis_minor_length'] for spot in spots_filtered]

In [183]:
spots_filtered[3]['centroid_weighted']

(218.58425921621495, 489.79624378225094)

In [185]:
overlap_indices = [idx for (idx, axes_ratio) in enumerate(axes_ratio) if axes_ratio > max_axes_ratio]

In [187]:
len(overlap_indices)

357

In [188]:
overlap_indices

[0,
 6,
 7,
 13,
 18,
 21,
 22,
 23,
 25,
 26,
 29,
 30,
 36,
 38,
 42,
 46,
 57,
 67,
 70,
 71,
 76,
 80,
 88,
 91,
 101,
 103,
 110,
 119,
 121,
 129,
 132,
 134,
 137,
 139,
 142,
 144,
 148,
 149,
 150,
 151,
 159,
 164,
 168,
 169,
 178,
 184,
 186,
 189,
 193,
 194,
 198,
 199,
 200,
 205,
 210,
 213,
 215,
 217,
 220,
 222,
 224,
 225,
 226,
 231,
 232,
 236,
 238,
 240,
 241,
 245,
 248,
 249,
 250,
 254,
 255,
 258,
 259,
 266,
 267,
 276,
 277,
 286,
 291,
 294,
 298,
 301,
 304,
 310,
 318,
 319,
 322,
 325,
 326,
 328,
 329,
 331,
 333,
 334,
 353,
 362,
 367,
 369,
 370,
 372,
 374,
 377,
 380,
 385,
 394,
 395,
 400,
 406,
 413,
 417,
 422,
 426,
 431,
 434,
 440,
 441,
 443,
 445,
 448,
 449,
 452,
 460,
 469,
 470,
 472,
 481,
 483,
 499,
 505,
 510,
 527,
 530,
 532,
 541,
 548,
 552,
 554,
 555,
 559,
 564,
 567,
 573,
 576,
 577,
 579,
 582,
 585,
 587,
 590,
 591,
 594,
 596,
 608,
 610,
 618,
 621,
 622,
 625,
 631,
 633,
 639,
 640,
 648,
 653,
 655,
 656,
 659,
 

#### Working with examples

In [253]:
def sd_intensity(regionmask):
    print(f"Region mask shape: {regionmask.shape}")
    return np.mean(regionmask)

In [260]:
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from skimage.draw import ellipse
from skimage.measure import label, regionprops, regionprops_table
from skimage.transform import rotate


image = np.zeros((600, 600))

rr, cc = ellipse(300, 350, 100, 220)
image[rr, cc] = 1

image = rotate(image, angle=-15, order=0)

rr, cc = ellipse(100, 100, 60, 50)
image[rr, cc] = 1

label_img = label(image)
regions = regionprops(label_img, extra_properties=(sd_intensity,))

In [255]:
regions = regionprops(label_img, extra_properties=(sd_intensity,))

In [257]:
regions[0].sd_intensity

Region mask shape: (119, 99)


0.7973007384772091

In [258]:
regions[1].sd_intensity

Region mask shape: (225, 429)


0.7155969955969956

In [251]:
sd_intensity

<function __main__.sd_intensity(regionmask)>

In [252]:
list(regions[0])

['area',
 'area_bbox',
 'area_convex',
 'area_filled',
 'axis_major_length',
 'axis_minor_length',
 'bbox',
 'centroid',
 'centroid_local',
 'coords',
 'eccentricity',
 'equivalent_diameter_area',
 'euler_number',
 'extent',
 'feret_diameter_max',
 'image',
 'image_convex',
 'image_filled',
 'inertia_tensor',
 'inertia_tensor_eigvals',
 'label',
 'moments',
 'moments_central',
 'moments_hu',
 'moments_normalized',
 'orientation',
 'perimeter',
 'perimeter_crofton',
 'slice',
 'solidity']

In [244]:
plt.figure()
plt.imshow(label_img)
plt.show()

In [245]:
regions[0].label

1

In [246]:
regions[1].label

2

In [None]:
def sd_intensity(region

In [211]:
label_img

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [212]:
plt.figure()
plt.imshow(label_img)
plt.show()

In [201]:
from skimage.io import imsave

In [208]:
imsave(Path('/home/pk/Documents/waveletCode/data/test.tiff'), label_img.astype('uint16'),
           plugin='tifffile', check_contrast=False)

In [220]:
fig, ax = plt.subplots()
ax.imshow(image, cmap=plt.cm.gray)

for props in regions:
    y0, x0 = props.centroid
    print(f"(y0, x0): ({y0}, {x0})")
    orientation = props.orientation
    print(f"Orientation: {orientation}")
    x1 = x0 + math.cos(orientation) * 0.5 * props.axis_minor_length
    y1 = y0 - math.sin(orientation) * 0.5 * props.axis_minor_length
    x2 = x0 - math.sin(orientation) * 0.5 * props.axis_major_length
    y2 = y0 - math.cos(orientation) * 0.5 * props.axis_major_length

    ax.plot((x0, x1), (y0, y1), '-k', linewidth=2.5)
    ax.plot((x0, x2), (y0, y2), '-r', linewidth=2.5)
    ax.plot(x0, y0, '.g', markersize=15)

    minr, minc, maxr, maxc = props.bbox
    bx = (minc, maxc, maxc, minc, minc)
    by = (minr, minr, maxr, maxr, minr)
    ax.plot(bx, by, '-b', linewidth=2.5)

ax.axis((0, 600, 600, 0))
plt.show()

(y0, x0): (100.0, 100.0)
Orientation: 0.0
(y0, x0): (313.0515975851635, 348.1505219116008)
Orientation: 1.3089215222380706


In [197]:
-1.308 * 180 / math.pi

-74.94287960311168

In [219]:
list(regions[0])

['area',
 'area_bbox',
 'area_convex',
 'area_filled',
 'axis_major_length',
 'axis_minor_length',
 'bbox',
 'centroid',
 'centroid_local',
 'coords',
 'eccentricity',
 'equivalent_diameter_area',
 'euler_number',
 'extent',
 'feret_diameter_max',
 'image',
 'image_convex',
 'image_filled',
 'inertia_tensor',
 'inertia_tensor_eigvals',
 'label',
 'moments',
 'moments_central',
 'moments_hu',
 'moments_normalized',
 'orientation',
 'perimeter',
 'perimeter_crofton',
 'slice',
 'solidity']

In [261]:
import cv2

In [262]:
cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))

array([[0, 0, 1, 0, 0],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [0, 0, 1, 0, 0]], dtype=uint8)