In [237]:
import anatomist.api as ana
from soma import aims
import numpy as np
import argparse
import os
import json
import glob

In [253]:
def simple_yaml_loader(filepath):
    data = {}
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):  # skip empty and comment lines
                continue
            if ":" in line:
                key, value = line.split(":", 1)
                data[key.strip()] = value.strip()
    return data

def load_configs(path: str):
    """
    Load decoder config and corresponding encoder config.
    """
    path = os.path.dirname(path)

    encoder_config_path = os.path.join(path,'.hydra', 'encoder_config.yaml')
    if not os.path.exists(encoder_config_path):
        print("No encoder_config.yaml found at: ", encoder_config_path)
    else:
        encoder_cfg = simple_yaml_loader(encoder_config_path)

    decoder_config_path = os.path.join(path,'.hydra', 'decoder_config.yaml')
    if not os.path.exists(decoder_config_path):
        print("No decoder_config.yaml found at: ", decoder_config_path)
        return None
    decoder_cfg = simple_yaml_loader(decoder_config_path)

    loss = decoder_cfg['loss']
    region_name, mask, side_skeleton = (encoder_cfg['numpy_all']).split('/')[-3:]
    print(region_name, mask, side_skeleton)
    side = side_skeleton[0]
    return region_name, side, loss

def build_gradient(pal):
    """Build a gradient palette for Anatomist visualization."""
    gw = ana.cpp.GradientWidget(None, 'gradientwidget', pal.header()['palette_gradients'])
    gw.setHasAlpha(True)
    nc = pal.shape[0]
    rgbp = gw.fillGradient(nc, True)
    rgb = rgbp.data()
    npal = pal.np['v']
    pb = np.frombuffer(rgb, dtype=np.uint8).reshape((nc, 4))
    npal[:, 0, 0, 0, :] = pb
    # Convert BGRA to RGBA
    npal[:, 0, 0, 0, :3] = npal[:, 0, 0, 0, :3][:, ::-1]
    pal.update()


def npy_to_nii(npy_path):
    """Convert a NumPy volume to a NIfTI (.nii.gz) file."""
    vol_npy = np.load(npy_path).astype(np.float32)
    vol_aims = aims.Volume(vol_npy)
    vol_aims.header()['voxel_size'] = [2.0, 2.0, 2.0]
    nii_path = npy_path.replace('.npy', '.nii.gz')
    aims.write(vol_aims, nii_path)
    return nii_path


def ensure_nii_exists(file_path):
    """
    Ensure that a .nii.gz file exists for the given input.
    If it's a .npy, convert it to .nii.gz.
    """
    if file_path.endswith('.nii.gz') and os.path.isfile(file_path):
        return file_path

    if file_path.endswith('.npy') and os.path.isfile(file_path):
        return npy_to_nii(file_path)

    raise FileNotFoundError(f"File not found or unsupported format: {file_path}")


def load_and_prepare_volume(anatomist, file_path, referential, palette=None, min_val=None, max_val=None):
    """Load a volume into Anatomist, wrap it into a fusion object, assign referential."""
    vol = aims.read(file_path)
    a_obj = anatomist.toAObject(vol)
    fusion = anatomist.fusionObjects(objects=[a_obj], method='VolumeRenderingFusionMethod')

    if palette:
        fusion.setPalette(palette, minVal=min_val, maxVal=max_val, absoluteMode=True)

    fusion.releaseAppRef()
    fusion.assignReferential(referential)
    return fusion

# ===========================
# Camera Utilities
# ===========================

def print_camera_infos(window):
    """
    Print quaternion and zoom information of an Anatomist window.
    """
    try:
        info = window.getInfos()
        quat = info.get('view_quaternion', None)
        zoom = info.get('zoom', None)
        print("---- Camera Infos ----")
        print(f"Quaternion : {quat}")
        print(f"Zoom       : {zoom}")
        print("----------------------")
    except Exception as e:
        print(f"Error while accessing window camera info: {e}")

# ===========================
# Snapshot Utilities
# ===========================

