# Intrasession

In [None]:
import nrrd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import zscore
from scipy.spatial.distance import jensenshannon

def load_nrrd(file_path):
    """Loads an NRRD file and returns the data and header."""
    data, header = nrrd.read(file_path)
    return data, header

def get_segment_data(image, segmentation, segment_id):
    """Extracts the image data corresponding to a specific segment."""
    return image[segmentation == segment_id]

def calculate_jsd(p, q):
    p_hist = np.histogram(p, bins=100, range=(p.min(), p.max()), density=True)[0]
    q_hist = np.histogram(q, bins=100, range=(p.min(), p.max()), density=True)[0]
    return jensenshannon(p_hist, q_hist, base=2) ** 2


def summarize_segment(segment_data):
    """Summarizes the data for a segment."""
    return {
        "Mean": np.mean(segment_data),
        "Std Dev": np.std(segment_data),
        "Min": np.min(segment_data),
        "Max": np.max(segment_data),
        "Voxel Count": len(segment_data)
    }

def plot_combined_histograms(
    test_image, 
    retest_image, 
    test_image_z,    
    retest_image_z,  
    test_segmentation, 
    retest_segmentation, 
    segment_names
):
    """
    Creates a combined plot of histograms for unnormalized and normalized data.
    Normalization is done at the image level before segment extraction.
    """

    fig, axes = plt.subplots(2, len(segment_names), figsize=(18, 12), sharey='row')
    max_density = {0: 0, 1: 0} 
    segment_summaries = []

    for idx, (segment_id, segment_name) in enumerate(segment_names.items()):
        # ----------------------------------------------------
        # 1) Get the unnormalized data for the current segment
        # ----------------------------------------------------
        test_segment_data = get_segment_data(test_image, test_segmentation, segment_id)
        retest_segment_data = get_segment_data(retest_image, retest_segmentation, segment_id)

        # Calculate JSD
        jsd_unnormalized = calculate_jsd(test_segment_data, retest_segment_data)

        # Summarize unnormalized segments
        segment_summaries.append({
            "Segment": segment_name,
            "Type": "Unnormalized",
            "Test Mean": np.mean(test_segment_data),
            "Test Std Dev": np.std(test_segment_data),
            "Test Min": np.min(test_segment_data),
            "Test Max": np.max(test_segment_data),
            "Test Voxel Count": len(test_segment_data),
            "Retest Mean": np.mean(retest_segment_data),
            "Retest Std Dev": np.std(retest_segment_data),
            "Retest Min": np.min(retest_segment_data),
            "Retest Max": np.max(retest_segment_data),
            "Retest Voxel Count": len(retest_segment_data),
            "JSD": jsd_unnormalized
        })

        # Plot unnormalized data
        hist_test = axes[0, idx].hist(test_segment_data, bins=100, alpha=0.6, label='Test', color='blue', density=True)
        hist_retest = axes[0, idx].hist(retest_segment_data, bins=100, alpha=0.6, label='Retest', color='orange', density=True)
        axes[0, idx].set_title(f"Segment {segment_name}", fontweight='bold')
        axes[0, idx].text(
            0.95, 0.9, f"JSD = {jsd_unnormalized:.4f}",
            transform=axes[0, idx].transAxes, ha='right', va='top', fontsize=10
        )

        # Update maximum density for row 0
        max_density[0] = max(max_density[0], max(hist_test[0].max(), hist_retest[0].max()))

        # ---------------------------------------------------
        # 2) Get the globally normalized data for the segment
        # ---------------------------------------------------

        test_segment_data_z = get_segment_data(test_image_z, test_segmentation, segment_id)
        retest_segment_data_z = get_segment_data(retest_image_z, retest_segmentation, segment_id)

        # Calculate JSD for normalized data
        jsd_normalized = calculate_jsd(test_segment_data_z, retest_segment_data_z)

        # Summarize normalized segments
        segment_summaries.append({
            "Segment": segment_name,
            "Type": "Normalized",
            "Test Mean": np.mean(test_segment_data_z),
            "Test Std Dev": np.std(test_segment_data_z),
            "Test Min": np.min(test_segment_data_z),
            "Test Max": np.max(test_segment_data_z),
            "Test Voxel Count": len(test_segment_data_z),
            "Retest Mean": np.mean(retest_segment_data_z),
            "Retest Std Dev": np.std(retest_segment_data_z),
            "Retest Min": np.min(retest_segment_data_z),
            "Retest Max": np.max(retest_segment_data_z),
            "Retest Voxel Count": len(retest_segment_data_z),
            "JSD": jsd_normalized
        })


        hist_test_z = axes[1, idx].hist(test_segment_data_z, bins=100, alpha=0.6, label='Test', color='blue', density=True)
        hist_retest_z = axes[1, idx].hist(retest_segment_data_z, bins=100, alpha=0.6, label='Retest', color='orange', density=True)
        axes[1, idx].text(
            0.95, 0.9, f"JSD = {jsd_normalized:.4f}",
            transform=axes[1, idx].transAxes, ha='right', va='top', fontsize=10
        )


        max_density[1] = max(max_density[1], max(hist_test_z[0].max(), hist_retest_z[0].max()))


    for row in range(2):
        for col in range(len(segment_names)):
            axes[row, col].set_ylim(0, max_density[row] * 1.1)


    axes[0, 0].set_ylabel("Density", fontsize=12, labelpad=10)
    axes[0, 0].annotate(
        "Non-Normalised", xy=(-0.4, 0.5), xycoords='axes fraction',
        fontweight='bold', fontsize=16, ha='center', va='center', rotation='vertical'
    )

    axes[1, 0].set_ylabel("Density", fontsize=12, labelpad=10)
    axes[1, 0].annotate(
        "Normalised", xy=(-0.4, 0.5), xycoords='axes fraction',
        fontsize=16, fontweight='bold', ha='center', va='center', rotation='vertical'
    )


    legend = fig.legend(['Test', 'Retest'], loc='lower center', fontsize=14, frameon=True, framealpha=1, ncol=2, bbox_to_anchor=(0.53, 0.47))
    legend.get_texts()[0].set_fontweight('bold')
    legend.get_texts()[1].set_fontweight('bold')
    
    plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])
    plt.subplots_adjust(hspace=0.4)
    plt.savefig("intrasession_histogram.png", dpi=300, format="png")
    plt.savefig("intrasession_histogram.pdf", dpi=300, format="pdf")
    plt.savefig("intrasession_histogram.svg", dpi=300, format="svg")
    plt.savefig("intrasession_histogram.eps", dpi=300, format="eps")
    plt.show()

    return pd.DataFrame(segment_summaries)

