1. Get the brain graph for each participant

In [2]:
import os
import torch
import pandas as pd
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import gc

from brain_graph import BrainGraph

data_dir = '/projectnb/ec523/projects/proj_GS_LQ_EPB/data/T1w_segmented/HCP/'

checkpoint_path = "brain_graph_dataset.pt"
completed_ids_path = "completed_subjects.txt"

# Load already completed subjects (for if/when code crashes)
if os.path.exists(completed_ids_path):
    with open(completed_ids_path, 'r') as f:
        completed_subjects = set(line.strip() for line in f)
else:
    completed_subjects = set()

# Load participant age info
df = pd.read_csv(os.path.join(data_dir, 'ages.csv'), dtype={'subject_id': str})
participant_ages = dict(zip(df['subject_id'], df['age']))

# Build list of participants to process
participants = []
for participant in os.listdir(data_dir):
    if participant in completed_subjects:
        continue
    participant_path = os.path.join(data_dir, participant)
    if not os.path.isdir(participant_path):
        continue
    ct_file = os.path.join(participant_path, f"{participant}_CorThick.nii.gz")
    seg_file = os.path.join(participant_path, f"{participant}_wmparc.nii.gz")
    if not (os.path.exists(ct_file) and os.path.exists(seg_file)):
        continue
    age = participant_ages.get(participant)
    if age is None:
        continue
    participants.append((participant, ct_file, seg_file, age))

# Load any existing graphs
if os.path.exists(checkpoint_path):
    participant_graphs = torch.load(checkpoint_path, weights_only=False)
else:
    participant_graphs = []

# Function to process one participant
def process_participant(args):
    participant, ct_file, seg_file, age = args
    try:
        graph_builder = BrainGraph(seg_file, ct_file)
        graph_builder.get_region_stats()
        graph_builder.get_region_centroids()
        graph_builder.create_adjacency_list(k=5)
        graph_builder.get_brain_graph(age)
        graph = graph_builder.graph
        return (participant, graph)
    except Exception as e:
        return (participant, f"ERROR: {e}")
    finally:
        del graph_builder
        gc.collect()

# Code to process in parallel (8 cores) and batches
# Needed because code kept freezing or crashing after ~250 participants
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

chunk_size = 100
max_workers = 8

for batch in chunks(participants, chunk_size):
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_participant, p): p[0] for p in batch}
        for future in tqdm(as_completed(futures), total=len(futures)):
            participant, result = future.result()
            if isinstance(result, str) and result.startswith("ERROR"):
                print(f"Failed for participant {participant}: {result}")
            else:
                participant_graphs.append(result)
                # Save graph list
                torch.save(participant_graphs, checkpoint_path)
                # Append ID to completed file
                with open(completed_ids_path, 'a') as f:
                    f.write(f"{participant}\n")


100%|██████████| 100/100 [09:35<00:00,  5.75s/it]
100%|██████████| 100/100 [09:27<00:00,  5.68s/it]
100%|██████████| 100/100 [09:34<00:00,  5.75s/it]
100%|██████████| 50/50 [05:02<00:00,  6.05s/it]
