In [1]:
# 1_compute_and_save_with_bands_and_plots.ipynb

import numpy as np
from pathlib import Path
import pickle
import mne
from scipy import signal
from functools import lru_cache
import logging
import matplotlib.pyplot as plt
from dataclasses import dataclass
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
from tqdm import tqdm
import nibabel as nib

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('PSD_Compute')

@dataclass
class Config:
    PROJECT_BASE: str = '/home/jaizor/jaizor/xtra'
    SFREQ: float = 500.0
    PSD_WINDOW_SEC: float = 2.0
    OUTPUT_DIR: str = 'derivatives/psd_voxel_cache'
    FREQ_BANDS = {
        'Delta': (1, 4),
        'Theta': (4, 8),
        'Alpha': (8, 12),
        'Low_Beta': (12, 20),
        'High_Beta': (20, 30),
        'Low_Gamma': (30, 50),
        'High_Gamma': (50, 100)
    }

class PSDDataExtractor:
    def __init__(self, lcmv_base_dir, coordinates, region_names, config=None):
        self.lcmv_base_dir = Path(lcmv_base_dir)
        self.coordinates = np.array(coordinates)
        self.region_names = region_names
        self.config = config or Config()
        self.project_base = Path(self.config.PROJECT_BASE)
        self.output_dir = self.project_base / self.config.OUTPUT_DIR
        self.fig_dir = self.output_dir / "figures"
        self.fig_dir.mkdir(parents=True, exist_ok=True)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.subject_folders = self._find_subject_folders()
        self.subjects = self._extract_subjects()
        print(f"✅ Found {len(self.subjects)} subjects: {self.subjects}")

    def _find_subject_folders(self):
        return [f for f in self.lcmv_base_dir.iterdir() 
                if f.is_dir() and f.name.startswith('sub') and '_lcmv_' in f.name]

    def _extract_subjects(self):
        subjects = set()
        for folder in self.subject_folders:
            match = re.search(r'(sub\w+)_lcmv_', folder.name)
            if match:
                subjects.add(match.group(1))
        return sorted(subjects)

    def _parse_folder_name(self, folder_name):
        parts = folder_name.lower().split('_')
        try:
            lcmv_idx = parts.index('lcmv')
        except ValueError:
            return None, None
        condition_raw = '_'.join(parts[lcmv_idx+1:-1])
        med_state = parts[-1]
        
        condition_map = {
            'bima': 'bima_activity',
            'hands': 'hands_move',
            'rest': 'rest_eyes_closed',
            'eyes_closed': 'rest_eyes_closed',
            'eyes_open': 'rest_eyes_open'
        }
        condition = condition_map.get(condition_raw, condition_raw)
        if condition in ['bima_activity', 'hands_move', 'rest_eyes_closed', 'rest_eyes_open'] and med_state in ['off', 'on']:
            return condition, med_state
        return None, None

    @lru_cache(maxsize=4)
    def _load_stc(self, folder_path, condition):
        h5_path = folder_path / "source_estimate_LCMV.h5"
        if h5_path.exists():
            return mne.read_source_estimate(str(h5_path))
        meta_path = folder_path / "computation_metadata.pkl"
        if meta_path.exists():
            with open(meta_path, 'rb') as f:
                meta = pickle.load(f)
            stc_path = Path(meta['stc_file'])
            if not stc_path.exists():
                stc_path = folder_path / stc_path.name
            if stc_path.exists():
                return mne.read_source_estimate(str(stc_path))
        raise FileNotFoundError(f"STC not found in {folder_path}")

    @lru_cache(maxsize=1)
    def _load_source_points(self, folder_path):
        points_path = folder_path / "source_space_points_mm.npy"
        if points_path.exists():
            return np.load(points_path)
        raise FileNotFoundError(f"source_space_points_mm.npy not found in {folder_path}")

    def _find_nearest_voxels(self, source_points):
        results = []
        for coord in self.coordinates:
            distances = np.linalg.norm(source_points - coord, axis=1)
            idx = np.argmin(distances)
            results.append((idx, source_points[idx], distances[idx]))
        return results

    def _extract_3x3x3_avg(self, stc, voxel_idx, source_points):
        x_vals = np.unique(np.round(source_points[:, 0], 3))
        y_vals = np.unique(np.round(source_points[:, 1], 3))
        z_vals = np.unique(np.round(source_points[:, 2], 3))
        shape = (len(x_vals), len(y_vals), len(z_vals))

        def find_closest(array, value):
            return np.argmin(np.abs(array - value))

        x, y, z = source_points[voxel_idx]
        x_idx = find_closest(x_vals, x)
        y_idx = find_closest(y_vals, y)
        z_idx = find_closest(z_vals, z)

        neighbors = []
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    nx, ny, nz = x_idx + dx, y_idx + dy, z_idx + dz
                    if (0 <= nx < shape[0]) and (0 <= ny < shape[1]) and (0 <= nz < shape[2]):
                        try:
                            flat_idx = np.ravel_multi_index((nx, ny, nz), shape)
                            if flat_idx < stc.data.shape[0]:
                                neighbors.append(stc.data[flat_idx, :])
                        except ValueError:
                            continue
        return np.mean(neighbors, axis=0) if neighbors else stc.data[voxel_idx, :]

    def _compute_band_powers(self, freqs, psd):
        """Compute band powers from PSD."""
        band_powers = {}
        for band_name, (low, high) in self.config.FREQ_BANDS.items():
            mask = (freqs >= low) & (freqs <= high)
            power = np.mean(psd[mask]) if mask.any() else 0.0
            band_powers[band_name] = {
                'power': float(power),
                'freqs': freqs[mask],
                'psd': psd[mask]
            }
        return band_powers

    def _compute_psd_dual(self, ts):
        ts = np.real(ts)  # 🚨 ADD THIS LINE — CRITICAL FIX
        sfreq = self.config.SFREQ
        window_size = int(self.config.PSD_WINDOW_SEC * sfreq)
        if len(ts) < window_size:
            raise ValueError("Time series too short")

        n_windows = len(ts) // window_size
        trimmed = ts[:n_windows * window_size]
        nyq = sfreq * 0.5
        b, a = signal.butter(4, 0.5 / nyq, btype='high')
        filtered = signal.filtfilt(b, a, trimmed)

        results = {}

        # Multitaper
        try:
            from mne.time_frequency import psd_array_multitaper
            windows = filtered.reshape(n_windows, window_size)
            psds_mt = []
            for window in windows:
                psd_mt, freqs_mt = psd_array_multitaper(
                    x=window, sfreq=sfreq, fmin=1, fmax=100,
                    bandwidth=None, adaptive=False, normalization='length',
                    low_bias=True, verbose=False, n_jobs=1
                )
                psds_mt.append(psd_mt)
            psd_mt_avg = np.mean(psds_mt, axis=0)
            results['multitaper'] = {
                'freqs': freqs_mt.astype(np.float32),
                'psd': psd_mt_avg.astype(np.float32),
                'band_powers': self._compute_band_powers(freqs_mt, psd_mt_avg)
            }
        except Exception as e:
            logger.warning(f"Multitaper failed: {e}")

        # Welch
        freqs_w, psd_w = signal.welch(
            filtered, fs=sfreq, window='hann', nperseg=window_size,
            noverlap=window_size // 2, detrend='constant'
        )
        mask = (freqs_w >= 1) & (freqs_w <= 100)
        freqs_w = freqs_w[mask]
        psd_w = psd_w[mask]
        results['welch'] = {
            'freqs': freqs_w.astype(np.float32),
            'psd': psd_w.astype(np.float32),
            'band_powers': self._compute_band_powers(freqs_w, psd_w)
        }

        return results, filtered.astype(np.float32)  # Now safe to cast


    def _load_t1_image(self):
        t1_path = self.project_base / "derivatives/lcmv/fsaverage/mri/T1.mgz"
        if not t1_path.exists():
            raise FileNotFoundError(f"T1.mgz not found: {t1_path}")
        img = nib.load(str(t1_path))
        return nib.as_closest_canonical(img)

    def _plot_and_save_orthoview(self):
        """Plot and save orthogonal views."""
        try:
            img = self._load_t1_image()
            data = img.get_fdata()
            inv_affine = np.linalg.inv(img.affine)

            coords_array = np.array(self.coordinates)
            homog = np.column_stack([coords_array, np.ones(len(coords_array))])
            voxel_coords = (inv_affine @ homog.T).T[:, :3].round().astype(int)

            cx, cy, _ = voxel_coords.mean(axis=0).astype(int)

            fig, axes = plt.subplots(1, 2, figsize=(18, 7), dpi=120)

            colors = plt.cm.Set1(np.linspace(0, 1, len(self.region_names)))

            views = [
                (cx, data[cx, :, :], "Sagittal", "X", "Y (P ← → A)", "Z (I ← → S)", lambda v: (v[1], v[2])),
                (cy, data[:, cy, :], "Coronal", "Y", "X (L ← → R)", "Z (I ← → S)", lambda v: (v[0], v[2]))
            ]

            for ax_idx, (center, slice_data, view_name, axis_name, xlabel, ylabel, coord_func) in enumerate(views):
                ax = axes[ax_idx]
                if 0 <= center < slice_data.shape[0]:
                    ax.imshow(slice_data.T, cmap="gray", origin="lower")
                    ax.set_title(f"{view_name} | {axis_name} = {center}", fontsize=13, fontweight='bold')
                else:
                    ax.text(0.5, 0.5, "Out of Range", ha="center", color="red", transform=ax.transAxes)
                
                ax.set_xlabel(xlabel, fontsize=11)
                ax.set_ylabel(ylabel, fontsize=11)
                ax.set_xticks([])
                ax.set_yticks([])

                for voxel, color, name in zip(voxel_coords, colors, self.region_names):
                    plot_x, plot_y = coord_func(voxel)
                    ax.plot(plot_x, plot_y, 'o', color=color, ms=12, mfc='none', mew=2.5)
                    ax.axvline(plot_x, color=color, ls='--', alpha=0.5, lw=1)
                    ax.axhline(plot_y, color=color, ls='--', alpha=0.5, lw=1)

            fig.suptitle(f"Brain Locations: {len(self.coordinates)} Region(s)", fontsize=16, fontweight='bold', y=0.97)
            fig.subplots_adjust(right=0.85, wspace=0.35, top=0.85, left=0.08)

            # Save figure
            fig_path = self.fig_dir / "orthoview.png"
            fig.savefig(fig_path, dpi=150, bbox_inches='tight')
            plt.close(fig)
            print(f"🖼️ Saved orthoview plot: {fig_path}")

        except Exception as e:
            print(f"❌ Could not save orthoview: {e}")

    def _process_single_subject(self, args):
        folder, condition, med_state, subject_id = args
        try:
            stc = self._load_stc(folder, condition)
            source_points = self._load_source_points(folder)
            voxel_info = self._find_nearest_voxels(source_points)

            data = {
                'subject': subject_id,
                'condition': condition,
                'med_state': med_state,
                'region_names': self.region_names,
                'coordinates_requested': self.coordinates.tolist(),
                'results': []
            }

            for (voxel_idx, actual_coord, distance), name in zip(voxel_info, self.region_names):
                ts_3x3 = self._extract_3x3x3_avg(stc, voxel_idx, source_points)
                psd_results, filtered_ts = self._compute_psd_dual(ts_3x3)

                data['results'].append({
                    'region_name': name,
                    'voxel_idx': int(voxel_idx),
                    'actual_coord': actual_coord.tolist(),
                    'distance': float(distance),
                    'time_series': ts_3x3.astype(np.float32),
                    'filtered_ts': filtered_ts,
                    'psd_dual': psd_results  # includes 'band_powers' now!
                })

            return data

        except Exception as e:
            print(f"❌ Failed {folder.name}: {e}")
            return None

    def run(self, parallel=True, max_workers=None):
        """Run extraction and save all data + plots."""
        # Save orthoview first
        self._plot_and_save_orthoview()

        groups = {}
        for folder in self.subject_folders:
            condition, med_state = self._parse_folder_name(folder.name)
            if condition and med_state:
                groups.setdefault((condition, med_state), []).append(folder)

        all_data = []

        if max_workers is None:
            max_workers = min(mp.cpu_count(), 4)

        total = sum(len(fs) for fs in groups.values())
        with tqdm(total=total, desc="Processing subjects") as pbar:
            for (condition, med_state), folders in groups.items():
                print(f"\n📌 Processing {condition}_{med_state}...")
                args_list = [
                    (folder, condition, med_state, re.search(r'(sub\w+)_lcmv_', folder.name).group(1))
                    for folder in folders
                ]

                if parallel and len(folders) > 1:
                    with ProcessPoolExecutor(max_workers=max_workers) as executor:
                        futures = [executor.submit(self._process_single_subject, args) for args in args_list]
                        for future in as_completed(futures):
                            result = future.result()
                            if result:
                                all_data.append(result)
                            pbar.update(1)
                else:
                    for args in args_list:
                        result = self._process_single_subject(args)
                        if result:
                            all_data.append(result)
                        pbar.update(1)

        # Save all data
        output_file = self.output_dir / "psd_voxel_all_data.npz"
        np.savez_compressed(output_file, data=all_data, allow_pickle=True)
        print(f"\n🎉 SUCCESS! All data saved to:\n   {output_file}")
        return all_data




