## Example evaluation pipeline

In [3]:
from src.models.HCNet import HCNet
from src.utils import calculate_indexes, remove_edge_cells
import src.functional
from src.cell import cell

import skimage.io as io
import torch
import click
from tqdm import tqdm
from itertools import product
import matplotlib.pyplot as plt


default_path = '/media/DataStorage/ToAnalyze/' \
               'Jul 23 Control m2.lif - TileScan 1 Merged.tif'

print('Loading Model...')
model = torch.jit.script(HCNet(in_channels=3, out_channels=4, complexity=15)).cuda()
model.load_state_dict(torch.load('Dec_17_REALLY_GOOD.hcnet'))
model.eval()


# (Z, Y, X, C)
print('Loading Image...')
image_base = io.imread(default_path)
base_im_shape = image_base.shape



out_img = torch.zeros((1, image_base.shape[2], image_base.shape[1], image_base.shape[0]), dtype=torch.int16) # (1, X, Y, Z)

max_cell = 0

x_ind = calculate_indexes(25, 613, base_im_shape[2], base_im_shape[2])
y_ind = calculate_indexes(25, 613, base_im_shape[1], base_im_shape[1])
total = len(x_ind) * len(y_ind)

with torch.no_grad():
    for (x, y) in tqdm(product(x_ind, y_ind), total=total):

        image = torch.from_numpy(image_base[:, y[0]:y[1], x[0]:x[1], [0, 2, 3]] / 2 ** 16).unsqueeze(0)
        image = image.transpose(1, 3).transpose(0, -1).squeeze().unsqueeze(0).sub(0.5).div(0.5).cuda()

        if image.max() == -1:
            continue

        out = model(image.float().cuda(), 5)
        prob_map = torch.sigmoid(out[:, -1, ...]).unsqueeze(1)
        out = out[:, 0:3:1, ...]
        embed = src.functional.vector_to_embedding(out)
        centroids = src.functional.estimate_centroids(embed, 0.001, 40)


        if centroids.nelement() == 0:
            del image, out, centroids
            continue

        out = src.functional.embedding_to_probability(out, cent, torch.tensor([0.006]))
        print(cent.shape[1])

        
        value, out = out.max(1)
        out[prob_map[:,0,...] < 0.5]=0

        max_cell = out.max()
        out = out[..., 0:image.shape[-1]].cpu().to(out_img.dtype)

        # post processing
        out = remove_edge_cells(out)
        # out = remove_small_cells(out)

        out_img[:, x[0]:x[1]-1, y[0]:y[1]-1:, 0:image.shape[-1]-1][out != 0] = out[out != 0]

        del image, out, cent, value
        

torch.save(out_img, 'out_image.trch')
out_img = out_img.squeeze(0).int().numpy().transpose((2,1,0))
io.imsave('figures/big_img.tif', out_img)        

curve, percent, apex = src.functional.get_cochlear_length(out_img > 0, 10)
print(curve.shape)

plt.imshow(out_img[0,...].gt(0).sum(-1).gt(3))
plt.plot(curve[0,:], curve[1,:], 'r')
plt.show()

cells = []
for u in tqdm(torch.unique(out_img)):
    cells.append(cell(image_base.unsqueeze(0), (out_img == u).unsqueeze(0)))

torch.save(cells, 'cells.trch')


Loading Model...
Loading Image...


100%|██████████| 210/210 [07:30<00:00,  2.14s/it]


AttributeError: 'numpy.ndarray' object has no attribute 'gt'

In [4]:
out_img.max()

0

In [None]:
out_img =  torch.load ('out_image.trch')
print(out_img.max())
out_img = out_img.squeeze(0).float().numpy().transpose((2,1,0))
io.imsave('figures/big_img.tif', out_img)

In [None]:
from src.models.HCNet_legacy import HCNet
from src.utils import calculate_indexes, remove_edge_cells
import src.functional
import skimage.io as io
import torch
import click
from tqdm import tqdm
from itertools import product
import matplotlib.pyplot as plt

default_path = '/media/DataStorage/ToAnalyze/' \
               'Jul 23 Control m2.lif - TileScan 1 Merged.tif'
    
image_base = io.imread(default_path)
out_img = torch.load('out_image.trch')
curve, percent, apex = src.functional.get_cochlear_length(out_img > 0, .5)
print(curve.shape)

plt.figure(figsize = (15,15))
plt.imshow(image_base.max(0)[:,:,[1,2,3]] / 2**16)
plt.plot(curve[1,:], curve[0,:], 'r', linewidth = 5)
plt.savefig('curveature.png')
plt.show()

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize = (15,15))
plt.imshow(image_base.max(0)[:,:,[1,2,3]] / 2**16)
plt.plot(curve[1,:], curve[0,:], 'r', linewidth = 5)
plt.show()

