In [12]:
import sys
import os
sys.path.insert(0,os.path.abspath('..'))
from spatial_graphs.AmiraSpatialGraph import AmiraSpatialGraph,MatchBarrels
from spatial_graphs.Landmarks import Landmarks
from spatial_graphs.Surfaces import Surface
from spatial_graphs.Vectors import Vectors
from spatial_graphs.Alignment import Alignment
from dask import compute,multiprocessing,delayed
import pathlib
import shutil
import glob
import pandas as pd
import vtk
from scipy.spatial import distance
import numpy as np
import SimpleITK as sitk
import itk

In [4]:
def add_z_pt_above_below(pt,z_offset,pt_list):
    above = [pt[0],pt[1],pt[2]+z_offset]
    below = [pt[0],pt[1],pt[2]-z_offset]
    pt_list.append(below)
    pt_list.append(above)
    pt_list.append(pt)
    return pt_list

In [5]:
def validate_pt(pt,im):
    if pt[0]<0:
        pt[0] = 0
    if pt[0]>im.GetSize()[0]:
        pt[0] = im.GetSize()[0]
    if pt[1] < 0:
        pt[1] = 0
    if pt[1] > im.GetSize()[1]:
        pt[1] = im.GetSize()[1]
        
    return pt

In [6]:
def get_min_max(pts):
    min_x = 9999
    max_x = 0
    min_y = 9999
    max_y = 0
    
    for pt in pts:
        if pt[0] < min_x:
            min_x = pt[0]
        if pt[0] > max_x:
            max_x = pt[0]
        if pt[1] < min_y:
            min_y = pt[1]
        if pt[1] > max_y:
            max_y = pt[1]
    return min_x,max_x,min_y,max_y

In [7]:
def get_real_foreground_im(bb_pts,masked_im):
    # get actual ROI as foreground
    pt_list=[]
    pt_list = add_z_pt_above_below(bb_pts[0],0.5,pt_list)
    pt_list = add_z_pt_above_below(bb_pts[1],0.5,pt_list)
    pt_list = add_z_pt_above_below(bb_pts[2],0.5,pt_list)
    pt_list = add_z_pt_above_below(bb_pts[3],0.5,pt_list)

    surf=Surface(pts=pt_list).create_delunay_surface_3d(return_hull=True,output_filename=output_path+'bla.vtk')

    np_im = sitk.GetArrayFromImage(masked_im)
    nonzeros = np.where(np_im>0)
    inds_pts = np.array([nonzeros[0],nonzeros[1],np.zeros_like(nonzeros[0])]).transpose()

    l = Landmarks(pts=inds_pts)
    selected = l.get_landmarks_within_given_surface(surf)
    print(selected)
    
    return selected

In [8]:
def get_bounding_plane(cube,z_coord,tx_mat,im,neg_x=False,coronal=False):
    
    intersectoin_plane = Surface(polydata=cube.get_intersection_plane(z_coord))
    
    if intersectoin_plane.surface.GetNumberOfPoints() > 3:
        
        
        if neg_x:
            neg_mat = [-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]
            intersectoin_plane.apply_transformation(neg_mat,inverse=True)
        if coronal:
            coronal_tr_mat = [-1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1]
            intersectoin_plane.apply_transformation(coronal_tr_mat,inverse=True)
            #intersectoin_plane.write_surface_mesh(output_path+'before.vtk')
        intersectoin_plane.apply_transformation(tx_mat,inverse=True)
#         intersectoin_plane_cube = Surface(polydata=intersectoin_plane.create_delunay_surface_3d(return_hull=True,\
#                                                              output_filename=output_path+'vS1_plane.vtk'))
        #intersectoin_plane.write_surface_mesh(output_path+'{}_sec_cutting_plane.vtk'.format(sec_num))
        bounds = intersectoin_plane.surface.GetBounds()
        #pt_list = []
        
        #vtkpts = intersectoin_plane.surface.GetPoints()
        #print(selected)
        #XY_RES = 1
         
