This notebook displays on the SPAM model of sulci the value of the parapeter for each sulcus. The entry point is a csv containing the value to plot for each sulcal region.

Requires anatomist from the BrainVISA software suite

# 1. Imports 

In [113]:
import anatomist.api as anatomist
from soma.qt_gui.qtThread import QtThreadCall
from soma.qt_gui.qt_backend import Qt

In [114]:
from soma import aims
import json
import os
import pandas as pd
import numpy as np

from PIL import Image, ImageFont, ImageDraw, ImageOps

# 2. Paths and constants

In [None]:
# Path where the models and region configuration lie
path_to_deep_folding = "/neurospin/dico/data/deep_folding/current"

# JSON file with sulcal regions containing list of sulci
# json_regions = f"{path_to_deep_folding}/sulci_regions_gridsearch.json"
# json_regions = f"{path_to_deep_folding}/sulci_regions_champollion_V1.json"
json_regions = f"{os.getcwd()}"\
    "/../../view_gene_database/"\
    "region_to_sulci_with_smaller_region_size.json"

# Path to parameter file
path_summary = f"{path_to_deep_folding}/models/Champollion_V1_after_ablation/analysis/QTIM"
path_file = "IHI_QTIM_resid_sex_age.csv"
file_to_display = f"{path_summary}/{path_file}"

# Parameters and p value
param = "auc"
p_value = "p_value"
threshold = 0.05/56.
minVal = 0.5
maxVal = 1.0

SNAPSHOT = True


In [116]:
# Gets SPAM models on which visualization is done
Rspam_model = aims.carto.Paths.findResourceFile(
    "models/models_2008/descriptive_models/"
    "segments/global_registered_spam_right/meshes/Rspam_model_meshes_1.arg")
Lspam_model = aims.carto.Paths.findResourceFile(
    "models/models_2008/descriptive_models/segments/"
    "global_registered_spam_left/meshes/Lspam_model_meshes_1.arg")

In [117]:
path_summary

'/neurospin/dico/data/deep_folding/current/models/Champollion_V1_after_ablation/analysis/QTIM'

# 2. Preprocessing

In [118]:
with open(json_regions) as f:
    regions = json.load(f)
print(len(regions))

50


In [119]:
# next(iter(regions[[brain]])) # Need the change to this if using the inital json file
next(iter(regions))

'F.C.L.p.-subsc.-F.C.L.a.-INSULA._left'

In [120]:
df = pd.read_csv(file_to_display)[["region", param, p_value]]
df.head()

Unnamed: 0,region,auc,p_value
0,F.Coll.-S.Rh._left,0.850495,8.9e-05
1,S.T.i.-S.O.T.lat._left,0.716641,8.9e-05
2,F.Coll.-S.Rh._right,0.696203,8.9e-05
3,S.T.i.-S.O.T.lat._right,0.692656,8.9e-05
4,F.P.O.-S.Cu.-Sc.Cal._left,0.68862,8.9e-05


In [121]:
res = df.set_index("region")
res["side"] = res.index.str.split('_').str[-1]
res = res.reset_index()

In [122]:
def get_sulci(region):
    if region in regions:
        return regions[f"{region}"]
    else:
        return []

In [123]:
get_sulci("S.C.-sylv._left")

['S.C._left', 'S.C.sylvian._left']

In [124]:
len(regions.keys())

50

In [125]:
res['sulcus'] = res.apply(lambda x: get_sulci(x.region), axis=1)
res.head()

Unnamed: 0,region,auc,p_value,side,sulcus
0,F.Coll.-S.Rh._left,0.850495,8.9e-05,left,"[F.Coll._left, S.Rh._left]"
1,S.T.i.-S.O.T.lat._left,0.716641,8.9e-05,left,"[S.O.T.lat.ant._left, S.O.T.lat.int._left, S.O..."
2,F.Coll.-S.Rh._right,0.696203,8.9e-05,right,"[F.Coll._right, S.Rh._right]"
3,S.T.i.-S.O.T.lat._right,0.692656,8.9e-05,right,"[S.O.T.lat.ant._right, S.O.T.lat.int._right, S..."
4,F.P.O.-S.Cu.-Sc.Cal._left,0.68862,8.9e-05,left,"[F.P.O._left, S.Cu._left, S.Pa.t._left]"


In [126]:
res[res["sulcus"].str.len() == 0]