In [None]:
image_base.max(0).max()

## Watershed Example

In [None]:
import skimage.segmentation
import skimage.io as io
import numpy as np
import matplotlib.pyplot as plt


image = io.imread('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/data/'\
                  'test/C2-Jul-1-AAV2-PHP.B-CMV3-m2.lif---m2.og.tif')
image.shape
segments_watershed = skimage.segmentation.watershed(image[..., 3], markers=500, compactness=20)

plt.imshow(skimage.segmentation.mark_boundaries(segments_watershed[22,...], image[22,:,:,[0,2,3]]))
io.imsave('watershed.tif',segments_watershed)
plt.imshow(skimage.segmentation.mark_boundaries(segments_watershed[22,...], image[22,:,:,[0,2,3]]))
plt.show()

## Transforms Example

In [None]:
import src.dataloader
import src.transforms as t
import matplotlib.pyplot as plt
import torchvision.transforms
from torch.utils.data import DataLoader

print('Loading Images...')
transforms = torchvision.transforms.Compose([
    t.save_image('figures/0_base_image'),
    t.nul_crop(rate=1),
    t.save_image('figures/1_nul_crop'),
    t.random_crop(shape=(256, 256, 16)),
    t.save_image('figures/2_random_crop'),
    t.elastic_deformation(grid_shape=(3, 3, 2), scale=1.5),
    t.save_image('figures/3_elastic_deformation'),
    t.to_cuda(),
    t.random_h_flip(rate=1),
    t.save_image('figures/4_horizontal_flip'),
    t.random_v_flip(rate=1),
    t.save_image('figures/5_vertical_flip'),
    t.random_affine(shear=(-15, 15)),
    t.save_image('figures/6_random_affine'),
    t.adjust_brightness(range_brightness = (-0.2, 0.2)),
    t.save_image('figures/7_adjust_brightness'),
    t.adjust_gamma(),
    t.save_image('figures/8_adjust_gamma'),
    t.adjust_centroids(),
])
data = src.dataloader.dataset('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/data/validate', transforms=transforms)
dl = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)
print('Done')

_ = data[0]

In [None]:
import src.cell
import torch
import matplotlib.pyplot as plt
import numpy as np

cells = torch.load('cells.trch')

gfp = []
dapi = []
myo = []
actin = []
volume = []

for c in cells:
    gfp.append(c.gfp.item())
    dapi.append(c.dapi.item())
    myo.append(c.myo7a.item())
    actin.append(c.actin.item())
    volume.append(c.volume.item())
    
myo = torch.tensor(myo)
gfp = torch.tensor(gfp)
dapi = torch.tensor(dapi)
actin = torch.tensor(actin)
volume = torch.tensor(volume)

plt.hist(gfp[myo > 0.05].numpy(), bins=20,color = 'g')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.xlabel('GFP Cell Intensity')
plt.show()

plt.hist(dapi[myo > 0.05].numpy(), bins=20,color = 'b')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.xlabel('DAPI Cell Intensity')
plt.show()

plt.hist(myo[myo > 0.05].numpy(), bins=20, color = 'y')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.xlabel('DAPI Cell Intensity')
plt.show()

plt.hist(volume[myo > 0.05].numpy(), bins=20, color = 'k')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.xlabel('DAPI Cell Intensity')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import torch
import skimage.io as io

rcdnet = io.imread('/media/DataStorage/Dropbox (Partners HealthCare)/mini-CMV project/Jul 18 AAV2-PHP.B-CMV4 m2.lif - TileScan 1 Merged_segmentation.tif')
unet = io.imread('/media/DataStorage/ToAnalyze/Jul 18 AAV2-PHP.B-CMV4 m2.lif - TileScan 1 Merged_cellBycell/test_unqiue_mask.tif')

rcdnet = torch.from_numpy(rcdnet)
unet = torch.from_numpy(unet)

unique, counts = torch.unique(rcdnet, return_counts=True)
plt.hist(counts.numpy())

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.xlabel('DAPI Cell Intensity')

plt.show()


In [None]:
unique, counts = torch.unique(rcdnet, return_counts=True)
plt.figure(figsize=(10,5))
plt.hist(counts[unique != 0].numpy(), bins=40,alpha=0.5)


unique, counts = torch.unique(unet, return_counts=True)
plt.hist(counts[unique != 0].numpy(), bins=40,alpha=0.5)
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.legend(['HCNet','UNet'])
plt.xlabel('Cell Volumes')
plt.axvline(5000, color='k')
plt.axvline(15000, color='k')
plt.savefig('volumes.png',dpi=400)

plt.show()