### Temporary: sbatch script

In [None]:
#!/bin/bash
#SBATCH --time=10:00:00
#SBATCH --mem=0
#SBATCH --cpus-per-task=1
#SBATCH --ntasks-per-node=192
#SBATCH --nodes=1
#SBATCH --array=0-47

# Create the logs directory if it does not exist
mkdir -p ./logs/outputs
mkdir -p ./logs/errors

# Activate the virtual environment
module load scipy-stack/2025a
source $HOME/env/bin/activate

# Note: Outputs will go into "outputs" folder, so be sure to backup the results of last step
mkdir images # For whitened PSD plots

# Run the script
inputs=("sub_NVAR008_251016_rest1" "sub_NVAR008_251016_rest1" "sub_NVAR008_251016_rest2" "sub_NVAR008_251016_rest2" "sub_NVAR008_251017_rest1" "sub_NVAR008_251017_rest1" "sub_NVAR008_251017_rest2" "sub_NVAR008_251017_rest2" "sub_NVAR008_251023_rest1" "sub_NVAR008_251023_rest1" "sub_NVAR008_251023_rest2" "sub_NVAR008_251023_rest2" "sub_NVAR008_251113_rest1" "sub_NVAR008_251113_rest1" "sub_NVAR008_251113_rest2" "sub_NVAR008_251113_rest2" "sub_NVAR010_251027_rest1" "sub_NVAR010_251027_rest1" "sub_NVAR010_251027_rest2" "sub_NVAR010_251027_rest2" "sub_NVAR010_251028_rest1" "sub_NVAR010_251028_rest1" "sub_NVAR010_251028_rest2" "sub_NVAR010_251028_rest2" "sub_NVAR010_251103_rest1" "sub_NVAR010_251103_rest1" "sub_NVAR010_251103_rest2" "sub_NVAR010_251103_rest2" "sub_NVAR010_251124_rest1" "sub_NVAR010_251124_rest1" "sub_NVAR010_251124_rest2" "sub_NVAR010_251124_rest2" "sub_NVAR011_251030_rest1" "sub_NVAR011_251030_rest1" "sub_NVAR011_251030_rest2" "sub_NVAR011_251030_rest2" "sub_NVAR011_251031_rest1" "sub_NVAR011_251031_rest1" "sub_NVAR011_251031_rest2" "sub_NVAR011_251031_rest2" "sub_NVAR011_251106_rest1" "sub_NVAR011_251106_rest1" "sub_NVAR011_251106_rest2" "sub_NVAR011_251106_rest2" "sub_NVAR011_251127_rest1" "sub_NVAR011_251127_rest1" "sub_NVAR011_251127_rest2" "sub_NVAR011_251127_rest2")
input=${inputs[$SLURM_ARRAY_TASK_ID]}
python -u /home/isw3/scratch/sprint/0221_post_sprint.py --idx "$input"

# Deactivate the virtual environment
deactivate


### Set up

In [None]:
# Note that code for visualizing sprint psd outputs is in 0205_plot_param.ipynb

# IMPORT PACKAGES
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import pickle as pkl
import mne
import copy as copy
from pathlib import Path
from mne.time_frequency import psd_array_welch
import argparse

# GET ARGUMENT
#prefix = "sub_NVAR008_251016_rest1"
parser = argparse.ArgumentParser()
parser.add_argument("--idx", type=str, required=True)
args = parser.parse_args()
prefix = args.idx
print(f"Now working on {prefix}")

# SET PATHS
base_dir = "C:/meg/0215_NVAR_sprint_fooof"

# CONSTANTS
N_WINDOWS = 115
N_VERTICES = 8196
# Set "example_stc" - this is the original beamformed stc outputted by the generic task-free processing script
# it will be used to get vertex numbers where relevant
example_stc = mne.read_source_estimate(os.path.join(base_dir, "example_stc"))

### Postprocess fooof output: Reformat csvs

In [3]:
# REFORMAT FOOOF CSVS 

def reformat(prefix):
    """
    Input: files prefix_vertex_index.csv (note: vertex index, not vertex number) where each row is window
    Output: files prefix_window_num.csv where each row is vertex index
    """
    for window in range(N_WINDOWS): # Number of windows
        rows = []

        for data_index in range(N_VERTICES): # Number of vertices

            og_csv_path = f"{base_dir}/output/{prefix}_fooof_vertex{data_index}.csv"
            og_df = pd.read_csv(og_csv_path)

            row = og_df.iloc[window].copy()
            row['vertex_index'] = data_index
            rows.append(row)

        pd.DataFrame(rows).to_csv(f"{base_dir}/output/{prefix}_fooof_window{window}.csv", index=False)

reformat(prefix)

### Postprocess SPRINT output: Break into stcs

In [6]:
# CREATE INDIVIDUAL STCS
# One stc for each time window
# You should have N_WINDOWS stcs, and each is PSD vs freq

# Do not log it
# Starts at 2 because I noticed during whitening the model fit worse there
# Ends at 161 because that's 40 Hz

# Load sprint output
with open(os.path.join(base_dir, "output", prefix + "_sprint.pkl"), "rb") as f:
    output = pkl.load(f)

# Check dimensions of sprint output
assert output["TF"].shape[0] == N_VERTICES, f"Expected {N_VERTICES} vertices, got {output['TF'].shape[0]}"
assert output["TF"].shape[1] == N_WINDOWS, f"Expected {N_WINDOWS} windows, got {output['TF'].shape[1]}"
assert len(example_stc.vertices[0]) + len(example_stc.vertices[1]) == N_VERTICES
freq_dim = output["freqs"][2:161]
assert freq_dim[0] > 1.5, f"Unexpected start frequency: {freq_dim[0]}"
assert 39.5 < freq_dim[-1] < 40.5, f"Unexpected end frequency: {freq_dim[-1]}"
print(f"Frequency range: {freq_dim[0]:.2f} to {freq_dim[-1]:.2f} Hz ({len(freq_dim)} bins)")