Unnamed: 0,region,auc,p_value,side,sulcus
14,S.C.-S.Pe.C._right,0.593671,0.002053,right,[]
24,S.C.-S.Po.C._right,0.562229,0.0383,right,[]
43,S.C.-S.Pe.C._left,0.505049,0.430587,left,[]
44,F.I.P.-F.I.P.Po.C.inf._left,0.503647,0.461298,left,[]
45,F.I.P.-F.I.P.Po.C.inf._right,0.502115,0.476297,right,[]
51,S.C.-S.Po.C._left,0.492078,0.583876,left,[]


In [127]:
res = res.sort_values(by=param, ascending=False)
res = res.explode("sulcus")
res[res.region.str.contains("S.T.s.")]

Unnamed: 0,region,auc,p_value,side,sulcus
6,S.T.i.-S.T.s.-S.T.pol._left,0.652506,0.000179,left,S.T.pol._left
10,S.T.s._left,0.629113,0.000268,left,S.T.s._left
26,S.T.s._right,0.558147,0.072136,right,S.T.s._right
34,S.T.i.-S.T.s.-S.T.pol._right,0.536855,0.172395,right,S.T.pol._right
36,S.T.s.br._left,0.529865,0.18668,left,F.I.P.r.int.1_left
36,S.T.s.br._left,0.529865,0.18668,left,F.I.P.r.int.2_left
36,S.T.s.br._left,0.529865,0.18668,left,S.GSM._left
36,S.T.s.br._left,0.529865,0.18668,left,S.T.s.ter.asc.ant._left
36,S.T.s.br._left,0.529865,0.18668,left,S.T.s.ter.asc.post._left
50,S.T.s.br._right,0.493057,0.576288,right,F.I.P.r.int.1_right


In [128]:
res[res["sulcus"].str.len() == 0]

Unnamed: 0,region,auc,p_value,side,sulcus


In [129]:
res[param].tolist()

