In [None]:
import argparse
import os
import numpy as np
import pandas as pd
from nsd_access import NSDAccess
import scipy.io
import s3fs
import nibabel as nb
import io
import sys
import gzip
import matplotlib.pyplot as plt

def read_behavior_from_s3(subject, session_index):
    s3 = s3fs.S3FileSystem(anon=True)
    behavior_file_path = f"s3://natural-scenes-dataset/nsddata/ppdata/{subject}/behav/responses.tsv"
    with s3.open(behavior_file_path, 'rb') as f:
        behavior = pd.read_csv(f, delimiter='\t')
    session_behavior = behavior[behavior['SESSION'] == session_index]
    return session_behavior

def main():
    if 'ipykernel' in sys.modules:
        subject = 'subj01'
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument("--subject", type=str, default="subj01", help="subject name")
        opt = parser.parse_args()
        subject = opt.subject

    fs = s3fs.S3FileSystem(anon=True)
    atlasname = 'streams'
    nsd_path = 's3://natural-scenes-dataset/'
    nsda = NSDAccess(nsd_path)
    nsd_expdesign = scipy.io.loadmat('nsd_expdesign.mat')
    sharedix = nsd_expdesign['sharedix'] - 1

    behs = pd.DataFrame()
    for i in range(1, 38):
        beh = read_behavior_from_s3(subject=subject, session_index=i)
        behs = pd.concat((behs, beh))

    stims_unique = behs['73KID'].unique() - 1
    stims_all = behs['73KID'] - 1

    savedir = f'./mrifeat/{subject}/'
    os.makedirs(savedir, exist_ok=True)

    if not os.path.exists(f'{savedir}/{subject}_stims.npy'):
        np.save(f'{savedir}/{subject}_stims.npy', stims_all)
        np.save(f'{savedir}/{subject}_stims_ave.npy', stims_unique)

    session_save_dir = os.path.join(savedir, 'betas_sessions')
    os.makedirs(session_save_dir, exist_ok=True)

    for i in range(1, 38):
        print(f"Processing session {i}")
        session_file = os.path.join(savedir, 'betas_sessions', f'betas_session{i:02d}.npy')
    
        needs_download = True
        if os.path.exists(session_file):
            try:
                existing = np.load(session_file, mmap_mode='r')
                if np.all(existing == 0) or np.isnan(existing).all():
                    print(f"Session file {session_file} is empty or invalid. Re-downloading.")
                else:
                    print(f"Already downloaded: {session_file}")
                    needs_download = False
            except Exception as e:
                print(f"Error loading existing {session_file}, re-downloading. Error: {e}")
    
        if needs_download:
            s3_path = f's3://natural-scenes-dataset/nsddata_betas/ppdata/{subject}/func1pt8mm/betas_fithrf_GLMdenoise_RR/betas_session{i:02d}.nii.gz'
            try:
                with fs.open(s3_path, 'rb') as f:
                    gz = gzip.GzipFile(fileobj=f)
                    raw_data = gz.read()
                    img = nb.Nifti1Image.from_bytes(raw_data)
                    beta_data = img.get_fdata(dtype=np.float32)
                    beta_data = np.transpose(beta_data, (3, 1, 2, 0))  # (trials, x, y, z)

                np.save(session_file, beta_data)
                print(f"Saved session {i} to {session_file}")
    
            except Exception as e:
                print(f"Error downloading session {i} from {s3_path}: {e}")
                continue


    print("Loading and concatenating session betas...")
    sample = np.load(os.path.join(session_save_dir, 'betas_session01.npy'))
    print("Sample shape:", sample.shape)

    n_trials, x, y, z = sample.shape
    total_trials = n_trials * 37
    stims_all = stims_all[:total_trials]
    behs = behs.iloc[:total_trials]

    betas_all_path = f'{savedir}/{subject}_betas_all_memmap.npy'
    if os.path.exists(betas_all_path):
        print(f"Memmap file already exists at {betas_all_path}, skipping concatenation.")
        betas_all = np.memmap(betas_all_path, dtype=np.float32, mode='r', shape=(total_trials, x, y, z))
    else:
        betas_all = np.memmap(betas_all_path, dtype=np.float32, mode='w+', shape=(total_trials, x, y, z))
        offset = 0
        for i in range(1, 38):
            session_file = os.path.join(session_save_dir, f'betas_session{i:02d}.npy')
            data = np.load(session_file)
            n = data.shape[0]
            betas_all[offset:offset + n] = data
            offset += n
        betas_all.flush()
        print(f"Memmap saved to {betas_all_path}")

    atlas_path = f's3://natural-scenes-dataset/nsddata/ppdata/{subject}/func1pt8mm/roi/{atlasname}.nii.gz'
    with fs.open(atlas_path, 'rb') as f:
        with gzip.GzipFile(fileobj=f) as gz:
            raw_data = gz.read()
            atlas_img = nb.Nifti1Image.from_bytes(raw_data)
            atlas_data = atlas_img.get_fdata()

    # Transpose to match beta volume shape
    atlas_data = atlas_data.transpose(1, 2, 0)
    
    # Load ROI label mapping
    mapping_path = f's3://natural-scenes-dataset/nsddata/freesurfer/fsaverage/label/{atlasname}.mgz.ctab'
    with fs.open(mapping_path, 'rb') as f:
        atlas_mapping_df = pd.read_csv(f, delimiter=' ', header=None, index_col=0)
        atlas_mapping = atlas_mapping_df[1].to_dict()


    print("Atlas mapping keys:", atlas_mapping.keys())
    print("Atlas mapping values:", atlas_mapping.values())

    atlas = (atlas_data, atlas_mapping)
    atlas_flat = atlas[0].flatten()
    print("Unique values in atlas_flat:", np.unique(atlas_flat))
    print("Type of atlas_flat[0]:", type(atlas_flat[0]))


    print("Atlas ROI voxel counts:")
    for label_id, roi_name in atlas[1].items():
        print(f"{label_id} ({roi_name}): {(atlas_flat == label_id).sum()}")

    betas_all = betas_all[:, :, :, :atlas_data.shape[2]]
    print("betas_all shape:", betas_all.shape)
    print("atlas shape:", atlas[0].shape)
    print("betas_all shape:", betas_all.shape)
    assert betas_all.shape[1:] == atlas[0].shape, "Beta and atlas spatial dimensions do not match!"


    # Show a visual check
    z = 40
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(np.mean(betas_all[0, :, :, :], axis=2), cmap='gray')
    plt.title('Mean Beta Volume (collapsed Z)')
    plt.subplot(1, 2, 2)
    plt.imshow(atlas[0][:, :, z], cmap='tab10')
    plt.title(f'Atlas Slice z={z}')
    plt.tight_layout()
    plt.show()

    print("Unique atlas voxel values:", np.unique(atlas[0]))

    # Flatten beta volume
    n_trials, x, y, z = betas_all.shape
    betas_all_flat = betas_all.reshape(n_trials, -1)
    print("betas_all_flat dtype:", betas_all_flat.dtype)
    print("betas_all_flat shape:", betas_all_flat.shape)
    print("betas_all_flat is memmap:", isinstance(betas_all_flat, np.memmap))

    print("Trying to read test chunk from betas_all_flat...")
    test_chunk = betas_all_flat[0:100, 0:100]
    print("Test chunk shape:", test_chunk.shape)

    import psutil
    
    print("\nStarting ROI processing...")
    
    stim_to_idx = {stim: i for i, stim in enumerate(stims_unique)}  # faster lookup
    
    roi_indices_cache = {}
    
    for label_id, roi_name in atlas[1].items():
        if label_id == 0:
            print(f"Skipping ROI {roi_name} (ID: 0)")
            continue
    
        print(f"\nProcessing ROI {roi_name} (ID: {label_id})", flush=True)
        
        roi_mask = atlas_flat == label_id
        voxel_count = roi_mask.sum()
    
        if voxel_count == 0:
            print(f"Skipping {roi_name}: no voxels found.")
            continue
    
        print(f"  ROI voxel count: {voxel_count}")
        print(f"  Available memory: {psutil.virtual_memory().available / 1e9:.2f} GB")
    
        try:
            # Use index list instead of boolean mask for performance
            roi_indices = roi_indices_cache.get(label_id)
            if roi_indices is None:
                roi_indices = np.where(roi_mask)[0]
                roi_indices_cache[label_id] = roi_indices
    
            # Load in chunks to avoid RAM overload
            chunk_size = 5000
            betas_roi_parts = []
    
            for start in range(0, len(stims_all), chunk_size):
                end = min(start + chunk_size, len(stims_all))
                part = betas_all_flat[start:end, roi_indices]
                betas_roi_parts.append(part)
    
            betas_roi = np.vstack(betas_roi_parts).astype(np.float32)
            print(f"  betas_roi shape: {betas_roi.shape}")
        except Exception as e:
            print(f"Error loading betas for {roi_name}: {e}")
            continue
    
        try:
            # Averaging per unique stimulus
            print("  Averaging betas by stimulus...")
            df = pd.DataFrame({'stim': stims_all})
            betas_df = pd.DataFrame(betas_roi)
            df = pd.concat([df, betas_df], axis=1)
    
            betas_roi_ave_df = df.groupby('stim').mean()
            betas_roi_ave = betas_roi_ave_df.loc[stims_unique].values.astype(np.float32)
            print(f"  betas_roi_ave shape: {betas_roi_ave.shape}")
        except Exception as e:
            print(f"Error in averaging for {roi_name}: {e}")
            continue
    
        try:
            # Train/test split
            betas_tr = np.stack([betas_roi[idx] for idx, stim in enumerate(stims_all) if stim not in sharedix])
            betas_te = np.stack([betas_roi[idx] for idx, stim in enumerate(stims_all) if stim in sharedix])
            betas_ave_tr = np.stack([betas_roi_ave[idx] for idx, stim in enumerate(stims_unique) if stim not in sharedix])
            betas_ave_te = np.stack([betas_roi_ave[idx] for idx, stim in enumerate(stims_unique) if stim in sharedix])
            print(f"  Shapes - tr: {betas_tr.shape}, te: {betas_te.shape}, ave_tr: {betas_ave_tr.shape}, ave_te: {betas_ave_te.shape}")
        except Exception as e:
            print(f"Error in train/test split for {roi_name}: {e}")
            continue
    
        try:
            np.save(f'{savedir}/{subject}_{roi_name}_betas_tr.npy', betas_tr)
            np.save(f'{savedir}/{subject}_{roi_name}_betas_te.npy', betas_te)
            np.save(f'{savedir}/{subject}_{roi_name}_betas_ave_tr.npy', betas_ave_tr)
            np.save(f'{savedir}/{subject}_{roi_name}_betas_ave_te.npy', betas_ave_te)
            print(f"  Saved all files for ROI {roi_name}")
        except Exception as e:
            print(f"Error saving files for {roi_name}: {e}")

if __name__ == "__main__":
    main()