# Define your coordinates and region names
coordinates = [
    [-37, -9, 67], [-42, -20, 67], [38, -7, 66], [42, -18, 66],
    [40, -14, 46], [-11.89, -14.51, -6.40], [12.53, -13.97, -6.57],
    [-8, -76, 10], [7, -76, 10]
]

region_names = [
    "Cortical Hand Knob - Left Anterior",
    "Cortical Hand Knob - Left Posterior", 
    "Cortical Hand Knob - Right Anterior",
    "Cortical Hand Knob - Right Posterior",
    "Primary Motor Cortex (M1) - Right",
    "Subthalamic Nucleus (STN) - Left",
    "Subthalamic Nucleus (STN) - Right",
    "Primary Visual Cortex (V1) - Left",
    "Primary Visual Cortex (V1) - Right"
]

# Initialize
extractor = PSDDataExtractor(
    lcmv_base_dir="/home/jaizor/jaizor/xtra/derivatives/lcmv",
    coordinates=coordinates,
    region_names=region_names
)

# Run and save
results = extractor.run(parallel=True, max_workers=4)

✅ Found 12 subjects: ['sub1', 'sub10', 'sub11', 'sub12', 'sub14', 'sub2', 'sub3', 'sub5', 'sub6', 'sub7', 'sub8', 'sub9']
🖼️ Saved orthoview plot: /home/jaizor/jaizor/xtra/derivatives/psd_voxel_cache/figures/orthoview.png


