# Brain region colors

## Imports

In [None]:
import numpy as np

In [None]:
import colorio
from colorio.cs import ColorCoordinates, HSV, OKLAB, SRGB1, XYZ100

In [None]:
from ibllib.atlas.regions import BrainRegions, FILE_BERYL

In [None]:
from IPython.core.display import display, HTML

In [None]:
from ipywidgets import interact

In [None]:
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

%matplotlib inline

plt.rcParams["figure.dpi"] = 100
plt.rcParams["axes.grid"] = False
#sns.set_theme(style="white")

## BrainRegions

In [None]:
br = BrainRegions()

In [None]:
beryl = np.sort(np.load(FILE_BERYL))

In [None]:
kept = br.ancestors(beryl)['id']

In [None]:
root = 997  # 0
basic = 8  # 1
cerebrum = 567  # 2
cortex = 688  # 3
isocortex = 315  # 4
nuclei = 623  # 3
brainstem = 343  # 2
interbrain = 1129  # 3
midbrain = 313  # 3
hindbrain = 1065  # 3
cerebellum = 512  # 2
thalamus = 549  # 4
hypothalamus = 1097  # 4

## Brain regions attributes

In [None]:
br.acronym

In [None]:
br.id

In [None]:
br.level

In [None]:
br.mappings

In [None]:
br.name

In [None]:
br.parent

In [None]:
br.rgb

## Color analysis of Beryl regions

### Utility functions

In [None]:
def r2h(rgb_float):
    # assume [0, 1]
    return mcolors.rgb_to_hsv(rgb_float)

def h2r(hsv_float):
    # assume [0, 1]
    return mcolors.hsv_to_rgb(hsv_float)

In [None]:
def children(ids):
    return br.id[np.isin(br.parent, np.array(ids))]

In [None]:
def children_colors(rid, restrict_to_beryl=None):
    cid = children(rid)
    #print(br.name[np.isin(br.id, did)])
    if restrict_to_beryl:
        ids = cid[np.nonzero(np.isin(cid, beryl))[0]]
    else:
        ids = cid
    return ids, br.rgb[np.isin(br.id, ids)] / 255.0

### Plotting functions

In [None]:
def plot_hsv(ids, hsv_float):
    n = len(hsv_float)
    x = np.arange(n)
    s = 'HSV'
    for i, hsv_triplet in enumerate(hsv_float):
        for u in range(3):
            plt.text(
                i, hsv_triplet[u], s[u], size=20,
                color=tuple(h2r(hsv_triplet)));
    plt.xlim(-.5, n);
    plt.ylim(0, 1.2);
    plt.xticks(x, labels=br.acronym[np.isin(br.id, ids)]);

In [None]:
def color_rectangle(level, rgb, is_custom=False, color='#333'):
    rgb = np.array(rgb)
    assert np.all(rgb <= 1)
    rgb = tuple(map(int, (rgb * 255).astype(np.uint8)))
    return (
        f'<div style="'
            f'width: 100px; height: 30px; '
            f'background-color: rgb{rgb}; '
            f'margin: 5px 10px 0 {level * 30}px; '
            f'padding: 2px 0 0 5px; ' 
            f'font-size: .75em; '
            f'border: 1px solid #aaa; '
            f'color: {color}; '
        f'">{"#%02x%02x%02x" % rgb}</div>\n'
    )

### Analysis

In [None]:
def plot_children(ids, **kwargs):
    cids, rgb = children_colors(ids, restrict_to_beryl=kwargs.get('restrict_to_beryl', None))
    hsv = r2h(rgb)
    plot_hsv(cids, hsv)
    show_children_colors(ids, **kwargs)

## Beryl regions with colors

Display all Beryl regions with their colors:

## Thalamus

- LVL4 Thalamus (549, TH) remains at #ff7080
- LVL5 Thalamus sensory-motor cortex related (864, DORsm) remains at #ff8084
    - **but all of its children need a different color**
    - **LVL6 Ventral group of the dorsal thalamus (637, VENT) #ff8084** needs a different color, and its children too
- LVL5 Thalamus polymodal association cortex related (856, DORpm) remains at #ff909f
    - **but all of its children need a different color**
    - **and each LVL7 great-children too**

## Color space conversion

