In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import nibabel as nib
from scipy.stats import entropy
import matplotlib.patches as patches  # Import the patches module

In [None]:

#img = nib.load('/home/yasmine/OASIS3/CNN/TKDD Paper/MR/sub-OAS30001_ses-d0129_run-02_T1w.nii_reoriented.nii_brain_jacobian.nii.gz')
img = nib.load('/home/yasmine/OASIS3/CNN/concat2/sub-OASsub-OAS30001_concat.nii.gz')
jacobian_data = img.get_fdata()


In [None]:
# Define the patch size (e.g., 16x16x16)
patch_size = (16, 16, 16)

# Initialize lists to store patch information and their positions
patch_information = []
patch_positions = []

# Loop through the 3D image and extract patches
for z in range(0, jacobian_data.shape[0], patch_size[0]):
    for y in range(0, jacobian_data.shape[1], patch_size[1]):
        for x in range(0, jacobian_data.shape[2], patch_size[2]):
            # Extract a patch
            patch = jacobian_data[z:z + patch_size[0], y:y + patch_size[1], x:x + patch_size[2]]
            
            # Calculate information measure (e.g., count of non-1 values)
            #information = np.count_nonzero(patch != 1)
            #information =  entropy(patch.ravel()) 
            # Calculate the number of unique values in the patch
            #unique_values = np.unique(patch)
            #num_unique_values = len(unique_values)
            std_dev = np.std(patch)
            print(patch.shape)
            
            # Store patch information and position
            patch_information.append(std_dev)
            patch_positions.append((z, y, x))


In [None]:
# Sort patches based on information measure (from most info to less info)
sorted_indices = np.argsort(patch_information)[::-1]
sorted_patches = [patch_positions[i] for i in sorted_indices]
sorted_information = [patch_information[i] for i in sorted_indices]


In [None]:
# Print the information measure and values for each patch
for i, (information, position) in enumerate(zip(sorted_information, sorted_patches)):
    z, y, x = position
    patch = jacobian_data[z:z + patch_size[0], y:y + patch_size[1], x:x + patch_size[2]]
    
    print(f"Patch {i + 1}: Information = {information}")
    print("Patch Values:")
    print(patch)  # Print the values of the patch
    
    # Visualize the patch (optional)
    plt.figure()
    plt.title(f"Patch {i + 1}")
    plt.imshow(patch[:, :, 0], cmap='gray', interpolation='nearest')
    plt.colorbar()
    plt.show()


In [None]:

# Create a heatmap-like plot to visualize all information measures
information_map = np.zeros(jacobian_data.shape)
for i, position in enumerate(sorted_patches):
    z, y, x = position
    information_map[z:z + patch_size[0], y:y + patch_size[1], x:x + patch_size[2]] = sorted_information[i]

# Display the information map
plt.figure(figsize=(10, 8))
plt.imshow(information_map[:, :, 0], cmap='viridis', interpolation='nearest')
plt.colorbar(label='Information Measure')
plt.title('Information Measures for Patches')
plt.axis('off')
plt.show()







In [None]:


# Visualize the sorted patches by highlighting spots in the original image and patch
def highlight_patches(image, patches_to_highlight):
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Display the original image
    ax.imshow(image[:, :, 0], cmap='gray', interpolation='nearest')
    
    for position in patches_to_highlight:
        z, y, x = position
        
        # Highlight the patch area in the original image
        rect = patches.Rectangle((x, y), patch_size[2], patch_size[1], linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        
        # Highlight the same patch in the patch image (for visualization)
        patch_rect = patches.Rectangle((0, 0), patch_size[2], patch_size[1], linewidth=1, edgecolor='r', facecolor='none')
        axins = ax.inset_axes([x / image.shape[2], y / image.shape[1], patch_size[2] / image.shape[2], patch_size[1] / image.shape[1]])
        axins.imshow(patch[:, :, 0], cmap='gray', interpolation='nearest')
        axins.add_patch(patch_rect)
    
    ax.set_title('Highlighted Patches in Original Image')
    plt.show()

# Visualize the top N most informative patches (e.g., top 3)
top_n = 3
highlight_patches(jacobian_data, sorted_patches[:top_n])

# Print the information measure for the top patches
for i in range(top_n):
    print(f"Patch {i + 1}: Information = {sorted_information[i]}")


In [None]:
# Visualize the sorted patches in three views
def visualize_three_views(image, patches_to_highlight):
    # Create coronal, sagittal, and axial views
    coronal_view = np.max(image, axis=2)
    sagittal_view = np.max(image, axis=1)
    axial_view = np.max(image, axis=0)
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Display the coronal view
    axes[0].imshow(coronal_view, cmap='gray', interpolation='nearest')
    axes[0].set_title('Coronal View')
    
    # Display the sagittal view
    axes[1].imshow(sagittal_view, cmap='gray', interpolation='nearest')
    axes[1].set_title('Sagittal View')
    
    # Display the axial view
    axes[2].imshow(axial_view, cmap='gray', interpolation='nearest')
    axes[2].set_title('Axial View')
    
    for position in patches_to_highlight:
        z, y, x = position
        
        # Highlight the patch area in each view
        for ax in axes:
            rect = patches.Rectangle((x - patch_size[2] // 2, y - patch_size[1] // 2), 
                                     patch_size[2], patch_size[1], 
                                     linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
    
    plt.show()

# Visualize the top N patches with highest entropy in three views (e.g., top 3)
top_n = 9
visualize_three_views(jacobian_data, sorted_patches[:top_n])