#         pt1 = [int(vtkpts.GetPoint(0)[0]/XY_RES),int(vtkpts.GetPoint(0)[1]),0]
#         pt2 = [int(vtkpts.GetPoint(1)[0]/XY_RES),int(vtkpts.GetPoint(1)[1]),0]
#         pt3 = [int(vtkpts.GetPoint(2)[0]/XY_RES),int(vtkpts.GetPoint(2)[1]),0]
#         pt4 = [int(vtkpts.GetPoint(3)[0]/XY_RES),int(vtkpts.GetPoint(3)[1]),0]
        
        # get im inds within the BB of ROI
#         if neg_x:
#             pt1 = validate_pt([-int(bounds[0]/XY_RES),int(bounds[2]/XY_RES),0],im)
#             pt2 = validate_pt([-int(bounds[1]/XY_RES),int(bounds[3]/XY_RES),0],im)
#         else:
        pt1 = validate_pt([int(bounds[0]/XY_RES),int(bounds[2]/XY_RES),0],im)
        pt2 = validate_pt([int(bounds[1]/XY_RES),int(bounds[3]/XY_RES),0],im)
        #pt3 = validate_pt(pt3,im)
        #pt4 = validate_pt(pt4,im)
        
        return [pt1,pt2]
    else:
        return None

In [9]:
def mask_im_region_from_bb(np_im,bb_pts):
    #np_im = sitk.GetArrayFromImage(im)
    #np_im[:,:] = 0
    
    min_x,max_x,min_y,max_y = get_min_max(bb_pts)
    print(min_x,max_x,min_y,max_y)
    np_im[bb_pts[0][1]:bb_pts[1][1],bb_pts[0][0]:bb_pts[1][0]] = 255
    masked_im = sitk.GetImageFromArray(np_im)
    #sitk.WriteImage(masked_im,'{}_binary_masked.tif'.format(op_filename))
    
    return np_im

In [10]:
def mark_binary_im(masked_im):
    contour_im = sitk.BinaryContour(masked_im,backgroundValue=0,foregroundValue=255)
    dil = sitk.BinaryDilateImageFilter()
    dil.SetKernelRadius(10)
    dil.SetBackgroundValue(0)
    dil.SetForegroundValue(255)
    contour_dilated = dil.Execute(contour_im)
    marked_im = sitk.Or(im,contour_dilated)
    return marked_im
    

In [3]:
animal_name = 'MG49_rhs'
image_name = 'MG49_3Day_rhs'
section_sg_path = '/nas1/Data_Mythreya/MotorCortexProject/V9/vM1_Ref_Frame/Original_Data/Spatial_Graphs/misc/Section_Graphs/{}.am/'.format(animal_name)

tx_file = section_sg_path +'aligntobarrel.hx'
file_pattern = '*S*_ct.am'
images_path = '/nas1/Data_aman/00_Rabies/{}/00_Images/00_Confocal/ch_01_stacks/'.format(image_name)
image_file_pattern = 'S*_01.tif'
#rabies_landmarks_path = '/nas1/Data_Mythreya/MotorCortexProject/V4/0_Inputs/SpatialGraphs/Section_Graphs/Rabies/{}/rabies/'.format(animal_name)
#neun_landmarks_path = '/nas1/Data_Mythreya/MotorCortexProject/V0/0_Inputs/Landmarks/Manual/Rabies/NeuN/{}/'.format(animal_name)

surface_path_vS1 = '/nas1/Data_Mythreya/MotorCortexProject/V9/vM1_Ref_Frame/Outputs/Surfaces/{}_s1_hull.vtk'.format(animal_name)
surface_path_vM1 = '/nas1/Data_Mythreya/MotorCortexProject/V9/vM1_Ref_Frame/Outputs/vM1_Ref_Surfaces/{}_vM1.vtk'.format(animal_name)