In [None]:
def oklab_to_xyz(L, a, b):
    M1 = np.array([
        +0.8189330101,
        +0.0329845436,
        +0.0482003018,
        +0.3618667424,
        +0.9293118715,
        +0.2643662691,
        -0.1288597137,
        +0.0361456387,
        +0.6338517070,
    ]).reshape((3, 3)).T

    M2 = np.array([
        +0.2104542553,
        +1.9779984951,
        +0.0259040371,
        +0.7936177850,
        -2.4285922050,
        +0.7827717662,
        -0.0040720468,
        +0.4505937099,
        -0.8086757660,
    ]).reshape((3, 3)).T

    Lab = np.c_[L, a, b].T
    lmsp = np.linalg.inv(M2) @ Lab
    lms = lmsp ** 3
    XYZ = np.linalg.inv(M1) @ lms
    return XYZ

In [None]:
def xyz_to_rgb(xyz):
    M = np.array([
      [3.2404542, -1.5371385, -0.4985314],
      [-0.9692660, 1.8760108, 0.0415560],
      [0.0556434, -0.2040259, 1.0572252],
    ])
    return M @ xyz

In [None]:
def oklab_to_rgb(L, a, b):
    xyz = oklab_to_xyz(L, a, b)
    return xyz_to_rgb(xyz)

In [None]:
def make_gradient(L0, C0, n):
    L = L0 * np.ones(n)
    C = C0 * np.ones(n)
    h = np.linspace(-np.pi, np.pi, n)
    a = C * np.cos(h)
    b = C * np.sin(h)
    return oklab_to_rgb(L, a, b).T

In [None]:
@interact(L0=(0.0, 1.0, 0.01), C0=(0.0, 1.0, 0.01))
def f(L0=0.8, C0=0.1):
    s = '<div style="display: flex;">'
    for rgb in make_gradient(L0, C0, 256):
        s += f'<div style="display: flexbox; width: 5px; height: 100px; background-color: rgb{tuple((rgb * 255).astype(np.int32))};"></div>\n'
    s += '</div>'
    display(HTML(s))

## Generating color variations

### Functions

In [None]:
def variants(rgb, seed=0, hstd=0, sstd=0, vstd=0):
    hsv = r2h(rgb)
    n = len(hsv)
    shape = hsv.shape
    #np.random.seed(seed)
    hsv[:, 0] += np.random.normal(size=n, loc=0, scale=hstd)
    hsv[:, 0] %= 1
    hsv[:, 1] += np.random.normal(size=n, loc=0, scale=sstd)
    hsv[:, 2] += np.random.normal(size=n, loc=0, scale=vstd)
    hsv = np.clip(hsv, 0, 1)
    return h2r(hsv)

In [None]:
def variant_children_colors(id, color, hstd=None, sstd=None, vstd=None):
    ids, colors = children_colors(id, restrict_to_beryl=False)
    if len(ids) == 0:
        return {}
    # If all children colors are identical, we modify them
    if np.all(np.std(colors, axis=0) < 1e-10):
        colors = variants(colors, hstd=hstd, sstd=sstd, vstd=vstd)
    custom = {id: tuple(c.ravel()) for id, c in zip(ids, colors)}
    return custom

In [None]:
def variant_descendent_colors(id, color, hstd=None, sstd=None, vstd=None, decrease_coef=None):
    custom = variant_children_colors(id, color, hstd=hstd, sstd=sstd, vstd=vstd)
    for child, child_color in custom.copy().items():
        custom.update(variant_descendent_colors(child, child_color, hstd=hstd / decrease_coef, sstd=sstd / decrease_coef, vstd=vstd / decrease_coef, decrease_coef=decrease_coef))
    return custom

### Testing

In [None]:
js = '''
function expand() {
    $("details").attr("open", true);
}

function collapse(level) {
    expand();
    $(".level-" + level).attr("open", false);
}
'''

In [None]:
css = '''
.region-label:hover {color: #000 !important;}
'''

In [None]:
area = root
hstd = .025
sstd = .03
vstd = .02
decrease_coef = 1.01

#np.random.seed(10)
idx = np.nonzero(br.id == 864)[0]
color = br.rgb[idx] / 255.0
colorv = variants(np.array(color), hstd=hstd, sstd=sstd, vstd=vstd)

In [None]:
custom = variant_descendent_colors(area, colorv, hstd=hstd, sstd=sstd, vstd=vstd, decrease_coef=decrease_coef)

In [None]:
ids = br.id[np.isin(br.id, beryl)]

In [None]:
L0 = .85
C0 = .1
n = len(beryl)
grad = make_gradient(L0, C0, n)
grad = np.clip(grad, 0, 1)
custom = {id: grad[(-50 + i) % n] for i, id in enumerate(ids)}

