In [None]:
import tensorflow as tf
import numpy as np
import voxelmorph as vxm
import neurite as ne
import nibabel as nib
import os, sys

import numpy as np
from skimage.transform import resize
import nibabel as nib

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

def load_nii_shape(file_path):
    img = nib.load(file_path)
    return img.shape

from skimage.transform import rescale

def preprocessing_for_unet(file_path, target_shape=(160, 192, 224)):
    def load_nii_image(file_path):
        # nibabel을 사용하여 NIfTI 이미지 파일 로드
        nii = nib.load(file_path)
        # 이미지 데이터를 numpy 배열로 가져오기
        image = nii.get_fdata()
        return image

    # nifti 이미지 로드
    image = load_nii_image(file_path)

    # 가로, 세로, 높이 비율을 유지하면서 타겟 크기로 조절하기 위해 스케일링 요소 계산
    scaling_factors = [target_dim / original_dim for target_dim, original_dim in zip(target_shape, image.shape)]
    min_scaling_factor = min(scaling_factors)

    # 모든 차원에 가장 작은 스케일링 요소 적용
    rescaled_factors = [min_scaling_factor] * 3
    resized_image = rescale(image, rescaled_factors, anti_aliasing=True)

    # 패딩 계산
    pad_width = []
    for dim in range(3):
        before_pad = int((target_shape[dim] - resized_image.shape[dim]) / 2)
        after_pad = target_shape[dim] - resized_image.shape[dim] - before_pad
        pad_width.append((before_pad, after_pad))

    # 이미지를 타겟 크기로 패딩
    padded_image = np.pad(resized_image, pad_width=pad_width, mode='constant', constant_values=0)

    # 이미지를 [0, 1] 범위로 리스케일
    rescaled_image = (padded_image - padded_image.min()) / (padded_image.max() - padded_image.min())

    return rescaled_image

# our data will be of shape 160 x 192 x 224
vol_shape = (160, 192, 224)
nb_features = [
    [16, 32, 32, 32],
    [32, 32, 32, 32, 32, 16, 16]
]


vxm_model = vxm.networks.VxmDense(vol_shape, nb_features, int_steps=0)
# vxm_model.summary()

# load validation data from Nifti files
val_volume_1 = preprocessing_for_unet('/home/IXI012-HH-1211-MRA_400x400x133_isores=0.6000.nii')
seg_volume_1 = preprocessing_for_unet('/home/IXI012-HH-1211-MRA_vessel_mask.nii')
val_volume_2 = preprocessing_for_unet('/home/IXI013-HH-1212-MRA_400x400x133_isores=0.6000.nii')
seg_volume_2 = preprocessing_for_unet('/home/IXI013-HH-1212-MRA_vessel_mask.nii')

# normalize data to [0, 1] range
val_volume_1 = val_volume_1 / np.max(val_volume_1)
val_volume_2 = val_volume_2 / np.max(val_volume_2)

# prepare input for validation
val_input = [
    val_volume_1[np.newaxis, ..., np.newaxis],
    val_volume_2[np.newaxis, ..., np.newaxis]
]

# prepare input for validation
val_input = [
    val_volume_1[np.newaxis, ..., np.newaxis],
    val_volume_2[np.newaxis, ..., np.newaxis]
]

# Load a trained 3D model.
vxm_model.load_weights('/home/brain_3d.h5')
val_pred = vxm_model.predict(val_input)

moved_pred = val_pred[0].squeeze()  # 3D 이미지 예측 값
pred_warp = val_pred[1] # 변환 맵 예측 값

warp_model = vxm.networks.Transform(vol_shape, interp_method='nearest')  # warp 후 interpolation
warped_seg = warp_model.predict([val_volume_1[np.newaxis, ..., np.newaxis], pred_warp])  # [고정 이미지(np.array), 이동된 이미지에 대한 예측 변환(np.array)]를 통해 3D이미지 시각화를 구현

warped_seg_final = warped_seg.squeeze()

In [None]:
def save_nii_image(image_data, affine, output_path):
    # Create a new NIfTI image object
    nii_image = nib.Nifti1Image(image_data, affine)

    # Save the image to a file
    nib.save(nii_image, output_path)

# Get the affine transformation matrix from the original NIfTI file
original_nii = nib.load('/home/IXI012-HH-1211-MRA_400x400x133_isores=0.6000.nii')
affine = original_nii.affine

# Save the warped segmentation image as a NIfTI file
output_path = os.path.join('/home/', 'warped_seg_final.nii')
save_nii_image(warped_seg_final, affine, output_path)

In [None]:
from scipy.ndimage import zoom

def rescale_to_original(image, original_shape=(400, 400, 133)):
    # Calculate the zoom factors for each dimension
    zoom_factors = [original_dim / current_dim for original_dim, current_dim in zip(original_shape, image.shape)]

    # Rescale the image using the zoom factors
    rescaled_image = zoom(image, zoom_factors, order=1)

    return rescaled_image

# Rescale the warped_seg_final image back to the original dimensions
rescaled_warped_seg_final = rescale_to_original(warped_seg_final)