def snapshot_all(subjects, side, region, save_dir):
    """
    Save snapshots of input and decoded windows.
    """
    os.makedirs(save_dir, exist_ok=True)

    image_files = []

    for i, subject in enumerate(subjects):

        w_decoded_key = f'w_decoded_{i}'
        w_input_key   = f'w_input_{i}'

        if w_decoded_key not in dic_windows or w_input_key not in dic_windows:
            print(f"Skipping {subject}: window not found.")
            continue

        # Disable cursor
        dic_windows[w_decoded_key].setHasCursor(0)
        dic_windows[w_input_key].setHasCursor(0)

        # Save decoded
        recon_fname = f"{subject}_{region}_{side}_decoded.png"
        recon_img_path = os.path.join(save_dir, recon_fname)
        dic_windows[w_decoded_key].snapshot(recon_img_path, width=1200, height=900)

        # Save input
        init_fname = f"{subject}_{region}_{side}_input.png"
        init_img_path = os.path.join(save_dir, init_fname)
        dic_windows[w_input_key].snapshot(init_img_path, width=1200, height=900)

        image_files.append(init_img_path)
        image_files.append(recon_img_path)

        print(f"Saved snapshots for {subject}")

    return image_files

def run_visualization(folder_name, subjects=None, nsubjects=4, crops=None):
    """
    Programmatic entry point for visualization (no CLI).
    """
    global a, block, dic_windows

    a = ana.Anatomist()
    nb_columns = 2
    block = a.createWindowsBlock(nb_columns)
    dic_windows = {}

    if load_configs(folder_name) is not None:
        region_name, side, loss = load_configs(folder_name)
        plot_ana(
            recon_dir=folder_name,
            n_subjects_to_display=nsubjects,
            listsub=subjects,
            region=region_name,
            side=side,
            loss_name=loss,
            crops=crops
        )
    else:
        plot_ana(
            recon_dir=folder_name,
            n_subjects_to_display=nsubjects,
            listsub=subjects,
            loss_name='bce',
            crops=crops
        )


In [254]:
def get_decoded_files(recon_dir, listsub, n_subjects_to_display):
    """Get decoded files either from provided subjects or randomly from directory."""
    if listsub:
        decoded_files = [os.path.join(recon_dir, f"{sub}_decoded.npy") for sub in listsub]
    else:
        print("No list of subjects provided, taking random subjects.")
        decoded_files = glob.glob(os.path.join(recon_dir, "*_decoded.npy"))

        if not decoded_files:
            raise FileNotFoundError(f"No decoded files found in {recon_dir}")

        selected = np.random.choice(len(decoded_files), size=n_subjects_to_display, replace=False)
        decoded_files = [decoded_files[i] for i in selected]
        print("Randomly selected decoded files:", [os.path.basename(f) for f in decoded_files])

    return decoded_files


