In [11]:
from pathlib import Path
import numpy as np
import nibabel as nib
import torch
import pandas as pd
import re 

In [53]:
params_dir = Path("/shared/slice_inflate/data/output/20230314__11_19_08_dummy-o2sqp1hb")

hla_params_files = list(params_dir.glob('hla_params*.pt'))

df = None

for fl in hla_params_files:
    mt = re.match(r".*(train|test).*([0-9]{1,3})\.pt", str(fl))
    phase, epoch = mt[1], int(mt[2])
    param_dict = torch.load(fl)
    epx_theta_aps = param_dict['epx_hla_theta_aps']
    epx_theta_t_offsets = param_dict['epx_hla_theta_t_offsets']
    epx_theta_zps = param_dict['epx_hla_theta_zps']

    ids = list(zip(*sorted(epx_theta_aps.items())))[0]
    theta_ap = list(zip(*sorted(epx_theta_aps.items())))[1]
    theta_t_offsets = list(zip(*sorted(epx_theta_t_offsets.items())))[1]
    theta_zp = list(zip(*sorted(epx_theta_zps.items())))[1]

    data = dict(
        view='hla', 
        sample=_id,
        epoch=epoch,
        theta_ap=theta_ap,
        theta_t_offsets=theta_t_offsets,
        theta_zp=theta_zp,
        phase=phase,
    )

    if df is None:
        df = pd.DataFrame(data)
    else:
        df = pd.concat([df, pd.DataFrame(data)])

In [54]:
df

Unnamed: 0,view,sample,epoch,theta_ap,theta_t_offsets,theta_zp,phase
0,hla,1001-mr,0,"[tensor(0.7136), tensor(0.5254), tensor(-0.027...","[tensor(0.0267), tensor(-0.0410), tensor(-0.02...",[tensor(1.)],test


In [2]:
import pymeshfix as mf
import pyvista as pv

assert False, "Caution: pyvista plots require a machine with attached physical display or a display emulation (Xvfb package) - otherwise the kernel may die when generating renderings."

ModuleNotFoundError: No module named 'pyvista'

# Create mesh from shape voxels

In [None]:
import numpy as np
import torch
import h5py
import pymeshfix as mf
import nibabel as nib
import pyvista as pv
from pathlib import Path
from skimage import measure

In [None]:
def replace_label_values(label):
    # Replace label numbers with MMWHS equivalent
    # STRUCTURE           MMWHS   ACDC    NNUNET
    # background          0       0       0
    # left_myocardium     205     2       1
    # left_atrium         420     N/A     2
    # ?                   421     N/A     N/A
    # left_ventricle      500     3       3
    # right_atrium        550     N/A     4
    # right_ventricle     600     1       5
    # ascending_aorta     820     N/A     6
    # pulmonary_artery    850     N/A     7
    orig_values = [0,  205, 420, 421, 500, 550, 600, 820, 850]
    new_values = [0,  1,   2,   0,   3,   4,   5,   0,   0]

    modified_label = label.clone()
    for orig, new in zip(orig_values, new_values):
        modified_label[modified_label == orig] = new
    return modified_label

In [None]:
CASE = "1004"

In [None]:
nii_shape = nib.load(f"mr_train_{CASE}_label_registered.nii.gz")
shape_data = replace_label_values(torch.as_tensor(nii_shape.get_fdata())).long()
shape_affine = torch.as_tensor(nii_shape.affine)
image_sample = nib.load(f"mr_train_{CASE}_image_registered.nii.gz").get_fdata()
SPACING = (1.,1.,1.)
STEP_SIZE = 2
CLASSES = ['background', 'MYO', 'LA', 'LV', 'RA', 'RV']

heart_data = {}
for class_idx, tag in enumerate(CLASSES):
    if class_idx == 0: continue

    sub_label = torch.nn.functional.one_hot(shape_data.long(), len(CLASSES))[:,:,:, class_idx]
    verts, faces, normals, values = measure.marching_cubes(sub_label.cpu().numpy(), spacing=SPACING, step_size=STEP_SIZE)
    mm_verts = torch.cat([torch.tensor(verts.copy()), torch.ones(len(verts),1)], dim=1)
    mm_verts = (shape_affine @ mm_verts.T.double()).T[:,:3]

    data = dict(
        verts=torch.as_tensor(mm_verts.numpy().copy()),
        faces=torch.as_tensor(faces.copy()),
        normals=torch.as_tensor(normals.copy()), 
        values=torch.as_tensor(values.copy())
    )
    heart_data[tag] = data