[0.8504950079808514,
 0.8504950079808514,
 0.7166407063007402,
 0.7166407063007402,
 0.7166407063007402,
 0.7166407063007402,
 0.7166407063007402,
 0.7166407063007402,
 0.696202521807576,
 0.696202521807576,
 0.6926556081217279,
 0.6926556081217279,
 0.6926556081217279,
 0.6926556081217279,
 0.6926556081217279,
 0.6926556081217279,
 0.6886202435954812,
 0.6886202435954812,
 0.6886202435954812,
 0.6862223757514773,
 0.6862223757514773,
 0.6862223757514773,
 0.6862223757514773,
 0.6525062170959306,
 0.6496946987130938,
 0.6496946987130938,
 0.6496946987130938,
 0.6339499807634642,
 0.6339499807634642,
 0.631145116758524,
 0.631145116758524,
 0.631145116758524,
 0.631145116758524,
 0.631145116758524,
 0.6291134641099797,
 0.6225767138687908,
 0.6053137837751111,
 0.6001607849836882,
 0.5936711028152285,
 0.5900586710755834,
 0.5900586710755834,
 0.5861538168978517,
 0.5861538168978517,
 0.5843233368318901,
 0.5843233368318901,
 0.5843233368318901,
 0.580747960515266,
 0.580747960515266,
 

In [130]:
# for _, row in res.iterrows():
#     print(row.sulcus)

In [131]:
# res[res.sulcus=="S.F.orbitaire._right"]

In [132]:
# res

# 3. Anatomist functions

In [322]:


def set_color_property(res, side):
    global dic_window
    global param, p_value, threshold

    if side == "L":
        dic_window[f"aims{side}"] = aims.read(Lspam_model)
        dic_window[f"aims{side}"]['boundingbox_min'][0] = 0
    else:
        dic_window[f"aims{side}"] = aims.read(Rspam_model)
        dic_window[f"aims{side}"]['boundingbox_max'][0] = 0

    for vertex in dic_window[f"aims{side}"].vertices():
        vertex[param] = 0.
    print(f"param = {param}")
        
    unknown_vertices = []
    for vertex in dic_window[f"aims{side}"].vertices():
        vname = vertex.get('name')
        if vname == 'unknown':
            #print(f"Removing vertex with name: {name}")
            unknown_vertices.append(vertex)
    for vertex in unknown_vertices:
        dic_window[f"aims{side}"].removeVertex(vertex)

    for _, row in res.iterrows():
        for vertex in dic_window[f"aims{side}"].vertices():
            vname = vertex.get('name')
            if vname == row.sulcus:
                if row[p_value] < threshold:
                    vertex[param] = row[param]
    
    dic_window[f"ana{side}"] = a.toAObject(dic_window[f"aims{side}"])

    dic_window[f"ana{side}"].setColorMode(dic_window[f"ana{side}"].PropertyMap)
    dic_window[f"ana{side}"].setColorProperty(param)
    dic_window[f"ana{side}"].notifyObservers()
    
                
def visualize_whole_hemisphere(view_quaternion, side, i):
    global block
    global dic_window
    try:
        block
    except NameError:
        block = a.createWindowsBlock(4)

    dic_window[f"win{i}"] = a.createWindow('3D',
                                    block=block,
                                    no_decoration=True,
                                    options={'hidden': 1})
    dic_window[f"win{i}"].addObjects(dic_window[f"ana{side}"])
    
    # Trick to save the palette with the extrema colors
    dic_window[f"ana{side}"].setPalette("green_yellow_red")
    pal_im = dic_window[f"ana{side}"].palette().toQImage(256, 32)  # 256 x 32 est la taille que tu veux
    pal_im.save(f'/tmp/pal{i}.jpg')
    
    dic_window[f"ana{side}"].setPalette("green_yellow_red",
                              minVal=minVal, maxVal=maxVal,
                              absoluteMode=True)
    
    dic_window[f"win{i}"].camera(view_quaternion=view_quaternion)
    dic_window[f"win{i}"].setHasCursor(0)
    
    if SNAPSHOT:
        image_file = f"/tmp/snapshot{i}.jpg"
        dic_window[f"win{i}"].snapshot(image_file,
                                width=1250,
                                height=900)
        return image_file
    else:
        return ""
        

def get_bounding_box(img, threshold=254):
    # Convert the image to grayscale
    gray_img = ImageOps.grayscale(img)

    # Convert to a binary mask: black for non-white, white for white
    mask = gray_img.point(lambda p: 0 if p > threshold else 255)

    # Find the bounding box of the non-white region
    bbox = mask.getbbox()

    return bbox


def crop_to_bounding_box(img, threshold=254):
    # Get the bounding box
    bbox = get_bounding_box(img, threshold)

    if bbox:
        # Crop the image to the bounding box
        cropped_img = img.crop(bbox)
        return cropped_img
    else:
        # Return the original image if no non-white region is found
        return img
    
    
def zoom_image(source_img, zoom_factor=1.0):
    """
    Returns a zoomed version of source_img.

    Args:
        source_img (Image): The image to zoom and paste.
        zoom_factor (float): The zoom factor (e.g., 2.0 for 2x zoom).
    """
    # Calculate new size after zoom
    new_width = int(source_img.width * zoom_factor)
    new_height = int(source_img.height * zoom_factor)

    # Resize the image
    zoomed_img = source_img.resize((new_width, new_height), Image.LANCZOS)

    return zoomed_img


def align_images_horizontally_centered(images, separator_horizontal):
    """
    Aligns images horizontally, centered vertically.

    Args:
        images (list): List of images.
        separator_horizontal: separator between each image
    """

    # Calculate total width and max height
    total_width = sum(img.width for img in images) + (len(images)-1) * separator_horizontal
    max_height = max(img.height for img in images)

    # Create a new blank image
    combined = Image.new('RGB', (total_width, max_height), (255, 255, 255))

    # Paste each image horizontally, centered vertically
    x_offset = 0
    for img in images:
        # Calculate vertical offset to center the image
        y_offset = (max_height - img.height) // 2
        combined.paste(img, (x_offset, y_offset))
        x_offset += img.width + separator_horizontal

    # Returns the result
    return combined


def stack_images_vertically(image1, image2, separator_vertical):
    # Calculate the total height and the maximum width
    total_height = image1.height + image2.height + separator_vertical
    max_width = max(image1.width, image2.width)

    # Create a new blank image with the calculated dimensions
    new_image = Image.new('RGB', (max_width, total_height),(255, 255, 255))

    # Paste the first image at the top
    new_image.paste(image1, (0, 0))

    # Paste the second image below the first
    new_image.paste(image2, (0, image1.height + separator_vertical))

    return new_image


def match_widths_to_largest(image1, image2):
    # Determine the maximum width
    max_width = max(image1.width, image2.width)

    # Resize the images to match the maximum width
    if image1.width < max_width:
        # Calculate the new height to maintain aspect ratio
        ratio = max_width / image1.width
        new_height = int(image1.height * ratio)
        image1 = image1.resize((max_width, new_height), Image.LANCZOS)

    if image2.width < max_width:
        # Calculate the new height to maintain aspect ratio
        ratio = max_width / image2.width
        new_height = int(image2.height * ratio)
        image2 = image2.resize((max_width, new_height), Image.LANCZOS)

    return image1, image2


def draw_title(grid, title, font):  # ---- Title ----
    if title:
        draw = ImageDraw.Draw(grid)
        bbox = draw.textbbox((0, 0), title, font=font)
        text_w = bbox[2] - bbox[0]
        draw.text(
            ((grid.width - text_w) // 2, 5),
            title,
            fill=(0, 0, 0),
            font=font
        )

def add_vertical_palette_with_ticks_and_labels(
    main_image, palette_image, position,
    tick_positions, labels, criterion,
    tick_length=10, tick_color=(0, 0, 0), label_color=(0, 0, 0), font_size=12):
    """
    Paste a vertical palette image onto a main image with ticks and labels.

    Args:
        main_image: The main image (background).
        palette_image: The vertical palette image to paste.
        position: Tuple (x, y) for the top-left corner of the palette.
        tick_positions: List of relative positions (0-1) for ticks (e.g., [0, 0.25, 0.5, 0.75, 1]).
        labels: List of labels for each tick (e.g., ["0", "25", "50", "75", "100"]).
        tick_length: Length of ticks in pixels.
        tick_color: Color of ticks (RGB).
        label_color: Color of labels (RGB).
        font_size: Font size for labels.

    Returns:
        The combined image with palette, ticks, and labels.
    """
    # Paste the palette image
    main_image.paste(palette_image, position)

    # Create a drawing context
    draw = ImageDraw.Draw(main_image)

    # Load a font
    try:
        font = ImageFont.truetype("DejaVuSans.ttf", font_size)
    except:
        font = ImageFont.load_default()

    # Calculate palette dimensions
    palette_width, palette_height = palette_image.size
    x, y = position

    # Draw ticks and labels
    for i, pos in enumerate(tick_positions):
        # Calculate absolute tick position (vertically)
        if i == (len(tick_positions)-1):
            offset = -5
        else:
            offset = 0    
            
        tick_y = y + int(pos * palette_height) + offset

        # Draw tick (horizontal line to the left of the palette)
        draw.line([(x - tick_length, tick_y), (x, tick_y)], fill=tick_color, width=1)

        # Get the width of the bounding box of the text
        bbox = draw.textbbox((0, 0), labels[i], font=font)
        width_text = bbox[2] - bbox[0]
        
        # Calculate label position (left of the tick)
        label_x = x - width_text - palette_width  # - tick_length # Adjust as needed
        label_y = tick_y - font_size // 2  # Center vertically
    
        # Draw label
        draw.text((label_x, label_y), labels[i], fill=label_color, font=font)

        # Get the width of the bounding box of the text
        bbox = draw.textbbox((0, 0), criterion, font=font)
        width_text = bbox[2] - bbox[0]
        height_text = bbox[3] - bbox[1]
        x_title = x - width_text//2
        y_title = y-height_text - 50
        font_title = ImageFont.truetype("DejaVuSans-Bold.ttf", 36)
        # Draw title of palette
        draw.text((x_title, y_title), criterion, fill=label_color, font=font_title)

    return main_image


def add_left_right_text(grid, label_color=(0, 0, 0)):
    """Draws L and R text"""
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("DejaVuSans-Bold.ttf", 120)
    offset_x = 50
    offset_y = 100
    draw.text((offset_x, offset_y), "L", fill=label_color, font=font)
    draw.text((offset_x, grid.height//2+offset_y), "R", fill=label_color, font=font)
    
    return grid


def create_grid(image_files, n_cols, out_path, title=None, criterion=None,
                palette_path=None, vmin=None, vmax=None):

    zoom_factors = [1, 1.5, 1.5, 1]
    zoom_factors = zoom_factors * 2
    separator_horizontal = 50
    separator_vertical = 10


    # Loads, crops, zooms and matches images
    imgs = [Image.open(f) for f in image_files]
    imgs = [crop_to_bounding_box(img) for img in imgs]
    imgs = [zoom_image(img, zoom_factor) for img, zoom_factor in zip(imgs, zoom_factors)]
    for i in range(n_cols):
        imgs[i], imgs[i+n_cols] = match_widths_to_largest(imgs[i], imgs[i+n_cols])
    
    # Creates grid
    grid_top = align_images_horizontally_centered(imgs[:n_cols], separator_horizontal)
    grid_bottom = align_images_horizontally_centered(imgs[n_cols:], separator_horizontal)
    grid = stack_images_vertically(grid_top, grid_bottom, separator_vertical)

    font_size = 36
    font = ImageFont.truetype("DejaVuSans.ttf", font_size)

    # # ---- Create final canvas ----
    # grid = Image.new(
    #     'RGB',
    #     (sum(widths[:n_cols]), title_h + n_rows * h + palette_h),
    #     (255, 255, 255)
    # )

    draw_title(grid, title, font)
    
    # ---- Palette handling ----
    palette_margin = 20
    if palette_path:
        pal_img = Image.open(palette_path)
        pal_img = pal_img.rotate(90, expand=True)
        pal_w, pal_h = pal_img.size
        x_pal = (grid.width - pal_w) - palette_margin
        y_pal = (grid.height - pal_h) // 2
        
        # Define tick positions (0 to 1) and labels
        # tick_positions = [0, (1.-0.9)/0.5, (1.-0.7)/0.5, 1]
        # labels = ["1.0", "0.9", "0.7", "0.5"]
        tick_positions = [0, 1]
        labels = [str(vmax), str(vmin)]

        # Add the vertical palette with ticks and labels
        grid = add_vertical_palette_with_ticks_and_labels(
            main_image=grid,
            palette_image=pal_img,
            position=(x_pal, y_pal),  # Top-left corner of the palette
            tick_positions=tick_positions,
            labels=labels,
            criterion=criterion,
            tick_length=15,
            tick_color=(0, 0, 0),
            label_color=(0, 0, 0),
            font_size=36
        )
        
        grid = add_left_right_text(grid)

    grid.save(out_path)
    print(f"Snapshot of the block available at {out_path}")


def visualize_whole(res, side, start):
    set_color_property(res, side)
    first_img = visualize_whole_hemisphere(middle_view if side == "L" else side_view, side, start+0)
    second_img = visualize_whole_hemisphere(top_view, side, start+1)
    third_img = visualize_whole_hemisphere(bottom_view, side, start+2)
    fourth_img = visualize_whole_hemisphere(side_view if side == "L" else middle_view, side, start+3)
    return [first_img, second_img, third_img, fourth_img]


# Main function

In [156]:
a = anatomist.Anatomist()

In [157]:
middle_view = [0.5, -0.5, -0.5, 0.5]
side_view = [0.5, 0.5, 0.5, 0.5]
bottom_view = [0, -1, 0, 0]
top_view = [0, 0, 0, -1]

In [158]:
%matplotlib qt5

In [172]:
dic_window = {} # Global dictionary of windows
left_images = visualize_whole(res, "L", 0)
right_images = visualize_whole(res, "R", 4)
image_files = left_images + right_images

Reading FGraph version 2.0
param = auc
bounding box found : 0, -80, -90
                     90, 120, 60


bindOtherFramebuffer 0.9 : OpenGL error: invalid operation
bindOtherFramebuffer 0.9 : OpenGL error: invalid operation
bindOtherFramebuffer 0.9 : OpenGL error: invalid operation


Reading FGraph version 2.0
param = auc
bounding box found : -90, -80, -90
                     0, 120, 60


bindOtherFramebuffer 0.9 : OpenGL error: invalid operation
bindOtherFramebuffer 0.9 : OpenGL error: invalid operation
bindOtherFramebuffer 0.9 : OpenGL error: invalid operation
bindOtherFramebuffer 0.9 : OpenGL error: invalid operation


In [323]:
if SNAPSHOT:
    create_grid(
        image_files, 4,f"/tmp/grid.png", title=None,
        criterion="AUC", palette_path=f'/tmp/pal0.jpg',
        vmin=0.5,
        vmax=int(1))

Snapshot of the block available at /tmp/grid.png


In [191]:
l = [1, 2, 3, 4]
l[0] = 't'
l

['t', 2, 3, 4]