def plot_ana(recon_dir, n_subjects_to_display, loss_name, listsub,
             dataset="UkBioBank40", region="S.T.s.br.", side="L", crops=None, region_views={}):
    """
    Display pairs of input and decoded volumes in Anatomist,
    side-by-side for each subject.
    """
    referential = a.createReferential()

    # Palette settings based on loss
    palette_config = {
        'bce': {
            'gradient': "1;1#0;1;1;0#0.994872;0#0;0;0.635897;0.266667;1;1",
            'min_val': 0, 'max_val': 0.5
        },
        'mse': {
            'gradient': "1;1#0;1;1;0#0.994872;0#0;0;0.694872;0.244444;1;1",
            'min_val': 0, 'max_val': 0.5
        },
        'ce': {
            'gradient': "1;1#0;1;0.292308;0.733333;0.510256;0;0.679487;"
                        "0.733333#1;0#0;0;0.341026;0.111111;0.507692;"
                        "0.911111;0.697436;0.111111;1;0",
            'min_val': -1.6, 'max_val': 0.33
        }
    }

    decoded_files = get_decoded_files(recon_dir, listsub, n_subjects_to_display)

    # Prepare palette
    pal = a.createPalette('VR-palette')
    pal.header()['palette_gradients'] = palette_config[loss_name]['gradient']
    build_gradient(pal)

    for i, decoded_path in enumerate(decoded_files):
        subject_id = os.path.basename(decoded_path).split('_decoded')[0]

        # ---- Load decoded file ----
        try:
            decoded_path = ensure_nii_exists(decoded_path)
            dic_windows[f'r_decoded_{i}'] = load_and_prepare_volume(
                a, decoded_path, referential,
                palette='VR-palette',
                min_val=palette_config[loss_name]['min_val'],
                max_val=palette_config[loss_name]['max_val']
            )
            dic_windows[f'w_decoded_{i}'] = a.createWindow('3D', block=block)
            dic_windows[f'w_decoded_{i}'].addObjects([dic_windows[f'r_decoded_{i}']])

            region_view = region_views.get(region, None)
            if region_views and side in region_view:
                camera_view = region_view[side]["camera_view"]
            else:
                camera_view = None
            if camera_view:
                dic_windows[f'w_decoded_{i}'].camera(view_quaternion=camera_view)
            else :
                print('No camera view provided for this region')

        except FileNotFoundError:
            print(f"ERROR: Decoded file not found for {subject_id}. Skipping.")
            continue

        # ---- Load input file ----
        input_path = os.path.join(recon_dir, f"{subject_id}_input.npy")
        if not os.path.isfile(input_path):
            print(f"Local input file missing for {subject_id}, searching in original dataset...")
            if crops:
                mm_skeleton_path = crops
            else:
                mm_skeleton_path = f"/neurospin/dico/data/deep_folding/current/datasets/{dataset}/crops/2mm/{region}/mask/{side}crops"
            input_path = os.path.join(mm_skeleton_path, f"{subject_id}_cropped_skeleton.nii.gz")

        try:
            input_path = ensure_nii_exists(input_path)
            dic_windows[f'r_input_{i}'] = load_and_prepare_volume(
                a, input_path, referential
            )
            dic_windows[f'w_input_{i}'] = a.createWindow('3D', block=block)
            dic_windows[f'w_input_{i}'].addObjects([dic_windows[f'r_input_{i}']])
            region_view = region_views.get(region, None)
            if region_view and side in region_view:
                camera_view = region_view[side]["camera_view"]
            else:
                camera_view = None
            if camera_view:
                dic_windows[f'w_input_{i}'].camera(view_quaternion=camera_view)

        except FileNotFoundError:
            print(f"ERROR: Input file not found for {subject_id}. Skipping.")
            continue

    print("All subjects loaded and displayed successfully.")

In [255]:
with open("/neurospin/dico/adufournet/2025_Champollion_Decoder/decoder/reconstruction/region_views.json", "r") as f:
    REGION_VIEWS = json.load(f)

In [256]:
#%matplotlib qt5

In [257]:
a = ana.Anatomist()
nb_columns = 2

In [258]:
subjects = ["100206","101410", "102715"]
nsubjects = len(subjects)
list_folder_name = glob.glob("/neurospin/dico/adufournet/2025_Champollion_Decoder/runs/Champollion_V1_after_ablation_256/*/reconstruction_best_model")
for folder_name in list_folder_name:
        block = a.createWindowsBlock(nb_columns)
        dic_windows = {}
        if not os.path.isdir(folder_name):
                raise FileNotFoundError(f"Provided path not found: {folder_name}")
        region_name, side, loss = load_configs(folder_name)
        print(region_name)

        plot_ana(recon_dir=folder_name,
        n_subjects_to_display=nsubjects,
        listsub=subjects, 
        region = region_name, 
        side = side,
        loss_name=loss,
        crops=None, 
        dataset='hcp',
        region_views=REGION_VIEWS)

        snapshot_all(subjects=subjects, side=side, region=region_name, save_dir="/volatile/ad279118/tmp")


OCCIPITAL mask Lskeleton.npy
OCCIPITAL
Local input file missing for 100206, searching in original dataset...
Local input file missing for 101410, searching in original dataset...
Local input file missing for 102715, searching in original dataset...
All subjects loaded and displayed successfully.
Saved snapshots for 100206
Saved snapshots for 101410
Saved snapshots for 102715
S.T.i.-S.T.s.-S.T.pol. mask Rskeleton.npy
S.T.i.-S.T.s.-S.T.pol.
Local input file missing for 100206, searching in original dataset...
Local input file missing for 101410, searching in original dataset...
Local input file missing for 102715, searching in original dataset...
All subjects loaded and displayed successfully.
Saved snapshots for 100206
Saved snapshots for 101410
Saved snapshots for 102715
S.Or. mask Lskeleton.npy
S.Or.
Local input file missing for 100206, searching in original dataset...
Local input file missing for 101410, searching in original dataset...
Local input file missing for 102715, searching 

