# Visualization of MRI data with labels

In [None]:
# import
import os
import numpy as np

import nibabel as nib

import matplotlib.pyplot as plt

## Dataset visualization

In [None]:
mri_1 = nib.load('../MRI/00001.nii')
mri_1_data = mri_1.get_fdata()
mri_1_data.shape

In [None]:
def show(slices, size, col=5, cmap=None, aspect=6):
   rows = -(-len(slices)//col)
   fig, axes = plt.subplots(rows, col, figsize=(15,2*rows))
   # Flatten the axes array to simplify indexing
   axes = axes.flatten()
   for i, slice in enumerate(slices):
       axes[i].imshow(slice.T, cmap=cmap, origin="lower", aspect=aspect)
       axes[i].set_title(f'Slice {size - i*5}')  # Set titles if desired
   # Adjust layout to prevent overlap of titles
   plt.tight_layout()

In [None]:
def show_slices(data, start, end, lap, col=5, cmap=None, aspect=6):
   """ Function to display row of image slices """
   it = 0
   slices = []
   for slice in range(start, 0, -lap):
       it += 1
       slices.append(data[:, slice, :])
       if it==end: break
   show(slices, data.shape[1], col, cmap, aspect)
    

In [None]:
show_slices(mri_1_data, mri_1_data.shape[1]-1, 25, 5, cmap="gray")

## Label dataset

In [None]:
mri_1_label = nib.load('..\MRI\Labels\labels_00001.nii')
mri_1_label_data = mri_1_label.get_fdata()
mri_1_label_data.shape

In [None]:
show_slices(mri_1_label_data,  mri_1_label_data.shape[1]-1, 25, 5)

In [None]:
print("Mri label data: \n", np.max(mri_1_label_data))
print("Mri data: \n", np.max(mri_1_data))


## Image transformation

In [None]:
rot_90_data = np.rot90(mri_1_label_data, k=1, axes=(0, 2))
rot_90_img = nib.Nifti1Image(rot_90_data, np.eye(4))
nib.save(rot_90_img, '../outputs/rot_90_img.nii')

# Display the original and transformed images (slices)
original_slice = mri_1_label_data[:, mri_1_label_data.shape[1]-20, :]
transformed_slice = rot_90_img.get_fdata()[:, rot_90_img.shape[1]-20, :]

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_slice.T, aspect=6)
plt.title('Original Image Slice')

plt.subplot(1, 2, 2)
plt.imshow(transformed_slice.T, aspect=1/6)
plt.title('Transformed Image Slice')

plt.show()


This seems working.

Let's try to apply the transformation voxel for voxel!

### Padding
In order to have enough space to apply the rotation in all dimension, padding might be useful

In [None]:
def padding(original_array):
    # Find the maximum dimension
    max_dim = max(original_array.shape)

    # Calculate padding for each dimension (left and right)
    padding_x_left = (max_dim - original_array.shape[0]) // 2
    padding_x_right = max_dim - original_array.shape[0] - padding_x_left

    padding_y_left = (max_dim - original_array.shape[1]) // 2
    padding_y_right = max_dim - original_array.shape[1] - padding_y_left

    padding_z_left = (max_dim - original_array.shape[2]) // 2
    padding_z_right = max_dim - original_array.shape[2] - padding_z_left

    # Pad the array with zeros
    padded_array = np.pad(original_array, ((padding_x_left, padding_x_right), 
                                        (padding_y_left, padding_y_right), 
                                        (padding_z_left, padding_z_right)), 
                        mode='constant')

    # Verify the shapes
    print("Original Array Shape:", original_array.shape)
    print("Padded Array Shape:", padded_array.shape)

    return padded_array

In [None]:
interest_data = mri_1_label_data[20:490,530:,:]
interest_data.shape

In [None]:
padded_array = padding(interest_data)

In [None]:
np.max(padded_array)
show_slices(padded_array, 400, 5, 1, aspect=1)