TF_cropped = output["TF"][:, :, 2:161]

for window in range(TF_cropped.shape[1]): 
    slice_window = TF_cropped[:, window, :]
    new_stc = mne.SourceEstimate(
        data = slice_window, 
        vertices = example_stc.vertices, 
        tmin = 0.5, 
        tstep = 0.25
    )
    new_stc.save(os.path.join(f"{base_dir}/output/{prefix}_sprint_window{window}"), overwrite=True)

AssertionError: Unexpected start frequency: 0.5

### Whitening

In [None]:
# SUBTRACTION
# Do one window at a time
# Input: sprint_window.stc, fooof_window.csv
# Output: _whitened.stc

# freq_dim was computed in an earlier step, but it's worth checking again
freq_dim = output["freqs"][2:161]
assert freq_dim[0] > 1.5, f"Unexpected start frequency: {freq_dim[0]}"
assert 39.5 < freq_dim[-1] < 40.5, f"Unexpected end frequency: {freq_dim[-1]}"
print(f"Frequency range: {freq_dim[0]:.2f} to {freq_dim[-1]:.2f} Hz ({len(freq_dim)} bins)")

for window in range(N_WINDOWS): 
    print(str(window))

    ##### Prep: Load files
    stc = mne.read_source_estimate(f"{base_dir}/output/{prefix}_sprint_window{window}")
    csv = pd.read_csv(f"{base_dir}/output/{prefix}_fooof_window{window}.csv")


    ##### PSD Model (exponential)

    # This is the list where you will put the model for each vertex
    rows = []

    # Get average offset and exponent for this subject (to be used for fixing bad values)
    average_exponent = np.nanmean(csv.iloc[:, 2])
    average_offset = np.nanmean(csv.iloc[:, 1])

    # Loop through rows of csv file
    for i in range(len(csv)): 
        
        # Collect values
        vertex = int(csv.iloc[i, 0])
        exponent = csv.iloc[i, 2]
        offset = csv.iloc[i, 1]

        # Make sure values look okay; if there are errors, replace with subject mean
        if type(exponent) != np.float64 or not np.isfinite(exponent): 
            print(f"error in exponent for subject {prefix} window {window} vertex {vertex}, setting to subject average")
            exponent = average_exponent
        if type(offset) != np.float64 or not np.isfinite(offset): 
            print(f"error in offset for subject {prefix} window {window} vertex {vertex}, setting to subject average")
            offset = average_offset

        # Compute exponential for this vertex
        # In linear space - not in log space
        exponential = 10**(offset - (np.log10(freq_dim))*exponent)

        # Add the exponential for this vertex to the "rows" list
        rows.append(exponential)

    # The final array of all exponentials for all vertices (cleaned up)
    model = np.array(rows).squeeze()

    ##### Subtraction

    # Check that dims are same first
    assert model.shape == stc.data.shape, \
    f"Shape mismatch: model {model.shape} vs stc.data {stc.data.shape}"

    # Whiten PSD: Subtract model PSD from real PSD
    # model: comes in not logged (linear)
    # stc data: comes in not logged (linear)
    # frequency is linear for both
    # subtraction happens in linear-linear space
    whitened_psd = stc.data - model

    # Create a plot for whitened PSD, and save it for review
    plt.figure()
    plt.plot(freq_dim, np.average(model.T, axis=1), label="Aperiodic fit", color="blue")
    plt.plot(stc.times, np.average(stc.data.T, axis=1), label="Original PSD", color="black")
    plt.plot(freq_dim, np.average(whitened_psd.T, axis=1), label="Whitened PSD", color="red")
    plt.ylabel("Power")
    plt.xlabel("Frequency")
    plt.legend()
    plt.title("Window " + str(window) + " detrending")
    plt.savefig(f"{base_dir}/images/{prefix}_{window}", dpi=300, bbox_inches="tight") 
    plt.close()

    # Convert the whitened psd to stc, then save it
    whitened_psd_stc = mne.SourceEstimate(
        data=whitened_psd,
        vertices=stc.vertices,
        tmin=0, 
        tstep=1
        )
    whitened_psd_stc.save(f"{base_dir}/output/{prefix}_whitened_window{window}", overwrite=True)

### Canonical bands

In [None]:
# BREAK WHITENED STCS INTO FREQUENCY BANDS

#bands = {'delta' : [1, 4], 'theta' : [4, 7], 'alpha' : [8, 12], 'beta' : [15, 29], 'low_gamma' : [30, 40]}
bands = {'theta' : [4, 7], 'alpha' : [8, 12]}

for band in bands: 

    low = bands[band][0]
    high = bands[band][1]

    new_data = np.zeros(shape = (N_VERTICES, N_WINDOWS)) # num vertices x num windows

    # Loop through time windows
    for window in range(N_WINDOWS): 

        stc = mne.read_source_estimate(f"{base_dir}/output/{prefix}_whitened_window{window}")

        trimmed_data = stc.data[:, (stc.times <= high) & (stc.times >= low)]
        average_in_this_window = trimmed_data.mean(axis=1) # Compute average across all freqs

        new_data[:, window] = average_in_this_window

    new_stc = mne.SourceEstimate(
        data=new_data, 
        vertices=stc.vertices,
        tmin=0, 
        tstep=1
        )
    
    new_stc.save(f"{base_dir}/output/{prefix}_whitened_{band}", overwrite=True)