def show_children_colors(ids, recursive=False, restrict_to_beryl=False, max_level=10, custom=None):
    if recursive:
        cids = br.descendants(ids)['id']
    else:
        cids = children(ids)
    idx = np.isin(br.id, cids) & (br.id >= 0) & np.isin(br.id, kept) & (br.level <= max_level)
    idx |= np.isin(br.id, ids)

    ids = br.id[idx]

    names = br.name[idx]
    acronyms = br.acronym[idx]
    colors = br.rgb[idx] / 255.0
    levels = br.level[idx].astype(np.int64)
    inberyls = np.isin(ids, beryl)
    l0 = levels.min()
    last_level = -1
    is_custom = False

    assert len(ids) == len(names) == len(acronyms) == len(colors) == len(levels) == len(inberyls)
    s = ''
    for id, name, acronym, color, level, inberyl in zip(ids, names, acronyms, colors, levels, inberyls):
        if name == 'fiber tracts':
            break
        if restrict_to_beryl and not recursive and level == l0 + 1 and not inberyl:
            continue
        rgb = tuple(color)
        
        if level == last_level:
            s += '</details>\n'
        elif level < last_level:
            s += ('</details>' * (last_level - level + 1))
        s += (f'<details open="true" id="area-{id}" class="level-{level}">\n')
        
        s += (
            f'<summary style="list-style: none; cursor: pointer;">\n'
            
            f'<div style="display: flex;">\n')
        
        # Color
        s += color_rectangle(level - l0, rgb)
        
        # Optional custom color
        if custom:
            custom_color = custom.get(id, rgb)
            is_custom = np.abs(np.array(custom_color) - np.array(rgb)).max() > 1e-10
            s += color_rectangle(0, (1, 1, 1) if not is_custom else custom_color, is_custom=is_custom, color='#fff' if not is_custom else '#333')

        # Label
        s += (
        f'<div class="region-label" style="padding-top: 7px; '
        f'font-weight: {"bold" if inberyl else "normal"}; '
        f'color: {"#555" if inberyl else "#999"}; '
        f'background-color: {"none" if not is_custom else "#ffd"}; '
        f'">L{int(level)} {name} (#{id}, {acronym})</div>\n')
            
        c = np.array(custom_color)
        assert np.all(c <= 1)
        c = tuple(map(int, (c * 255).astype(np.uint8)))
            
        # Lines.
        if is_custom:
            s += (
            f'<div style="margin-left: 20px; margin-top: 5px;">'
                # darkbg
                f'<div style="width: 150px; padding: 5px; background-color: #fff;">'
                f'<div style=" height: 6px; background-color: rgb{c};"></div>'
                f'</div>'
                # white bg
                f'<div style="width: 150px; padding: 4px; background-color: #000;">'
                f'<div style=" height: 6px; background-color: rgb{c};"></div>'
                f'</div>'
            f'</div>\n')

        s += (
        f'</div>\n'

        f'</summary>\n')
        
        last_level = level
    s += ("</details>" * (level))
    return s
    

s = show_children_colors(area, recursive=True, custom=custom, max_level=7)
#display(HTML(s))

with open("docs/index.html", "w") as f:
    f.write(f'''
    <html>
    <head><title>Brain region colors</title></head>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
    <script>
    {js}
    </script>
    <style>
    {css}
    </style>
    <body>
    <div>
    This is a first attempt at providing a modified color map for the Allen Atlas brain regions, for which many regions have identical colors.
    Notes:
    <ul>
    <li>This list shows all Beryl brain regions and their ancestors.</li>
    <li>Click on a region to collapse/extend its descendents.</li>
    <li>For each region, two colors are shown: the original Allen Atlas color, and the optionally modified one.</li>
    <li>For each region, the name, region id, acronym, tree level are shown.</li>
    <li>The name is in bold if the region is a Beryl region.</li>
    <li>Regions with a modified color are highlighted in yellow.</li>
    <li>A modified color is proposed for a region if all of its siblings have the same color.</li>
    <li>A color variant is computed by adding a small random (normal) perturbation to the H, S, V components.</li>
    <li>Required improvements: find a smarter algorithm for generating visually distinct color variations.</li>
    </ul>
    </div>
    <div style="margin: 20px;">
    <button onclick="collapse(2);">Collapse to L2</button>
    <button onclick="collapse(3);">Collapse to L3</button>
    <button onclick="collapse(4);">Collapse to L4</button>
    <button onclick="collapse(5);">Collapse to L5</button>
    <button onclick="collapse(6);">Collapse to L6</button>
    <button onclick="expand();">Expand all</button>
    </div>
    {s}
    </body>
    </html>
    ''')