### Regular Grid Interpolation
In the following I tried out a tutorial that uses RegularGridInterpolatin, which can be found [here](https://medium.com/vitrox-publication/rotation-of-voxels-in-3d-space-using-python-c3b2fc0afda1).

Voxel is a 3D equivalent of a pixel in a 2D image (VOlume piXEL). The data is represented by a 3D array where the value of a specific element (voxel) in the array represents some physical properties (color, density) in the space.

The rotational operation for 3D volumetric data encoded in these 2 formats can be easily achieved by multiplying a rotational matrix to the coordinates of the points. However, the rotation matrix cannot work directly on voxel data as the data are not representing coordinates. However, the coordinate of a voxel is derived from its relative position in the 3D array, a coordinate system can be constructed and the rotation can be performed by rotating the coordinate system in the opposite direction.

For 3D volumetric data consisting of voxels (or 3D array), the rotation operation can only be achieved through ndimage module of the SciPy package 

In [None]:
# Create 3D coordinate grids
ex_x = np.linspace(0, 5, 6)
ex_y = np.linspace(0, 3, 4)
ex_z = np.linspace(0, 2, 3)

ex_xx, ex_yy, ex_zz = np.meshgrid(ex_x, ex_y, ex_z, indexing='ij')

print("xx: \n", ex_xx[0,:3,:3])
print("yy: \n", ex_yy[:3,0,:3])
print("zz: \n", ex_zz[:3,:3,0])

# Assume the center of the coordinate system is (2, 1, 1)
x_center, y_center, z_center = 2, 1, 1

# Shift the coordinate system to have the center at (2, 1, 1)
ex_coor = np.array([ex_xx - x_center, ex_yy - y_center, ex_zz - z_center])

print("xx centered: \n", ex_coor[0,0,:3,:3])
print("yy centered: \n", ex_coor[1,:3,0,:3])
print("zz centered: \n", ex_coor[2,:3,:3,0])

# Plot original and shifted coordinates
fig = plt.figure(figsize=(15, 5))

# Original coordinates
ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(ex_xx, ex_yy, ex_zz, c='b', label='Original Coordinates')
ax1.set_title('Original Coordinates')
ax1.set_xlabel('X-axis')
ax1.set_ylabel('Y-axis')
ax1.set_zlabel('Z-axis')
ax1.legend()

# Shifted coordinates
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(ex_coor[0], ex_coor[1], ex_coor[2], c='r', label='Shifted Coordinates')
ax2.set_title('Shifted Coordinates (Center at (2, 1, 1))')
ax2.set_xlabel('X-axis')
ax2.set_ylabel('Y-axis')
ax2.set_zlabel('Z-axis')
ax2.legend()

plt.tight_layout()
plt.show()


We want to do the same as in this example with our multi dimensional array

In [None]:
from scipy.interpolate import RegularGridInterpolator

trans_mat = np.eye(3)
image = padded_array
print(image.shape)

Let's create a meshgrid for every dimension

In [None]:
# def img_rotation(trans_mat, image):
# Construct the coordinate (𝑥𝑖,𝑦𝑖,𝑧𝑖) for all the voxels with mesh grid from NumPy.
Nx, Ny, Nz = image.shape
x = np.linspace(0, Nx - 1, Nx)
y = np.linspace(0, Ny - 1, Ny)
z = np.linspace(0, Nz - 1, Nz)
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')

In [None]:
print("xx: \n", xx[0,:5,:5])
print("yy: \n", yy[:5,0,:5])
print("zz: \n", zz[:5,:5,0])

Then let's shift the origin of the meshgrid to the center of it

In [None]:
mri_vox_center = (np.array(image.shape) - 1) // 2
x_center, y_center, z_center = mri_vox_center
print("Voxel center: ", mri_vox_center)
coor = np.array([xx - x_center, yy - y_center, zz - z_center])

In [None]:
print("coor xx: \n", coor[0,0,:5,:5])
print("coor yy: \n", coor[1,:5,0,:5])
print("coor zz: \n", coor[2,:5,:5,0])


In [None]:
print("Coor shape: ", coor.shape)
center = np.where(coor == 0)
print("center: ", np.shape(center))

As we can see, there are 613'254 null entries, which makes sense. This is a plane of zeros for every dimension

In [None]:
606*864 + 864*61 + 61*606

In [None]:
print("X center: ", coor[0, x_center, y_center, z_center])
print("Y center: ", coor[1, x_center, y_center, z_center])
print("Z center: ", coor[2, x_center, y_center, z_center])

print("coor xx at center: \n", coor[0,x_center,:5,:5])
print("coor yy at center: \n", coor[1,:5,y_center,:5])
print("coor zz at center: \n", coor[2,:5,:5,z_center])

Now we need to apply the rotation to the coordinates

In [None]:
# Define a 3x3 rotation matrix (example: 45 degrees around the z-axis)
theta = np.radians(45)
rotation_matrix = np.array([
    [np.cos(theta), -np.sin(theta), 0],
    [np.sin(theta), np.cos(theta), 0],
    [0, 0, 1]
])

# Apply rotation to the coor_reshaped matrix
coor_reshaped = ex_coor.reshape(3, -1)
rotated_coor_reshaped = np.dot(rotation_matrix, coor_reshaped)

# Reshape back to the original shape
rotated_coor = rotated_coor_reshaped.reshape(3, *ex_coor.shape[1:])

# Plot original and rotated coordinates
fig = plt.figure(figsize=(15, 5))

# Original coordinates
ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(ex_xx, ex_yy, ex_zz, c='b', label='Original Coordinates')
ax1.set_title('Original Coordinates')
ax1.set_xlabel('X-axis')
ax1.set_ylabel('Y-axis')
ax1.set_zlabel('Z-axis')
ax1.legend()

# Rotated coordinates
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(rotated_coor[0], rotated_coor[1], rotated_coor[2], c='r', label='Rotated Coordinates')
ax2.set_title('Rotated Coordinates')
ax2.set_xlabel('X-axis')
ax2.set_ylabel('Y-axis')
ax2.set_zlabel('Z-axis')
ax2.legend()

plt.tight_layout()
plt.show()

Here we need to do the same thing

In [None]:
# Apply rotation to the coor_reshaped matrix
coor_reshaped = coor.reshape(3, -1)

# Evaluate the new coordinate (𝑥𝑖′,𝑦𝑖′,𝑧𝑖′) by multiplying matrix 𝑀−1 to the original coordinate (𝑥𝑖,𝑦𝑖,𝑧𝑖).
rotated_coor_reshaped = np.dot(trans_mat, coor_reshaped)

# Reshape back to the original shape
coor_prime = rotated_coor_reshaped.reshape(3, *coor.shape[1:])

print(coor_prime.shape)

In [None]:
xx_prime = coor_prime[0] + x_center
yy_prime = coor_prime[1] + y_center
zz_prime = coor_prime[2] + z_center

In [None]:

# Identify the set of points (voxels) that require interpolation, 
# eliminate the points with new coordinates which lie beyond the region bounded by the cuboid of the original volume, 
# i.e. 𝑥′𝑖∉{𝑝|0≤𝑝≤𝑁𝑥−1}，𝑦′𝑖∉{𝑝|0≤𝑝≤𝑁𝑦−1}，𝑧′𝑖∉{𝑝|0≤𝑝≤𝑁𝑧−1}.
x_valid1 = xx_prime>=0
x_valid2 = xx_prime<=Nx-1
y_valid1 = yy_prime>=0
y_valid2 = yy_prime<=Ny-1
z_valid1 = zz_prime>=0
z_valid2 = zz_prime<=Nz-1
valid_voxel = x_valid1 * x_valid2 * y_valid1 * y_valid2 * z_valid1 * z_valid2
x_valid_idx, y_valid_idx, z_valid_idx = np.where(valid_voxel > 0)

In [None]:

# Initialize a 3D array with size the same as the original 3D array (use for storing transformed results).
image_transformed_data = np.zeros((Nx, Ny, Nz))

# Interpolate using the function scipy.interpolate.RegularGridInterpolator(), return the values to the transformed array according to their respective indices.
data_w_coor = RegularGridInterpolator((x,y,z), image, method="nearest")
interp_points = np.array([xx_prime[x_valid_idx, y_valid_idx, z_valid_idx],
                          yy_prime[x_valid_idx, y_valid_idx, z_valid_idx],
                          zz_prime[x_valid_idx, y_valid_idx, z_valid_idx],]).T
interp_result = data_w_coor(interp_points)
image_transformed_data[x_valid_idx, y_valid_idx, z_valid_idx] = interp_result

In [None]:
padding_difference = np.array(padded_array.shape) - np.array(interest_data.shape)
padding_difference

In [None]:
np.savetxt("../outputs/original_array.txt", mri_1_label_data[:, mri_1_label_data.shape[1]-200, :], fmt='%.3f')
np.savetxt("../outputs/transformed_array.txt", image_transformed_data[:, :, image_transformed_data.shape[1]-(130+padding_difference[1]//2)], fmt='%.3f')


In [None]:
print(np.max(image_transformed_data))
# Display the original and transformed images (slices)
original_slice = padded_array[:, padded_array.shape[1]-240, :]
transformed_slice = image_transformed_data[:, image_transformed_data.shape[1]-240, :]

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_slice, origin='upper')
plt.title('Original Image Slice')

plt.subplot(1, 2, 2)
plt.imshow(transformed_slice, origin='upper')
plt.title('Transformed Image Slice')

plt.show()

In [None]:
image_transformed = nib.Nifti1Image(image_transformed_data, np.eye(4))
nib.save(image_transformed, '../outputs/image_transformed.nii.gz')