In [1]:
from pathlib import Path
import numpy as np
import torch
import pymeshfix as mf
import nibabel as nib
import pyvista as pv
import nibabel as nib

assert False, "Caution: pyvista plots require a machine with attached physical display or a display emulation (Xvfb package) - otherwise the kernel may die when generating renderings."

AssertionError: Caution: pyvista plots require a machine with attached physical display or a display emulation (Xvfb package) - otherwise the kernel may die when generating renderings.

# Create mesh from shape voxels

In [50]:
import numpy as np
import torch
import h5py
import pymeshfix as mf
import nibabel as nib
import pyvista as pv
from pathlib import Path
from skimage import measure

In [51]:
def replace_label_values(label):
    # Replace label numbers with MMWHS equivalent
    # STRUCTURE           MMWHS   ACDC    NNUNET
    # background          0       0       0
    # left_myocardium     205     2       1
    # left_atrium         420     N/A     2
    # ?                   421     N/A     N/A
    # left_ventricle      500     3       3
    # right_atrium        550     N/A     4
    # right_ventricle     600     1       5
    # ascending_aorta     820     N/A     6
    # pulmonary_artery    850     N/A     7
    orig_values = [0,  205, 420, 421, 500, 550, 600, 820, 850]
    new_values = [0,  1,   2,   0,   3,   4,   5,   0,   0]

    modified_label = label.clone()
    for orig, new in zip(orig_values, new_values):
        modified_label[modified_label == orig] = new
    return modified_label

In [95]:
nii_shape = nib.load("mr_train_1004_label_registered.nii.gz")
shape_data = replace_label_values(torch.as_tensor(nii_shape.get_fdata())).long()

In [264]:
shape_affine = torch.as_tensor(nii_shape.affine)
image_sample = nib.load("mr_train_1004_image_registered.nii.gz").get_fdata()
SPACING = (1,1,1)
STEP_SIZE = 2
CLASSES = ['background', 'MYO', 'LA', 'LV', 'RA', 'RV']

heart_data = {}
for class_idx, tag in enumerate(CLASSES):
    if class_idx == 0: continue

    sub_label = torch.nn.functional.one_hot(shape_data.long(), len(CLASSES))[:,:,:, class_idx]
    verts, faces, normals, values = measure.marching_cubes(sub_label.cpu().numpy(), spacing=SPACING, step_size=STEP_SIZE)
    data = dict(
        verts=torch.as_tensor(verts.copy()),
        faces=torch.as_tensor(faces.copy()),
        normals=torch.as_tensor(normals.copy()), 
        values=torch.as_tensor(values.copy())
    )
    heart_data[tag] = data

In [265]:
hla_mat = torch.from_numpy(np.loadtxt("mmwhs_1002_HLA_red_slice_to_ras.mat"))
sa_mat = torch.from_numpy(np.loadtxt("mmwhs_1002_SA_yellow_slice_to_ras.mat"))

In [266]:
optimized_sa_mat_1004 = torch.tensor(
    [[[ 0.3762,  0.0235, -0.9262,  0.0000],
      [ 0.6500,  0.7056,  0.2820,  0.0000],
      [ 0.6602, -0.7082,  0.2502,  0.0000],
      [ 0.0000,  0.0000,  0.0000,  1.0000]]]
)

optimized_hla_mat_1004 = torch.tensor(
    [[[ 0.8096, -0.5865,  0.0236,  0.0000],
      [ 0.5751,  0.7844, -0.2325,  0.0000],
      [ 0.1178,  0.2018,  0.9723,  0.0000],
      [ 0.0000,  0.0000,  0.0000,  1.0000]]]
)

In [267]:
from align_mmwhs import nifti_transform
FOV_MM = torch.tensor([224,224,224])
FOV_VOX = torch.tensor([160,160,160])

with torch.no_grad():
    hla_label, hla_affine = nifti_transform(shape_data.unsqueeze(0).unsqueeze(0), shape_affine.unsqueeze(0), hla_mat.unsqueeze(0), fov_mm=FOV_MM, fov_vox=FOV_VOX, is_label=True, pre_grid_sample_affine=None)
    sa_label, sa_affine = nifti_transform(shape_data.unsqueeze(0).unsqueeze(0), shape_affine.unsqueeze(0), sa_mat.unsqueeze(0), fov_mm=FOV_MM, fov_vox=FOV_VOX, is_label=True, pre_grid_sample_affine=None)



In [450]:
sa_normal = (sa_affine.inverse() @ torch.tensor([0.,0.,1.,0.]).double())[0,:3]
sa_support = torch.tensor([64.,64.,64.])
# sa_support = (sa_affine.inverse() @ torch.tensor([0.,0.,0.,1.]).to(dtype=sa_to_hla.dtype))[0,:3]
print(sa_normal, sa_support)