def process_and_visualize(test_image_path, retest_image_path, test_segmentation_path, retest_segmentation_path):
    """Loads the images and segmentations, and visualizes histograms for the segments."""

    test_image, _ = load_nrrd(test_image_path)
    retest_image, _ = load_nrrd(retest_image_path)
    test_segmentation, test_header = load_nrrd(test_segmentation_path)
    retest_segmentation, retest_header = load_nrrd(retest_segmentation_path)

    test_image_mean = np.mean(test_image)      
    test_image_std = np.std(test_image)        
    test_image_z = (test_image - test_image_mean) / test_image_std

    retest_image_mean = np.mean(retest_image)  
    retest_image_std = np.std(retest_image)    
    retest_image_z = (retest_image - retest_image_mean) / retest_image_std

    segment_names = {
        int(key.split('_')[0].replace('Segment', '')): value
        for key, value in test_header.items()
        if key.endswith('_Name') and key.startswith('Segment')
    }

    segment_label_mapping = {
        int(test_header[f"Segment{key}_LabelValue"]): value
        for key, value in segment_names.items()
    }

    segment_label_mapping = {
        label: name for label, name in segment_label_mapping.items() if label != 0
    }

    print(f"Segment Names and Labels (excluding background): {segment_label_mapping}")


    segment_summary_df = plot_combined_histograms(
        test_image,
        retest_image,
        test_image_z,   
        retest_image_z, 
        test_segmentation,
        retest_segmentation,
        segment_label_mapping
    )


    segment_summary_df.to_excel('Intrasession_segment_statistics.xlsx')
    display(segment_summary_df)

    return segment_summary_df


test_image_path = r"\PATH.nrrd"
retest_image_path = r"\PATH.nrrd"
test_segmentation_path = r"\PATH.nrrd"
retest_segmentation_path = r"\PATH.nrrd"

if __name__ == "__main__":
    segment_summary_df = process_and_visualize(
        test_image_path, 
        retest_image_path, 
        test_segmentation_path, 
        retest_segmentation_path
    )


# Intersession

