In [None]:
from ml4cvd.arguments import _get_tmap
import h5py
import numpy as np
from ml4cvd.tensor_from_file import _mri_hd5_to_structured_grids
import h5py
import time
import matplotlib.pyplot as plt
import os
import sys

In [None]:
USER = 'pdiachil'
HOME_PATH = '/home/'+USER

# Add position, orientation, width, height, and thickness tensors to HD5
HD5 tensors are saved in {HOME_PATH}

In [None]:
from ml4cvd.arguments import parse_args
from ml4cvd.recipes import run
sys.argv = ['tensorize',
            '--mode', 'tensorize',
            '--zip_folder', '/mnt/disks/sax-and-lax-zip-2019-09-30/',
            '--xml_folder', '/mnt/disks/ecg-rest-xml-36k/',
            '--output_folder', f'{HOME_PATH}/mri_tensors/',
            '--tensors', f'{HOME_PATH}/mri_tensors/',
            '--mri_field_ids', '20208', '20209', 
            '--xml_field_ids', 
            '--min_sample_id', '2000000', 
            '--max_sample_id', '2000200'
           ]
args = parse_args()
run(args)

# Extract centers of gravity of myocardium at each slice

In [None]:
import vtk
from ml4cvd.defines import EPS, MRI_SEGMENTED_CHANNEL_MAP, MRI_FRAMES

with h5py.File(f'{HOME_PATH}/mri_tensors/2000119.hd5', 'r') as hd5:
    cine_segmented_grids = _mri_hd5_to_structured_grids(hd5, 'cine_segmented_sax_inlinevf_segmented')
    for cine_segmented_grid in cine_segmented_grids:
        cell_centers = vtk.vtkCellCenters()
        cell_centers.SetInputData(cine_segmented_grid)
        cell_centers.Update()
        cell_pts = vtk.util.numpy_support.vtk_to_numpy(cell_centers.GetOutput().GetPoints().GetData())
        dims = cine_segmented_grid.GetDimensions()
        # Remove 1 to get cell dimensions rather than point dimensions
        dims = [dim - 1 for dim in dims] 
        ncells_per_slice = dims[0]*dims[1]
        indices = []
        for t in range(MRI_FRAMES):
            arr_name = f'cine_segmented_sax_inlinevf_segmented_{t}'
            segmented_arr = vtk.util.numpy_support.vtk_to_numpy(cine_segmented_grid.GetCellData().GetArray(arr_name))
            segmented_arr = segmented_arr.reshape(*dims, order='F')
            cogs = np.zeros((dims[2], 3))
            cogs_2 = np.zeros((dims[2], 3))
            cogs[:] = np.nan
            for s in range(dims[2]):
                thresh_indices = np.nonzero((segmented_arr[:, :, s].T > MRI_SEGMENTED_CHANNEL_MAP['myocardium'] - EPS) &
                                            (segmented_arr[:, :, s].T < MRI_SEGMENTED_CHANNEL_MAP['myocardium'] + EPS))
                thresh_flat_indices = np.ravel_multi_index(thresh_indices, (dims[0], dims[1]))
                thresh_flat_indices += s*ncells_per_slice
                indices.append(thresh_flat_indices)
                if len(thresh_flat_indices) > 0 :
                        cogs[s, :] = np.mean(cell_pts[thresh_flat_indices], axis=0)
            slices_no_nans = ~np.isnan(cogs).any(axis=1)
            cogs_mean = np.mean(cogs[slices_no_nans], axis=0)
            uu, dd, vv = np.linalg.svd(cogs[slices_no_nans] - cogs_mean)
            ventricle_length = np.linalg.norm(cogs[0] - cogs[-1])
            line_pts = vv[0] * np.mgrid[(-0.5*ventricle_length):(0.5*ventricle_length):2j][:, np.newaxis] + cogs_mean
            line_source = vtk.vtkLineSource()
            line_source.SetPoint1(cogs[0, :])
            line_source.SetPoint2(cogs[-1, :])
            line_source.Update()
            line_writer = vtk.vtkXMLPolyDataWriter()
            line_writer.SetInputConnection(line_source.GetOutputPort())
            line_writer.SetFileName(os.path.join(HOME_PATH, 'mri_tensors', f'cog_line_{t}.vtp'))
            line_writer.Update()
            thresh_channel = vtk.vtkThreshold()
            thresh_channel.SetInputData(cine_segmented_grid)
            thresh_channel.ThresholdBetween(MRI_SEGMENTED_CHANNEL_MAP['myocardium'] - EPS,
                                            MRI_SEGMENTED_CHANNEL_MAP['myocardium'] + EPS)
            thresh_channel.SetInputArrayToProcess(0, 0, 0, vtk.vtkDataObject.FIELD_ASSOCIATION_CELLS, arr_name)
            thresh_channel.Update()
            thresh_surf = vtk.vtkDataSetSurfaceFilter()
            thresh_surf.SetInputConnection(thresh_channel.GetOutputPort())
            thresh_surf.Update()
            thresh_writer = vtk.vtkXMLPolyDataWriter()
            thresh_writer.SetInputConnection(thresh_surf.GetOutputPort())
            thresh_writer.SetFileName(os.path.join(HOME_PATH, 'mri_tensors', f'channel_thresh_{t}.vtp'))
            thresh_writer.Update() 

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(segmented_arr[:, :, 7])
thresh_indices = np.nonzero((segmented_arr[:, :, 7] > MRI_SEGMENTED_CHANNEL_MAP['myocardium'] - EPS) &
                            (segmented_arr[:, :, 7] < MRI_SEGMENTED_CHANNEL_MAP['myocardium'] + EPS))
thresh_flat_indices = np.ravel_multi_index(thresh_indices, (dims[0], dims[1]))

In [None]:
only_myocardium = np.zeros((256, 256))
only_myocardium[thresh_indices[0], thresh_indices[1]] = 1
plt.imshow(only_myocardium)

In [None]:
%matplotlib inline
from mpl_toolkits import mplot3d

fig = plt.figure()
ax = plt.axes(projection="3d")

for i in [1, 3, 5, 7]:
    ax.scatter3D(cell_pts[indices[i], 0], cell_pts[indices[i], 1], cell_pts[indices[i], 2], 
                 c=cell_pts[indices[i], 0], cmap='hsv');
ax.scatter3D(cogs[:, 0], cogs[:, 1], cogs[:, 2])
ax.plot3D(line_pts[:, 0], line_pts[:, 1], line_pts[:, 2])

# Test new TMAP

In [None]:
tm = _get_tmap('cine_segmented_sax_inlinevf_axis')
with h5py.File(f'{HOME_PATH}/mri_tensors/2000119.hd5', 'r') as hd5:
    axes = tm.tensor_from_file(tm, hd5)

In [None]:
axes

In [None]:
axes[1, 0]