In [55]:
import re
optimized_sa_affines = []

run_data_path = Path("/shared/slice_inflate/data/output/20230310__01_40_03_devout-brook-811-stage-1")

file_paths = list(run_data_path.glob('*.pt'))

key = lambda _path: int(re.match(r".*?([0-9]+)\.pt", str(_path))[1])
file_paths.sort(key=key)

for fl in file_paths:
    content = torch.load(fl)
    content.keys()
    break

In [50]:
mt = int(re.match(r".*?([0-9]+)\.pt", str(file_paths))[1])
print(mt)

0


In [None]:
hla_mat = torch.from_numpy(np.loadtxt("mmwhs_1002_4CH.mat"))
sa_mat = torch.from_numpy(np.loadtxt("mmwhs_1002_SA.mat"))

In [None]:
optimized_sa_mat_1004 = torch.tensor(
    [[[ 0.3762,  0.0235, -0.9262,  0.0000],
      [ 0.6500,  0.7056,  0.2820,  0.0000],
      [ 0.6602, -0.7082,  0.2502,  0.0000],
      [ 0.0000,  0.0000,  0.0000,  1.0000]]]
)

optimized_hla_mat_1004 = torch.tensor(
    [[[ 0.8096, -0.5865,  0.0236,  0.0000],
      [ 0.5751,  0.7844, -0.2325,  0.0000],
      [ 0.1178,  0.2018,  0.9723,  0.0000],
      [ 0.0000,  0.0000,  0.0000,  1.0000]]]
)

In [None]:
from align_mmwhs import nifti_transform
FOV_MM = torch.tensor([224,224,224])
FOV_VOX = torch.tensor([160,160,160])

with torch.no_grad():
    sa_label, sa_affine, sa_grid_affine = nifti_transform(shape_data.unsqueeze(0).unsqueeze(0), shape_affine.unsqueeze(0), sa_mat.unsqueeze(0), 
        fov_mm=FOV_MM, fov_vox=FOV_VOX, is_label=True, 
        # pre_grid_sample_affine=optimized_sa_mat_1004,
        pre_grid_sample_affine=None
    )
    
    hla_label, hla_affine, hla_grid_affine = nifti_transform(shape_data.unsqueeze(0).unsqueeze(0), shape_affine.unsqueeze(0), hla_mat.unsqueeze(0), 
        fov_mm=FOV_MM, fov_vox=FOV_VOX, is_label=True, 
        # pre_grid_sample_affine=optimized_hla_mat_1004,
        pre_grid_sample_affine=None
    )


In [None]:
from matplotlib import pyplot as plt
plt.imshow(sa_label.squeeze()[:,:,80].T)

In [None]:
from matplotlib import pyplot as plt
plt.imshow(hla_label.squeeze()[:,:,80].T)

In [None]:
# Nifti original plane slicing.
# sa_normal = torch.tensor([1.,0.,0.])  # working! for D slicing
# sa_normal = torch.tensor([0.,1.,0.])  # working! for H slicing

sa_normal = torch.tensor([0.,0.,1.])  # working! for W slicing
sa_support = (shape_affine @ torch.tensor([64.,64.,64.,1.]).double())[:3]
print("Non transformed", sa_normal, sa_support)

# SA slicing
sa_normal = (sa_affine @ torch.tensor([0.,0.,1.,0.]).double())[0,:3]
sa_support = (sa_affine @ torch.tensor([80.,80.,80.,1.]).double())[0,:3]

print("Transformed SA", sa_normal, sa_support)

hla_normal = (hla_affine @ torch.tensor([0.,0.,1.,0.]).double())[0,:3]
hla_support =(hla_affine @ torch.tensor([80.,80.,80.,1.]).double())[0,:3]
print("Transformed HLA", hla_normal, hla_support)

In [None]:
h_vertices = heart_data['MYO']['verts']
h_faces = heart_data['MYO']['faces']