In [None]:
import nrrd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import zscore
from scipy.spatial.distance import jensenshannon

def load_nrrd(file_path):
    """Loads an NRRD file and returns the data and header."""
    data, header = nrrd.read(file_path)
    return data, header

def get_segment_data(image, segmentation, segment_id):
    """Extracts the image data corresponding to a specific segment."""
    return image[segmentation == segment_id]

def calculate_jsd(p, q):
    p_hist = np.histogram(p, bins=100, range=(p.min(), p.max()), density=True)[0]
    q_hist = np.histogram(q, bins=100, range=(p.min(), p.max()), density=True)[0]
    return jensenshannon(p_hist, q_hist, base=2) ** 2


def summarize_segment(segment_data):
    """Summarizes the data for a segment."""
    return {
        "Mean": np.mean(segment_data),
        "Std Dev": np.std(segment_data),
        "Min": np.min(segment_data),
        "Max": np.max(segment_data),
        "Voxel Count": len(segment_data)
    }

def plot_combined_histograms(
    test_image, 
    retest_image, 
    test_image_z,    
    retest_image_z,  
    test_segmentation, 
    retest_segmentation, 
    segment_names
):
    """
    Creates a combined plot of histograms for unnormalized and normalized data.
    Normalization is done at the image level before segment extraction.
    """

    fig, axes = plt.subplots(2, len(segment_names), figsize=(18, 12), sharey='row')
    max_density = {0: 0, 1: 0} 
    segment_summaries = []

    for idx, (segment_id, segment_name) in enumerate(segment_names.items()):
        # ----------------------------------------------------
        # 1) Get the unnormalized data for the current segment
        # ----------------------------------------------------
        test_segment_data = get_segment_data(test_image, test_segmentation, segment_id)
        retest_segment_data = get_segment_data(retest_image, retest_segmentation, segment_id)

        # Calculate JSD for unnormalized data
        jsd_unnormalized = calculate_jsd(test_segment_data, retest_segment_data)

        # Summarize unnormalized segments
        segment_summaries.append({
            "Segment": segment_name,
            "Type": "Unnormalized",
            "Test Mean": np.mean(test_segment_data),
            "Test Std Dev": np.std(test_segment_data),
            "Test Min": np.min(test_segment_data),
            "Test Max": np.max(test_segment_data),
            "Test Voxel Count": len(test_segment_data),
            "Retest Mean": np.mean(retest_segment_data),
            "Retest Std Dev": np.std(retest_segment_data),
            "Retest Min": np.min(retest_segment_data),
            "Retest Max": np.max(retest_segment_data),
            "Retest Voxel Count": len(retest_segment_data),
            "JSD": jsd_unnormalized
        })

        hist_test = axes[0, idx].hist(test_segment_data, bins=100, alpha=0.6, label='Test', color='blue', density=True)
        hist_retest = axes[0, idx].hist(retest_segment_data, bins=100, alpha=0.6, label='Retest', color='orange', density=True)
        axes[0, idx].set_title(f"Segment {segment_name}", fontweight='bold')
        axes[0, idx].text(
            0.95, 0.9, f"JSD = {jsd_unnormalized:.4f}",
            transform=axes[0, idx].transAxes, ha='right', va='top', fontsize=10
        )

        max_density[0] = max(max_density[0], max(hist_test[0].max(), hist_retest[0].max()))

        # ---------------------------------------------------
        # 2) Get the globally normalized data for the segment
        # ---------------------------------------------------
        test_segment_data_z = get_segment_data(test_image_z, test_segmentation, segment_id)
        retest_segment_data_z = get_segment_data(retest_image_z, retest_segmentation, segment_id)

        # Calculate JSD for normalized data
        jsd_normalized = calculate_jsd(test_segment_data_z, retest_segment_data_z)

        # Summarize normalized segments
        segment_summaries.append({
            "Segment": segment_name,
            "Type": "Normalized",
            "Test Mean": np.mean(test_segment_data_z),
            "Test Std Dev": np.std(test_segment_data_z),
            "Test Min": np.min(test_segment_data_z),
            "Test Max": np.max(test_segment_data_z),
            "Test Voxel Count": len(test_segment_data_z),
            "Retest Mean": np.mean(retest_segment_data_z),
            "Retest Std Dev": np.std(retest_segment_data_z),
            "Retest Min": np.min(retest_segment_data_z),
            "Retest Max": np.max(retest_segment_data_z),
            "Retest Voxel Count": len(retest_segment_data_z),
            "JSD": jsd_normalized
        })

        hist_test_z = axes[1, idx].hist(test_segment_data_z, bins=100, alpha=0.6, label='Test', color='blue', density=True)
        hist_retest_z = axes[1, idx].hist(retest_segment_data_z, bins=100, alpha=0.6, label='Retest', color='orange', density=True)
        axes[1, idx].text(
            0.95, 0.9, f"JSD = {jsd_normalized:.4f}",
            transform=axes[1, idx].transAxes, ha='right', va='top', fontsize=10
        )

        max_density[1] = max(max_density[1], max(hist_test_z[0].max(), hist_retest_z[0].max()))

    for row in range(2):
        for col in range(len(segment_names)):
            axes[row, col].set_ylim(0, max_density[row] * 1.1)  

    axes[0, 0].set_ylabel("Density", fontsize=12, labelpad=10)
    axes[0, 0].annotate(
        "Non-Normalised", xy=(-0.4, 0.5), xycoords='axes fraction',
        fontweight='bold', fontsize=16, ha='center', va='center', rotation='vertical'
    )

    axes[1, 0].set_ylabel("Density", fontsize=12, labelpad=10)
    axes[1, 0].annotate(
        "Normalised", xy=(-0.4, 0.5), xycoords='axes fraction',
        fontsize=16, fontweight='bold', ha='center', va='center', rotation='vertical'
    )

    legend = fig.legend(['Test', 'Retest'], loc='lower center', fontsize=14, frameon=True, framealpha=1, ncol=2, bbox_to_anchor=(0.53, 0.47))
    legend.get_texts()[0].set_fontweight('bold')  
    legend.get_texts()[1].set_fontweight('bold')  
    
    plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])
    plt.subplots_adjust(hspace=0.4)
    plt.savefig("intersession_histogram.png", dpi=300, format="png")
    plt.savefig("intersession_histogram.pdf", dpi=300, format="pdf")
    plt.savefig("intersession_histogram.svg", dpi=300, format="svg")
    plt.savefig("intersession_histogram.eps", dpi=300, format="eps")
    plt.show()

    return pd.DataFrame(segment_summaries)