Processing subjects:   0%|          | 0/75 [00:00<?, ?it/s]


📌 Processing rest_eyes_open_off...


Processing subjects:  15%|█▍        | 11/75 [00:10<00:51,  1.25it/s]


📌 Processing rest_eyes_closed_on...


Processing subjects:  28%|██▊       | 21/75 [00:17<00:35,  1.51it/s]


📌 Processing rest_eyes_closed_off...


Processing subjects:  44%|████▍     | 33/75 [00:27<00:28,  1.47it/s]


📌 Processing hands_move_on...


  'time_series': ts_3x3.astype(np.float32),
Processing subjects:  56%|█████▌    | 42/75 [00:33<00:19,  1.70it/s]


📌 Processing hands_move_off...


Processing subjects:  72%|███████▏  | 54/75 [00:44<00:14,  1.49it/s]


📌 Processing bima_activity_off...


Processing subjects:  88%|████████▊ | 66/75 [00:57<00:09,  1.04s/it]


📌 Processing rest_eyes_open_on...


Processing subjects: 100%|██████████| 75/75 [01:05<00:00,  1.15it/s]



🎉 SUCCESS! All data saved to:
   /home/jaizor/jaizor/xtra/derivatives/psd_voxel_cache/psd_voxel_all_data.npz


In [11]:
# 2_visualize_results.ipynb

import numpy as np
from pathlib import Path

# Load your saved data
data_file = "/home/jaizor/jaizor/xtra/derivatives/psd_voxel_cache/psd_voxel_all_data.npz"
loaded = np.load(data_file, allow_pickle=True)
all_data = loaded['data']  # List of dicts: one per subject-condition

print(f"Loaded {len(all_data)} subject-condition combinations")
print("Example keys:", list(all_data[0].keys()))
print("Regions:", all_data[0]['region_names'])

Loaded 75 subject-condition combinations
Example keys: ['subject', 'condition', 'med_state', 'region_names', 'coordinates_requested', 'results']
Regions: ['Cortical Hand Knob - Left Anterior', 'Cortical Hand Knob - Left Posterior', 'Cortical Hand Knob - Right Anterior', 'Cortical Hand Knob - Right Posterior', 'Primary Motor Cortex (M1) - Right', 'Subthalamic Nucleus (STN) - Left', 'Subthalamic Nucleus (STN) - Right', 'Primary Visual Cortex (V1) - Left', 'Primary Visual Cortex (V1) - Right']