-----------------------------------------------------------
Cannot delete object VolumeRendering: Volume_S16 (177),
check for multi-objects which contain it. There are still 1 other references to it

-----------------------------------------------------------
-----------------------------------------------------------
Cannot delete object VolumeRendering: Volume_S16 (175),
check for multi-objects which contain it. There are still 1 other references to it

-----------------------------------------------------------
-----------------------------------------------------------
Cannot delete object VolumeRendering: Volume_FLOAT (175),
check for multi-objects which contain it. There are still 1 other references to it

-----------------------------------------------------------
-----------------------------------------------------------
Cannot delete object Volume_S16 (175),
check for multi-objects which contain it. There are still 1 other references to it

-----------------------------------

: 

In [231]:
# ===========================
# Region-based camera views
# ===========================

REGION_VIEWS = {
    "CINGULATE.": {
        "L": {
            "camera_view": [0.5, -0.5, -0.5, 0.5],
        },
        "R": {
            "camera_view": [0.5, 0.5, 0.5, 0.5],
        }
    },
    "F.C.L.p.-subsc.-F.C.L.a.-INSULA.": {
        "L": {
            "camera_view": [0.15, -0.66, 0.58, -0.43],
        },
        "R": {
            "camera_view": [0.15, 0.63, -0.44, -0.60],
        }
    },
    "F.C.M.post.-S.p.C.": {
        "L": {
            "camera_view": [0.81, 0.043, 0.58, 0.083],
        },
        "R": {
            "camera_view": [-0.45, -0.43, -0.54, -0.56],
        }
    },
    "F.Coll.-S.Rh.": {
        "L": {
            "camera_view": [0.36, 0.32, 0.74, 0.46],
        },
        "R": {
            "camera_view": [0.745, 0.46, 0.46, 0.14],
        }
    },
    "F.I.P.": {
        "L": {
            "camera_view": [0.01, 0.36,  0.74, 0.57],
        },
        "R": {
            "camera_view": [-0.29, 0.08, 0.56, 0.77],
        }
    },
    "F.P.O.-S.Cu.-Sc.Cal.": {
        "L": {
            "camera_view": [0.5, -0.5, -0.5, 0.5],
        },
        "R": {
            "camera_view": [0.5, 0.5, 0.5, 0.5],
        }
    },
    "LARGE_CINGULATE.": {
        "L": {
            "camera_view": [0.5, -0.5, -0.5, 0.5],
        },
        "R": {
            "camera_view": [0.5, 0.5, 0.5, 0.5],
        }
    },
    "Lobule_parietal_sup.": {
        "L": {
            "camera_view": [-0.07, 0.32, 0.81, 0.5],
        },
        "R": {
            "camera_view": [-0.3, 0.11, 0.62, 0.72],
        }
    },
    "OCCIPITAL": {
        "L": {
            "camera_view": [-0.38, 0.34, 0, 0.86],
        },
        "R": {
            "camera_view": [-0.5, -0.05, 0.3, 0.82],
        }
    },
    "S.C.-S.Pe.C.": {
        "L": {
            "camera_view": [0.34, 0.26, 0.48, 0.76],
        },
        "R": {
            "camera_view": [0.3, -0.24, -0.44, 0.8],
        }
    },
    "S.C.-S.Po.C.": {
        "L": {
            "camera_view": [0.28, 0.23, 0.48, 0.79],
        },
        "R": {
            "camera_view": [0.3, -0.24, -0.44, 0.8],
        }
    },
    "S.C.-sylv.": {
        "L": {
            "camera_view": [0.28, 0.23, 0.48, 0.79],
        },
        "R": {
            "camera_view": [0.3, -0.24, -0.44, 0.8],
        }
    },
    "S.F.inf.-BROCA-S.Pe.C.inf.": {
        "L": {
            "camera_view": [0.73, 0.21, 0.54, 0.33],
        },
        "R": {
            "camera_view": [0.70, -0.13, -0.60, 0.34],
        }
    },
    "S.F.int.-F.C.M.ant.": {
        "L": {
            "camera_view": [0.5, -0.5, -0.5, 0.5],
        },
        "R": {
            "camera_view": [0.5, 0.5, 0.5, 0.5],
        }
    },
    "S.F.int.-S.R.": {
        "L": {
            "camera_view": [0.5, -0.5, -0.5, 0.5],
        },
        "R": {
            "camera_view": [0.5, 0.5, 0.5, 0.5],
        }
    },
    "S.F.inter.-S.F.sup.": {
        "L": {
            "camera_view": [0.23, 0.3, 0.15, 0.9],
        },
        "R": {
            "camera_view": [0.4, -0.33, -0.15, 0.83],
        }
    },
    "S.F.marginal-S.F.inf.ant.": {
        "L": {
            "camera_view": [0.60, 0.27, 0.34, 0.66],
        },
        "R": {
            "camera_view": [0.52, -0.34, -0.25, 0.73],
        }
    },
    "S.F.median-S.F.pol.tr.-S.F.sup.": {
        "L": {
            "camera_view": [0.24, 0.18, 0.02, 0.95],
        },
        "R": {
            "camera_view": [0.24, -0.16, -0.02, 0.95],
        }
    },
    "S.Or.-S.Olf.": {
        "L": {
            "camera_view": [0.96, -0.11, 0.19, 0.15],
        },
        "R": {
            "camera_view": [0.97, 0.12, -0.05, 0.17],
        }
    },
    "S.Or.": {
        "L": {
            "camera_view": [0.96, -0.11, 0.19, 0.15],
        },
        "R": {
            "camera_view": [0.97, 0.12, -0.05, 0.17],
        }
    },
    "S.Pe.C.": {
        "L": {
            "camera_view": [0.44, 0.35, 0.45, 0.68],
        },
        "R": {
            "camera_view": [0.3, -0.24, -0.44, 0.8],
        }
    },
    "S.Po.C.": {
        "L": {
            "camera_view": [0.19, 0.27, 0.46, 0.81],
        },
        "R": {
            "camera_view": [0.197, -0.26, -0.49, 0.81],
        }
    },
    "S.T.i.-S.O.T.lat.": {
        "L": {
            "camera_view": [-0.59, 0.66, -0.17, 0.43],
        },
        "R": {
            "camera_view": [0.66, -0.54, -0.47, 0.19],
        }
    },
    "S.T.i.-S.T.s.-S.T.pol.": {
        "L": {
            "camera_view": [0.59, 0.32, 0.58, 0.43],
        },
        "R": {
            "camera_view": [0.67, -0.35, -0.55, 0.36],
        }
    },
    "S.T.s.": {
        "L": {
            "camera_view": [0.63, 0.27, 0.62, 0.37],
        },
        "R": {
            "camera_view": [0.62, -0.30, -0.60, 0.39],
        }
    },
    "S.T.s.br.": {
        "L": {
            "camera_view": [0.30, 0.53, 0.58, 0.53],
        },
        "R": {
            "camera_view": [0.42, -0.60, -0.55, 0.40],
        }
    },
    "Sc.Cal.-S.Li.": {
        "L": {
            "camera_view": [0.40, 0.50, 0.60, 0.47],
        },
        "R": {
            "camera_view": [0.68, 0.45, 0.45, 0.34],
        } 
    },
    "S.s.P.-S.Pa.int.": {
        "L": {
            "camera_view": [0.3, -0.6, -0.4, 0.6],
        },
        "R": {
            "camera_view": [0.36, 0.60, 0.48, 0.52],
        }
    },
    "fronto-parietal_medial_face.": {
        "L": {
            "camera_view": [0.5, -0.5, -0.5, 0.5],
        },
        "R": {
            "camera_view": [0.5, 0.5, 0.5, 0.5],
        },
    }
}

In [252]:
print_camera_infos(dic_windows[f'w_decoded_{0}'])

---- Camera Infos ----
Quaternion : [-0.375619530677795, 0.336425125598907, -0.000197820365428925, 0.863555550575256]
Zoom       : 1
----------------------