for tag, data in heart_data.items():
    faces = data['faces']
    num_faces = faces.shape[0]
    num_points = torch.tensor([3]*num_faces).view(num_faces,1)
    data['pyvista_faces'] = torch.cat([num_points, faces], dim=1)

In [None]:
# https://coolors.co/b8336a-726da8-7d8cc4-a0d2db-c490d1

palette = [
    '#B8336A',
    '#726DA8',
    '#7D8CC4',
    '#A0D2DB',
    '#C490D1',
]
dark_palette = [
    '#4F172E', 
    '#424064',
    '#485070',
    '#547378',
    '#73507C',
]

In [None]:
SHOW_FULL_MESH = False

SHOW_SA_PLANE = True
SHOW_SA_SLICE = True

SHOW_HLA_PLANE = True
SHOW_HLA_SLICE = True

CPOS = [
    (456.42537425592207, -476.3133591268548, 11.507223342275584),
    (-49.71246948463747, 0.5437890350509349, -5.752938139380589),
    (0.6710467014241088, 0.7038135452178427, -0.23311546082513432)
]

plotter = pv.Plotter(
    lighting='three lights'
)
plotter.background_color = "white"

cpos = 'iso'
sa_plane = pv.Plane(center=sa_support.tolist(), direction=sa_normal.tolist(), i_size=200, j_size=200, i_resolution=1, j_resolution=1)
sa_plane.point_data.clear()
sa_edges = sa_plane.extract_feature_edges(boundary_edges=True, feature_edges=False, manifold_edges=False)

hla_plane = pv.Plane(center=hla_support.tolist(), direction=hla_normal.tolist(), i_size=200, j_size=200, i_resolution=1, j_resolution=1)
hla_plane.point_data.clear()
hla_edges = hla_plane.extract_feature_edges(boundary_edges=True, feature_edges=False, manifold_edges=False)

full_mesh = []

# Prepare meshes
for idx, (tag, data) in enumerate(heart_data.items()):
    surf = pv.PolyData(data['verts'].numpy(), data['pyvista_faces'].view(-1).numpy())
    scalars=np.array([idx]*data['verts'].shape[0])
    surf.point_data.set_scalars(scalars, 'scalars')
    # full_mesh.append(surf)
    smooth = surf.smooth_taubin(n_iter=100, pass_band=0.3)
    full_mesh.append(smooth)

block = pv.MultiBlock(full_mesh)
full_mesh = block.combine(merge_points=False)

sa_slice = full_mesh.slice(normal=sa_normal.tolist(), origin=sa_support.tolist())
hla_slice = full_mesh.slice(normal=hla_normal.tolist(), origin=hla_support.tolist())

if SHOW_FULL_MESH:
    plotter.add_mesh(full_mesh, name='all', cmap=palette, line_width=2, show_scalar_bar=False, smooth_shading=True)   

if SHOW_SA_SLICE:
    plotter.add_mesh(sa_slice, name=tag +'_sa_slice', cmap=palette, line_width=2, show_scalar_bar=False) 

if SHOW_SA_PLANE:   
    plotter.add_mesh(sa_plane, color=palette[idx],  opacity=0.3, show_edges=False, line_width=2)
    plotter.add_mesh(sa_edges, color=dark_palette[idx], line_width=1)

if SHOW_HLA_SLICE:
    plotter.add_mesh(hla_slice, name=tag +'_hla_slice', cmap=palette, line_width=2, show_scalar_bar=False)

if SHOW_HLA_PLANE:    
    plotter.add_mesh(hla_plane, color=palette[idx],  opacity=0.3, show_edges=False, line_width=2)
    plotter.add_mesh(hla_edges, color=dark_palette[idx], line_width=1)

plotter.view_isometric()
plotter.enable_parallel_projection()

# plotter.camera.position = (hla_support+hla_normal).tolist()
# plotter.camera.focal_point = hla_support.tolist()
# plotter.camera.up = (0.0, 1.0, 0.0)
plotter.camera.zoom(1.2)

plotter.show(
    window_size=[1200,1200], 
    # jupyter_backend='static',
    cpos=CPOS
)

In [None]:
plotter.camera_position