output_path = '/nas1/Data_Mythreya/MotorCortexProject/Images_For_NeuN_Count/{}/'.format(animal_name)

XY_RES = 0.868

In [11]:
#sg_3d = AmiraSpatialGraph(glob.glob(path + file_pattern)[0],read_header_only=True)
#landmarks3d = Landmarks()
#landmarks3d_neun = Landmarks()
z_cood_offset = 50
vs1_hull = Surface(surface_path_vS1)
#s1_cube = Surface(polydata=s1_hull.create_delunay_surface_3d(make_cube=True,return_hull=True,\
#                                                             output_filename=output_path+'vS1_Cube.vtk'))

vm1_hull = Surface(surface_path_vM1)
#vm1_cube = Surface(polydata=vm1_hull.create_delunay_surface_3d(make_cube=True,return_hull=True,\
#                                                               output_filename=output_path+'vM1_Cube.vtk'))
#write_cube(vm1_cube.surface,output_path+'vM1_Cube.vtk')


for file in sorted(glob.glob(images_path + image_file_pattern)):
    #sg = AmiraSpatialGraph(file,axis_directions=[1,1,1])
    #print(file,tx_file)
    
    # read tx mat from the file
    with open(tx_file,'r') as f:
        lines = f.readlines()
        
    sec_num = int(os.path.basename(file)[1:-7])
    print(sec_num)
    #sec_num = 49
    
    str_to_compare = '"S{:03d}_ct.am" setTransform '.format(sec_num)
    #print(str_to_compare)
    tx_mat = []
    tx_mat_np = np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]).reshape(4,4)
    for line in lines:
        if line.startswith(str_to_compare):
            #tx_mat = (line[len(str_to_compare):-1].split(' '))
            for num in (line[len(str_to_compare):-1].split(' ')):
                tx_mat.append(float(num))
            tx_mat_np = np.array(tx_mat).reshape(4,4)
            break
    
    if os.path.exists(images_path+'S{}_01.tif'.format(sec_num)):
        im = sitk.ReadImage(images_path+'S{}_01.tif'.format(sec_num))
        npim = sitk.GetArrayFromImage(im)
        npim[:] = 0
        # get s1 and m1 bounding box for this plane
        s1_bb_pts = get_bounding_plane(vs1_hull,(sec_num-1) * z_cood_offset,tx_mat_np,im, neg_x=True,coronal=False)
        m1_bb_pts = get_bounding_plane(vm1_hull,(sec_num-1) * z_cood_offset,tx_mat,im, neg_x=True,coronal=False)
        print(s1_bb_pts)
        print(m1_bb_pts)
        #Landmarks(pts=[s1_bb_pts[0]]).write_landmarks(output_path+'{}_lower.landmarksAscii'.format(sec_num))
        #Landmarks(pts=[s1_bb_pts[1]]).write_landmarks(output_path+'{}_upper.landmarksAscii'.format(sec_num))
        if s1_bb_pts is not None :
            
            npim = mask_im_region_from_bb(npim,s1_bb_pts,)#output_path+'S'+str(sec_num)+'_vS1')
            #pixel_inds = get_real_foreground_im(s1_bb_pts,s1_masked_im)
            #Landmarks(pts=pixel_inds).write_landmarks(output_path+'s1_im_inds')
            #sitk.WriteImage(cropped_im,output_path+'{}_01_cropped.tif'.format(sec_num))
            
        if m1_bb_pts is not None :
            npim = mask_im_region_from_bb(npim,m1_bb_pts,)#output_path+'S'+str(sec_num)+'_vM1')
            #sitk.WriteImage(cropped_im,output_path+'{}_01_cropped.tif'.format(sec_num))
        
        #else:
        if s1_bb_pts is not None or m1_bb_pts is not None:
            sitk.WriteImage(sitk.GetImageFromArray(npim),output_path+'S{}_binary.tif'.format(sec_num))
            marked_im = mark_binary_im(sitk.GetImageFromArray(npim))
            sitk.WriteImage(marked_im,output_path+'S{}_boundary_marked.tif'.format(sec_num))
            
            masked_im = sitk.Mask(im,sitk.GetImageFromArray(npim))
            sitk.WriteImage(masked_im,output_path+'S{}_masked.tif'.format(sec_num))
    
    