# Check if the rescaled image has the original dimensions
print("Rescaled image shape:", rescaled_warped_seg_final.shape)


In [None]:
# Get the affine transformation matrix from the original NIfTI file
original_nii = nib.load('/home/IXI012-HH-1211-MRA_400x400x133_isores=0.6000.nii')
affine = original_nii.affine
rescaled_warped_seg_final = rescale_to_original(warped_seg_final)

# Save the warped segmentation image as a NIfTI file
output_path = os.path.join('/home/', 'rescaled_warped_seg_final .nii')
save_nii_image(rescaled_warped_seg_final, affine, output_path)

In [None]:
import matplotlib.pyplot as plt

def visualize_orthogonal_slices(image, slice_indices=None):
    if slice_indices is None:
        # 각 축에 대한 기본 슬라이스 인덱스 설정
        slice_indices = [dim // 2 for dim in image.shape]
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Sagittal (X) 슬라이스 시각화
    axes[0].imshow(image[slice_indices[0], :, :].T, cmap='gray', origin='lower')
    axes[0].set_title(f'Sagittal slice (X) at {slice_indices[0]}')
    axes[0].axis('off')

    # Coronal (Y) 슬라이스 시각화
    axes[1].imshow(image[:, slice_indices[1], :].T, cmap='gray', origin='lower')
    axes[1].set_title(f'Coronal slice (Y) at {slice_indices[1]}')
    axes[1].axis('off')

    # Axial (Z) 슬라이스 시각화
    axes[2].imshow(image[:, :, slice_indices[2]], cmap='gray', origin='lower')
    axes[2].set_title(f'Axial slice (Z) at {slice_indices[2]}')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
visualize_orthogonal_slices(warped_seg_final)

In [None]:
visualize_orthogonal_slices(moved_pred)

In [None]:
import matplotlib.pyplot as plt

def overlay_orthogonal_slices(image1, image2, slice_indices=None, alpha=0.7):
    if slice_indices is None:
        # 각 축에 대한 기본 슬라이스 인덱스 설정
        slice_indices = [dim // 2 for dim in image1.shape]

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    # 겹치는 부분에서는 빨간색과 파란색, 노란색과 초록색이 겹쳐짐
    cmp1, cmp2 = 'hot', 'cool'
    # 겹치는 부분에서는 무지개색과 회색이 겹쳐짐
    # cmp1, cmp2 = 'jet', 'gray'
    # Sagittal (X) 슬라이스 시각화
    axes[0].imshow(image1[slice_indices[0], :, :].T, cmap='gray', alpha=1, origin='lower')
    axes[0].imshow(image1[slice_indices[0], :, :].T, cmap=cmp1, alpha=alpha, origin='lower')
    axes[0].imshow(image2[slice_indices[0], :, :].T, cmap=cmp2, alpha=alpha, origin='lower')
    axes[0].set_title(f'Sagittal slice (X) at {slice_indices[0]}')
    axes[0].axis('off')

    # Coronal (Y) 슬라이스 시각화
    axes[1].imshow(image1[:, slice_indices[1], :].T, cmap='gray', alpha=1, origin='lower')
    axes[1].imshow(image1[:, slice_indices[1], :].T, cmap=cmp1, alpha=alpha, origin='lower')
    axes[1].imshow(image2[:, slice_indices[1], :].T, cmap=cmp2, alpha=alpha, origin='lower')
    axes[1].set_title(f'Coronal slice (Y) at {slice_indices[1]}')
    axes[1].axis('off')

    # Axial (Z) 슬라이스 시각화
    axes[2].imshow(image1[:, :, slice_indices[2]].T, cmap='gray', alpha=1, origin='lower')
    axes[2].imshow(image1[:, :, slice_indices[2]].T, cmap=cmp1, alpha=alpha, origin='lower')
    axes[2].imshow(image2[:, :, slice_indices[2]].T, cmap=cmp2, alpha=alpha, origin='lower')
    axes[2].set_title(f'Axial slice (Z) at {slice_indices[2]}')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

# 이미지 슬라이싱 오버레이 시각화
overlay_orthogonal_slices(warped_seg_final, moved_pred)

In [None]:
import plotly.graph_objs as go
from skimage import measure

def visualize_3d_volume(image, threshold=0.05):
    # Create a binary mask with the given threshold
    mask = image > threshold

    # Calculate the 3D mesh of the volume using the marching cubes algorithm
    verts, faces, _, _ = measure.marching_cubes(mask, level=threshold)

    # Create a 3D scatter plot of the vertices
    scatter = go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
                        i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
                        opacity=0.1, color='red')

    # Create a 3D layout
    layout = go.Layout(scene=dict(xaxis=dict(visible=False),
                                   yaxis=dict(visible=False),
                                   zaxis=dict(visible=False)))

    # Create and display the figure
    fig = go.Figure(data=[scatter], layout=layout)
    fig.show()

# Visualize the warped segmentation image
visualize_3d_volume(warped_seg_final)
visualize_3d_volume(moved_pred)