In [1]:
import nibabel as nib
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from nilearn import plotting
from PIL import Image
import math

# Create output folder if it doesn't exist
output_dir = 'state_maps'
os.makedirs(output_dir, exist_ok=True)

In [2]:
state_vectors = pd.read_csv("state-activation-profiles/mean_activations_k14.csv")

network_names = [col.replace('_mean_activation', '') for col in state_vectors.columns if '_mean_activation' in col]

# Select only columns ending with '_mean_activation'
activation_cols = [col for col in state_vectors.columns if col.endswith('_mean_activation')]

# Extract each row as a list of activations and store in state_vector_list
state_vector_list = state_vectors[activation_cols].values.tolist()

In [3]:
# Load network masks for each of the 14 Shirer component networks
network_masks = [nib.load(f'shirer-masks/network-{name}_space-MNI152NLin2009cAsym_res-02_thresh-01_mask.nii').get_fdata() for name in network_names]
    
state_map_list = []
for i in range(14):
    # Compute state map
    state_map = sum(w * m for w, m in zip(state_vector_list[i], network_masks))
    state_map_list.append(state_map)

for i in range(14):
    # Save to NIfTI
    nifti_img = nib.Nifti1Image(state_map_list[i], 
                                affine=nib.load(f'shirer-masks/network-{network_names[0]}_space-MNI152NLin2009cAsym_res-02_thresh-01_mask.nii').affine) # arbitrarily use one of the networks for info
    nib.save(nifti_img, f'state_maps/state_{i+1}.nii.gz')

In [4]:
n_states = 14
state_map_dir = "state_maps"
output_dir = "state_views"
os.makedirs(output_dir, exist_ok=True)

for i in range(n_states):
    state_path = os.path.join(state_map_dir, f"state_{i+1}.nii.gz")
    output_path = os.path.join(output_dir, f"state_{i+1}_views.png")
    
    plotting.plot_stat_map(
        state_path,
        display_mode='ortho',  # sagittal, coronal, axial
        cut_coords=(0, 0, 0),  # or use 'auto'
        title=f"State {i+1}",
        output_file=output_path,
        colorbar=True,
        vmin=-3,
        vmax=3,
        threshold=0.5
    )

In [6]:
# Where the per-state images are saved
input_dir = "state_views"
output_path = "state_views/state_maps_summary.png"

# Load all 14 images
images = [Image.open(os.path.join(input_dir, f"state_{i+1}_views.png")) for i in range(14)]

# Image dimensions (assumes all are same size)
img_width, img_height = images[0].size

# Grid layout
n_cols = 4
n_rows = math.ceil(len(images) / n_cols)

# Create blank canvas
summary_image = Image.new('RGB', (img_width * n_cols, img_height * n_rows), color=(255, 255, 255))

# Paste each image at the correct (x, y) position
for idx, img in enumerate(images):
    row = idx // n_cols
    col = idx % n_cols
    x = col * img_width
    y = row * img_height
    summary_image.paste(img, (x, y))

# Save the final image
summary_image.save(output_path)
print(f"Saved summary to: {output_path}")

Saved summary to: state_views/state_maps_summary.png