def process_and_visualize(test_image_path, retest_image_path, test_segmentation_path, retest_segmentation_path):
    """Loads the images and segmentations, and visualizes histograms for the segments."""
    test_image, _ = load_nrrd(test_image_path)
    retest_image, _ = load_nrrd(retest_image_path)
    test_segmentation, test_header = load_nrrd(test_segmentation_path)
    retest_segmentation, retest_header = load_nrrd(retest_segmentation_path)

    # ---------------------------------------------------
    # Perform global z-score normalization on entire images
    # ---------------------------------------------------
    test_image_mean = np.mean(test_image)   
    test_image_std = np.std(test_image) 
    test_image_z = (test_image - test_image_mean) / test_image_std

    retest_image_mean = np.mean(retest_image) 
    retest_image_std = np.std(retest_image)  
    retest_image_z = (retest_image - retest_image_mean) / retest_image_std

    segment_names = {
        int(key.split('_')[0].replace('Segment', '')): value
        for key, value in test_header.items()
        if key.endswith('_Name') and key.startswith('Segment')
    }

    segment_label_mapping = {
        int(test_header[f"Segment{key}_LabelValue"]): value
        for key, value in segment_names.items()
    }

    segment_label_mapping = {
        label: name for label, name in segment_label_mapping.items() if label != 0
    }

    print(f"Segment Names and Labels (excluding background): {segment_label_mapping}")

    segment_summary_df = plot_combined_histograms(
        test_image,
        retest_image,
        test_image_z,       
        retest_image_z,     
        test_segmentation,
        retest_segmentation,
        segment_label_mapping
    )

    segment_summary_df.to_excel('Intersession_segment_statistics.xlsx')
    display(segment_summary_df)

    return segment_summary_df

test_image_path = r"\PATH.nrrd"
retest_image_path = r"\PATH.nrrd"
test_segmentation_path = r"\PATH.nrrd"
retest_segmentation_path = r"\PATH.nrrd"

if __name__ == "__main__":
    segment_summary_df = process_and_visualize(
        test_image_path, 
        retest_image_path, 
        test_segmentation_path, 
        retest_segmentation_path
    )