100
None
None
101
None
None
102
None
None
10
[[2165, 2824, 0], [4663, 6107, 0]]
None
2165 4663 2824 6107
11
[[2227, 2847, 0], [4703, 6111, 0]]
None
2227 4703 2847 6111
13
[[2676, 3485, 0], [5069, 6694, 0]]
None
2676 5069 3485 6694
14
[[3002, 4311, 0], [5445, 7516, 0]]
None
3002 5445 4311 7516
15
[[2846, 4248, 0], [5237, 7436, 0]]
None
2846 5237 4248 7436
16
[[2872, 5023, 0], [5180, 8154, 0]]
None
2872 5180 5023 8154
17
[[2513, 4952, 0], [4780, 8065, 0]]
None
2513 4780 4952 8065
18
[[3510, 4713, 0], [5839, 7849, 0]]
None
3510 5839 4713 7849
19
[[3559, 4955, 0], [5934, 8065, 0]]
None
3559 5934 4955 8065
20
[[3044, 5451, 0], [5335, 8552, 0]]
None
3044 5335 5451 8552
21
[[3615, 5626, 0], [5994, 8689, 0]]
None
3615 5994 5626 8689
22
[[3287, 5474, 0], [5587, 8526, 0]]
None
3287 5587 5474 8526
23
[[4042, 6497, 0], [6309, 9537, 0]]
[[4079, 1287, 0], [5143, 1443, 0]]
4042 6309 6497 9537
4079 5143 1287 1443
24
[[4132, 5573, 0], [6465, 8567, 0]]
[[4949, 207, 0], [6529, 638, 0]]
4132 6465 5573 856

