In [14]:
import sys
import os
sys.path.insert(0,os.path.abspath('..'))
from spatial_graphs.AmiraSpatialGraph import AmiraSpatialGraph,MatchBarrels
from spatial_graphs.Landmarks import Landmarks
from spatial_graphs.Surfaces import Surface
from spatial_graphs.Vectors import Vectors
from spatial_graphs.Alignment import Alignment
from dask import compute,multiprocessing,delayed
import pathlib
import shutil
import glob
import pandas as pd
import vtk
from scipy.spatial import distance
import numpy as np
from sklearn.decomposition import PCA
import SimpleITK as sitk

In [6]:
def erode_image(im,radius):
    erode = sitk.BinaryErodeImageFilter()
    erode.SetKernelRadius(radius)
    erode.SetKernelType(sitk.sitkAnnulus)
    erode.SetBackgroundValue(0)
    erode.SetForegroundValue(255)
    eroded = erode.Execute(im)
    
    return eroded

In [7]:
def convertPointsToImage2D(pts,templateimg):
    arrim = sitk.Image(templateimg.GetWidth(),templateimg.GetHeight(),templateimg.GetPixelID())
    arr = sitk.GetArrayFromImage(arrim)
    if len(pts)>0:
        apts = np.transpose([list(map(int,np.array(pts)[:,0])),list(map(int,np.array(pts)[:,1]))])
        arr[apts[:,1],apts[:,0]] = 255
    im = sitk.GetImageFromArray(arr)
    im.SetOrigin([templateimg.GetWidth(),templateimg.GetHeight(),0])
    return im

In [8]:
def scaleToImageCoords2D(points,XY_RESOLUTION=1,Z_RESOLUTION=1):
    _points = np.reshape(np.array(points,dtype=np.int64),[len(points),2])
    _points[:,0] = list(map(round,(_points[:,0] / XY_RESOLUTION)))
    _points[:,1] = list(map(round,(_points[:,1] / XY_RESOLUTION)))
    #_points[:,2] = list(map(round,(_points[:,2] / Z_RESOLUTION)))
   
    return _points

In [18]:
def getMidLine(ptlist,limit):
    newptlist = []
    #print(ptlist)
    if len(ptlist) < (limit/2):
        for i in range(len(ptlist)-1):
            newptlist.append(ptlist[i])
            newptlist.append([(ptlist[i][0]+ptlist[i+1][0])/2,(ptlist[i][1]+ptlist[i+1][1])/2])
        newptlist.append(ptlist[i+1])
        ptlist = newptlist
        #print(ptlist)
        return getMidLine(ptlist,limit)
        #return newptlist
    else:
        return (ptlist)

In [19]:
def resample_contours(contours,res,downscale=1):
    resampled_conts_list = []
    for cnt in contours:
        resampled_cont_list = []
        for i in range(len(cnt)):
            if i == len(cnt)-1:
                pt1 = np.reshape(np.array(cnt)[i,:],[1,3])
                pt2 = np.reshape(np.array(cnt)[0,:],[1,3])
            else:
                pt1 = np.reshape(np.array(cnt)[i,:],[1,3])
                pt2 = np.reshape(np.array(cnt)[i+1,:],[1,3])
            
            dist = (distance.cdist(pt1,pt2))
            num_pts_inserted = int(dist / res)
            #print('num_pts_inserted: {}'.format(num_pts_inserted))
            m = (pt2[:,1]-pt2[:,0]) / (pt1[:,1]-pt1[:,0])

            #print('needed_{}'.format(num_pts_inserted))
            ptlist = []
            ptlist.append(cnt[i])
            if i == len(cnt)-1:
                ptlist.append(cnt[0])
            else:
                ptlist.append(cnt[i+1])
            ptlist = getMidLine(ptlist,num_pts_inserted)
            #print('added{}'.format(len(ptlist)))
            for pt in ptlist:
                resampled_cont_list.append([pt[0]/downscale,pt[1]/downscale,])
        resampled_conts_list.append(resampled_cont_list)
    return resampled_conts_list

In [23]:
def dilate_image(im,radius):
    dilate = sitk.BinaryDilateImageFilter()
    dilate.SetKernelRadius(radius)
    dilate.SetKernelType(sitk.sitkAnnulus)
    dilate.SetBackgroundValue(0)
    dilate.SetForegroundValue(255)
    dilated = dilate.Execute(im)
    
    return dilated

In [24]:
def convertContourCoordsIntoBinaryImage(cont,im,resample_res = 0):
    if resample_res != 0:
        [resampled_pia] = resample_contours([cont],resample_res)
    pia_imcords = scaleToImageCoords2D(resampled_pia,XY_RESOLUTION=XY_RES)
    pia_cont_im = convertPointsToImage2D(pia_imcords,im)
    pia_cont_im.SetOrigin(im.GetOrigin())
    pia_hole_filled = erode_image(sitk.BinaryFillhole(dilate_image(pia_cont_im,1),foregroundValue=255,fullyConnected=False),1)
    
    return pia_hole_filled

In [33]:
XY_RES = 0.868
Z_RES = 50

#exp_name = 'MG50_rhs/'
exp_name = 'MG48_3Day_bs'
#input_image_path =  '/nas1/Data_Mythreya/MotorCortexProject/V0/0_Inputs/Images/Rabies/NeuN/' + exp_name
#input_contours_path = '/nas1/Data_Mythreya/MotorCortexProject/V0/0_Inputs/Contours/Rabies/' + exp_name
#output_path = '/nas1/Data_Mythreya/MotorCortexProject/V0/0_Inputs/Landmarks/Rabies/NeuN/' + exp_name

input_image_path =  '/nas1/Data_aman/00_Rabies/{}/00_Images/00_Confocal/ch_00_stacks/'.format(exp_name)
input_contours_path = '/nas1/Share/Project_Rabies/MG48_bs_contour_labeled/'
output_path = '/nas1/Data_aman/00_Rabies/{}/00_Images/00_Confocal/masked_for_rabies_auto_detection/'.format(exp_name)

# output_path = '/rall/rabies_neun_counts/' + exp_name
pathlib.Path(output_path).mkdir(exist_ok= True)
#sec_num_start = 1
#sec_num_end = 1
sec_num = 21

In [38]:
for sec_num in [1,11,21,31,41,51,61,71,82,91]:
    im = sitk.ReadImage(input_image_path+'S{:03d}_00.tif'.format(sec_num))

    spatial_graph_file = input_contours_path + 'S{:03d}.am'.format(sec_num)
    sg = AmiraSpatialGraph(spatial_graph_file)
    pia_bin = convertContourCoordsIntoBinaryImage(sg.pia.edge_pt_coords,im,resample_res=0.1)
    pia_only = sitk.Mask(im,pia_bin)
    sitk.WriteImage(pia_only,output_path+'S{}_Pia_Only.tif'.format(sec_num))