hla_normal = (hla_affine.inverse() @ torch.tensor([0.,0.,1.,0.]).double())[0,:3]
hla_support = torch.tensor([64.,64.,64.])
print(hla_normal, hla_support)
# hla_support = (hla_to_sa @ torch.tensor([64.,64.,64.,1.]).to(dtype=sa_to_hla.dtype))[:3

tensor([0.0390, 0.5021, 0.5065], dtype=torch.float64) tensor([64., 64., 64.])
tensor([ 0.2684, -0.1490,  0.6450], dtype=torch.float64) tensor([64., 64., 64.])


In [451]:
h_vertices = heart_data['MYO']['verts']
h_faces = heart_data['MYO']['faces']

for tag, data in heart_data.items():
    faces = data['faces']
    num_faces = faces.shape[0]
    num_points = torch.tensor([3]*num_faces).view(num_faces,1)
    data['pyvista_faces'] = torch.cat([num_points, faces], dim=1)

In [452]:
# https://coolors.co/b8336a-726da8-7d8cc4-a0d2db-c490d1

palette = [
    '#B8336A',
    '#726DA8',
    '#7D8CC4',
    '#A0D2DB',
    '#C490D1',
]
dark_palette = [
    '#4F172E', 
    '#424064',
    '#485070',
    '#547378',
    '#73507C',
]

In [457]:
plotter = pv.Plotter(
    lighting='three lights'
)
plotter.background_color = "white"

# cpos = [(-57.05794118047339, -173.82419298093834, 190.51142547607827),
#  (64.0, 64.0, 64.43975067138672),
#  (-0.8136483842355717, 0.5348925173329734, 0.2277417434991626)]

cpos = 'iso'
sa_plane = pv.Plane(center=sa_support.tolist(), direction=sa_normal.tolist(), i_size=100, j_size=100, i_resolution=1, j_resolution=1)
sa_plane.point_data.clear()
sa_edges = sa_plane.extract_feature_edges(boundary_edges=True, feature_edges=False, manifold_edges=False)

hla_plane = pv.Plane(center=hla_support.tolist(), direction=hla_normal.tolist(), i_size=100, j_size=100, i_resolution=1, j_resolution=1)
hla_plane.point_data.clear()
hla_edges = hla_plane.extract_feature_edges(boundary_edges=True, feature_edges=False, manifold_edges=False)

full_mesh = []

# Prepare meshes
for idx, (tag, data) in enumerate(heart_data.items()):
    surf = pv.PolyData(data['verts'].numpy(), data['pyvista_faces'].view(-1).numpy())
    scalars=np.array([idx]*data['verts'].shape[0])
    surf.point_data.set_scalars(scalars, 'scalars')
    smooth = surf.smooth_taubin(n_iter=100, pass_band=0.3)
    full_mesh.append(smooth)

block = pv.MultiBlock(full_mesh)
full_mesh = block.combine(merge_points=False)

sa_slice = full_mesh.slice(normal=sa_normal.tolist(), origin=sa_support.tolist())

plotter.add_mesh(full_mesh, name='all', cmap=palette, line_width=2, show_scalar_bar=False, smooth_shading=True)   
plotter.add_mesh(sa_slice, name=tag +'_sa_slice', cmap=palette, line_width=2, show_scalar_bar=False)    
plotter.add_mesh(sa_plane, color=palette[idx],  opacity=0.3, show_edges=False, line_width=2)
plotter.add_mesh(sa_edges, color=dark_palette[idx], line_width=1)

hla_slice = full_mesh.slice(normal=hla_normal.tolist(), origin=hla_support.tolist())

plotter.add_mesh(hla_slice, name=tag +'_hla_slice', cmap=palette, line_width=2, show_scalar_bar=False)    
plotter.add_mesh(hla_plane, color=palette[idx],  opacity=0.3, show_edges=False, line_width=2)
plotter.add_mesh(hla_edges, color=dark_palette[idx], line_width=1)

plotter.view_isometric()
plotter.show(
    window_size=[2000,2000], 
    # jupyter_backend='static',
    cpos=cpos,
)
# pv.global_theme.anti_aliasing = 'ssaa'

ViewInteractiveWidget(height=2000, layout=Layout(height='auto', width='100%'), width=2000)

In [248]:
plotter.camera_position

[(-74.39510547025466, -266.25200577202395, 160.91771143179494),
 (100.36705418048375, 86.35979157449216, 89.34951083603468),
 (-0.7767468638865815, 0.47005278688638497, 0.41918335723487804)]