In [115]:
nr_yr = 15
init = [25000,25000,50000,50000,100000,100000,150000,150000,200000,200000,\
        300000,300000,350000,350000,40000,0,0,0,0,0,\
        0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
final = 0
interest = 1.25
for i in range(35):
    saving = init[i]
    final = (final +  saving) * interest
final

1012057770.9125042

In [None]:
cutPlane = vtk.vtkPlane()
center = Surface(surface_path_vS1).surface.GetCenter()
cutPlane.SetOrigin([center[0],center[1],30])
cutPlane.SetNormal(0, 0, 1)

cutter = vtk.vtkCutter()
cutter.SetInputData(Surface(surface_path_vS1).surface)
cutter.SetCutFunction(cutPlane)
cutter.SetValue(0,0)
cutter.Update()

stripper = vtk.vtkStripper()
stripper.SetInputData(cutter.GetOutput())
stripper.Update()
Surface(polydata=stripper.GetOutput()).write_surface_mesh(output_path+'plane.vtk')
circle = cutter.GetOutput()

In [None]:
whiteImage =vtk.vtkImageData()
bounds = circle.GetBounds()
whiteImage.SetSpacing([1,1,1])
dim = [int(bounds[1] - bounds[0]),int(bounds[1] - bounds[0]),1 ]
whiteImage.SetDimensions(dim)
whiteImage.SetExtent(0, dim[0] - 1, 0, dim[1] - 1, 0, dim[2] - 1)
whiteImage.SetOrigin([bounds[0],bounds[1],0])
whiteImage.AllocateScalars(vtk.VTK_UNSIGNED_CHAR,1)
inval = 255
outval = 0
count = whiteImage.GetNumberOfPoints()
for i in range(count):
    whiteImage.GetPointData().GetScalars().SetTuple1(i, inval)

extruder =vtk.vtkLinearExtrusionFilter()
extruder.SetInputData(circle)
extruder.SetScaleFactor(1.)
extruder.SetExtrusionTypeToNormalExtrusion()
extruder.SetVector(1, 1, 0)
extruder.Update()
Surface(polydata = extruder.GetOutput()).write_surface_mesh(output_path+'extrusion.vtk')

In [None]:
pol2stenc =vtk.vtkPolyDataToImageStencil()
pol2stenc.SetTolerance(0) 
pol2stenc.SetInputConnection(extruder.GetOutputPort())
pol2stenc.SetOutputOrigin([bounds[0],bounds[1],0])
pol2stenc.SetOutputSpacing([1,1,1])
pol2stenc.SetOutputWholeExtent(whiteImage.GetExtent())
pol2stenc.Update()

imgstenc = vtk.vtkImageStencil()
imgstenc.SetInputData(whiteImage)
imgstenc.SetStencilConnection(pol2stenc.GetOutputPort())
imgstenc.ReverseStencilOff()
imgstenc.SetBackgroundValue(outval)
imgstenc.Update()

imageWriter = vtk.vtkMetaImageWriter()
imageWriter.SetFileName("labelImage.mhd")
imageWriter.SetInputConnection(imgstenc.GetOutputPort())
imageWriter.Write()

In [None]:
Surface(polydata=surf).write_surface_mesh(output_path+'plane.vtk')

In [None]:
surf.get_intersection_plane(30)


In [None]:
Landmarks(pts=s1_bb_pts).write_landmarks(output_path+'s1_bb_pts')

In [None]:
intersectoin_plane = Surface(polydata=vm1_cube.get_intersection_plane((sec_num-1) * z_cood_offset))
intersectoin_plane.surface.GetPoints().GetPoint(0)
pt1 = [int(GetPoint(0)[0]/XY_RES),int(GetPoint(0)[1])]
pt2 = [int(GetPoint(1)[0]/XY_RES),int(GetPoint(1)[1])]
pt3 = [int(GetPoint(2)[0]/XY_RES),int(GetPoint(2)[1])]
pt4 = [int(GetPoint(3)[0]/XY_RES),int(GetPoint(3)[1])]


# boundary check and clip if exceeding image boudaries
if pt1 < 0:
    lower[0] = 0
if lower[1] < 0:
    lower[1] = 0
if upper[0] > im.GetWidth():
    upper[0] = im.GetWidth()
if upper[1] > im.GetHeight():
    upper[1] = im.GetHeight()

In [None]:
5000/XY_RES,1200/XY_RES

In [None]:
im = sitk.ReadImage(images_path+'S{}_01.tif'.format(sec_num))
transform = sitk.Transform()
transform.SetParameters(tx_mat)
txed_im = sitk.Resample(im,  transform, sitk.sitkLinear, 0.0, sitk.sitkUInt8)
sitk.WriteImage(txed_im,output_path+'S{}.tif'.format(sec_num))

In [None]:
transform = sitk.Euler2DTransform()
transform.SetTranslation([2592.38,532.041])
transform.SetParameters([tx_mat[0],tx_mat[4],tx_mat[1],tx_mat[5]])

In [None]:
im = sitk.ReadImage(images_path+'S{}_01.tif'.format(sec_num))
fil = sitk.ResampleImageFilter()
fil.SetInterpolator(sitk.sitkLinear)
fil.SetTransform(transform)
fil.SetDefaultPixelValue(0)
fil.SetOutputPixelType(sitk.sitkUInt8)
fil.SetSize([im.GetWidth()+2000,im.GetHeight()+2000])
#fil.SetOutputOrigin([-2592.38,-532.041,0])
#fil.SetOutputDirection()
txed_im = fil.Execute(im)
sitk.WriteImage(txed_im,output_path+'S{}.tif'.format(sec_num))

In [None]:
tx_mat

In [None]:
tx_mat[0],tx_mat[1],tx_mat[4],tx_mat[5]