# Introduction

This notebook will take a data-driven approach to generating word lists for mental functions that are related to brain circuitry. The overall process is as follows:

1. Cluster brain structures into circuits by PMI-weighted co-occurrences with mental function terms.
2. Identify the mental function terms most highly associated to each circuit over a range of list lengths.
3. Select the list length for each circuit that maximizes word-structure classification performance. 
4. Select the number of circuits that maximizes circuit-function classification performance.

# Load the data

In [1]:
import numpy as np
import pandas as pd
import sys
sys.path.append("..")
import utilities, ontology

In [2]:
suffix = "_logreg"
cerebellum = "seg"

## Cluster range to visualize

In [3]:
circuit_counts = range(2, 26)

## Brain activation coordinates

In [4]:
act_bin = utilities.load_coordinates(cerebellum="seg")
print("Document N={}, Structure N={}".format(
      act_bin.shape[0], act_bin.shape[1]))

Document N=18155, Structure N=148


## Terms for mental functions

In [5]:
version = 190325
dtm_bin = utilities.load_doc_term_matrix(version=version, binarize=True)

In [6]:
lexicon = utilities.load_lexicon(["cogneuro"])
lexicon = sorted(list(set(lexicon).intersection(dtm_bin.columns)))
len(lexicon)

1683

In [7]:
dtm_bin = dtm_bin[lexicon]
print("Document N={}, Term N={}".format(
      dtm_bin.shape[0], dtm_bin.shape[1]))

Document N=18155, Term N=1683


## Document splits

In [8]:
train, val = [[int(pmid.strip()) 
               for pmid in open("../data/splits/{}.txt".format(split))] 
                    for split in ["train", "validation"]]
print("Training N={}, Validation N={}".format(len(train), len(val)))

Training N=12708, Validation N=3631


# Name the domains

In [9]:
k2terms, k2name = {}, {}
for k in circuit_counts:
    print("Processing k={}".format(k))
    lists, circuits = ontology.load_ontology(k, suffix=suffix, cerebellum=cerebellum)
    k2terms[k] = {i: list(set(lists.loc[lists["CLUSTER"] == i+1, "TOKEN"])) for i in range(k)}
    k2name[k] = {i+1: "" for i in range(k)}
    names, degs = [""]*k, [0]*k
    while "" in names:
        for i in range(k):
            degrees = ontology.term_degree_centrality(i+1, lists, dtm_bin, dtm_bin.index)
            degrees = degrees.loc[k2terms[k][i]].sort_values(ascending=False)
            name = degrees.index[0].upper().replace("_", " ")
            if name not in names:
                names[i] = name
                degs[i] = max(degrees)
                k2name[k][i+1] = name
            elif name in names:
                name_idx = names.index(name)
                if degs[name_idx] > degs[i]:
                    k2terms[k][i] = [term for term in k2terms[k][i] if term != name.lower().replace(" ", "_")]

Processing k=2
Processing k=3
Processing k=4
Processing k=5
Processing k=6
Processing k=7
Processing k=8
Processing k=9
Processing k=10
Processing k=11
Processing k=12
Processing k=13
Processing k=14
Processing k=15
Processing k=16
Processing k=17
Processing k=18
Processing k=19
Processing k=20
Processing k=21
Processing k=22
Processing k=23
Processing k=24
Processing k=25


In [10]:
names = []
for k in circuit_counts:
    names += list(k2name[k].values())
names = sorted(list(set(names)))

In [11]:
names_in_order = [
    'MEMORY',
    'EPISODIC MEMORY',
    'RETRIEVAL',
    'EMOTION',
    'MOOD',
    'REWARD',
    'VALENCE',
    'DECISION MAKING',
    'AGENCY',
    'JUDGING',
    'AROUSAL',
    'ANTICIPATION',
    'REACTION TIME',
    'COGNITION',
    'COGNITIVE',
    'COGNITIVE FUNCTION',
    'COGNITIVE PROCESS',
    'REPRESENTATION',
    'MANIPULATION',
    'PLANNING',
    'PREPARATION',
    'EXECUTION',
    'MOVEMENT',
    'MOTOR CONTROL',
    'MOTOR LEARNING',
    'COORDINATION',
    'ARM',
    'FOOT',
    'HAND',
    'REST',
    'HUNGER',
    'VESTIBULAR',
    'COVERT',
    'VISION',
    'IMAGERY',
    'HEARING',
    'PERCEPTION',
    'LANGUAGE',
    'WORD',
    'MEANING'
]

In [12]:
k2order = {}
for k in circuit_counts:
    name2i = {name: i for i, name in k2name[k].items()}
    order = []
    for i, name in enumerate(names_in_order):
        if name in name2i.keys():
            order.append(name2i[name])
    k2order[k] = order
k2name_ordered = {k: [k2name[k][i] for i in k2order[k]] for k in circuit_counts}

In [13]:
k2name_titles = {}
for k in circuit_counts:
    k2name_titles[k] = [dom.replace("_", " ").title() for dom in k2name_ordered[k]]
k2name_titles

{2: ['Manipulation', 'Movement'],
 3: ['Arousal', 'Manipulation', 'Execution'],
 4: ['Arousal', 'Manipulation', 'Movement', 'Hearing'],
 5: ['Arousal', 'Cognitive', 'Movement', 'Vision', 'Hearing'],
 6: ['Memory', 'Reward', 'Cognitive', 'Movement', 'Vision', 'Hearing'],
 7: ['Memory',
  'Reward',
  'Cognitive',
  'Execution',
  'Movement',
  'Vision',
  'Hearing'],
 8: ['Emotion',
  'Reward',
  'Manipulation',
  'Execution',
  'Movement',
  'Motor Control',
  'Vision',
  'Language'],
 9: ['Memory',
  'Reward',
  'Cognitive',
  'Manipulation',
  'Execution',
  'Movement',
  'Vestibular',
  'Vision',
  'Language'],
 10: ['Memory',
  'Emotion',
  'Reward',
  'Cognitive',
  'Manipulation',
  'Execution',
  'Movement',
  'Vision',
  'Perception',
  'Language'],
 11: ['Memory',
  'Emotion',
  'Anticipation',
  'Cognitive',
  'Execution',
  'Movement',
  'Motor Control',
  'Coordination',
  'Rest',
  'Vision',
  'Hearing'],
 12: ['Memory',
  'Emotion',
  'Reward',
  'Cognitive',
  'Manipulati

# Visualize the term lists

In [14]:
c = {"red": "#CE7D69", "orange": "#BA7E39", "yellow": "#CEBE6D", "chartreuse": "#AEC87C", "green": "#77B58A", 
     "blue": "#7597D0", "magenta": "#B07EB6", "purple": "#7D74A3", "brown": "#846B43", "pink": "#CF7593",
     "slate": "#6F8099", "crimson": "#8C4058", "gold": "#D8AE54", "teal": "#5AA8A7", "indigo": "#3A3C7C", 
     "lobster": "#FF7B5B", "olive": "#72662A", "lime": "#91E580", "sky": "#BCD5FF", "fuschia": "#E291DD",
     "violet": "#8B75EA", "tan": "#E0BD84", "berry": "#D64F7C", "mint": "#A4EACA", "sun": "#F4FF6B"}

palettes = {"data-driven": [c["blue"], c["magenta"], c["yellow"], c["green"], c["red"], 
                            c["purple"], c["chartreuse"], c["orange"], c["pink"], c["brown"], 
                            c["slate"], c["crimson"], c["gold"], c["teal"], c["indigo"],
                            c["lobster"], c["olive"], c["lime"], c["sky"], c["fuschia"],
                            c["violet"], c["tan"], c["berry"], c["mint"], c["sun"]],
            "rdoc": [c["blue"], c["red"], c["green"], c["purple"], c["yellow"], c["orange"]],
            "dsm": [c["purple"], c["chartreuse"], c["orange"], c["blue"], c["red"], c["magenta"], c["yellow"], c["green"], c["brown"]]}

In [15]:
def plot_wordclouds(k, domains, lists, dtm, framework="data-driven", suffix="lr", cerebellum="combo"):

    import os
    from wordcloud import WordCloud
    from style import style
    import matplotlib.pyplot as plt

    for i, dom in enumerate(domains):
        
        file_name = "figures/lists/{}_{}_{}_kvals/k{:02d}_wordcloud_{}.png".format(framework, suffix, cerebellum, k, dom)
        if not os.path.exists(file_name):
        
            def color_func(word, font_size, position, orientation, 
                           random_state=None, idx=0, **kwargs):
                return palettes[framework][i]

            tkns = lists.loc[lists["DOMAIN"] == dom, "TOKEN"]
            freq = dtm[tkns].sum().values
            tkns = [t.replace("_", " ") for t in tkns]
            dic = {tkn: f for tkn, f in zip(tkns, freq)}

            cloud = WordCloud(background_color="rgba(255, 255, 255, 0)", mode="RGB", 
                              max_font_size=100, prefer_horizontal=1, scale=20, margin=3,
                              width=550, height=850, font_path=style.font, 
                              random_state=42).generate_from_frequencies(dic)

            fig = plt.figure(1, figsize=(2,10))
            plt.axis("off")
            plt.imshow(cloud.recolor(color_func=color_func, random_state=42))
            plt.savefig(file_name, 
                        dpi=800, bbox_inches="tight")
            utilities.transparent_background(file_name)
            plt.close()

In [16]:
for k in circuit_counts:
    lists, circuits = ontology.load_ontology(k, suffix=suffix, cerebellum=cerebellum)
    lists["DOMAIN"] = [k2name[k][i] for i in lists["CLUSTER"]]
    lists_ordered = pd.DataFrame()
    for name in k2name_ordered[k]:
        lists_ordered = lists_ordered.append(lists.loc[lists["DOMAIN"] == name])
    plot_wordclouds(k, k2name_ordered[k], lists_ordered, dtm_bin, cerebellum=cerebellum)

# Visualize the circuits

In [17]:
import os
from style import style

In [18]:
atlas = utilities.load_atlas(cerebellum=cerebellum)



In [19]:
purples = style.make_cmap([(1,1,1), (0.365,0,0.878)])
chartreuses = style.make_cmap([(1,1,1), (0.345,0.769,0)])
magentas = style.make_cmap([(1,1,1), (0.620,0,0.686)])
yellows = style.make_cmap([(1,1,1), (0.937,0.749,0)])
browns = style.make_cmap([(1,1,1), (0.82,0.502,0)])
pinks = style.make_cmap([(1,1,1), (0.788,0,0.604)])
cmaps = ["Blues", magentas, yellows, "Greens", "Reds", 
         purples, chartreuses, "Oranges", pinks, browns]

In [20]:
framework = "data-driven"
for k in range(2, len(cmaps) + 1):
    path = "figures/circuits/{}_{}_{}_kvals/k{:02}".format(framework, "lr", cerebellum, k)
    if not os.path.exists(path):
        os.makedirs(path)
    _, circuits = ontology.load_ontology(k, suffix=suffix)
    circuits["DOMAIN"] = [k2name[k][i] for i in circuits["CLUSTER"]]
    circuit_mat = pd.DataFrame(0.0, index=act_bin.columns, columns=k2name_ordered[k])
    for name in k2name_ordered[k]:
        structures = circuits.loc[circuits["DOMAIN"] == name, "STRUCTURE"]
        for structure in structures:
            circuit_mat.loc[structure, name] = 1.0
    utilities.map_plane(circuit_mat, atlas, path, 
                        suffix="_z", cmaps=cmaps, plane="z", cbar=False, vmin=0.0, vmax=2.0,
                        verbose=False, print_fig=False, annotate=True)

  get_mask_bounds(new_img_like(img, not_mask, affine))


# Export the results

## File structure

In [21]:
for k in circuit_counts:
    path = "../../nke-cerebellum-viewer/data/k{:02d}".format(k)
    if not os.path.exists(path):
        os.mkdir(path)

## Word lists

In [22]:
freq = dtm_bin.sum(axis=0)
freq.head()

3d_object                 171.0
abductive_reasoning         2.0
abstract_analogy            3.0
abstract_concrete_task      6.0
abstract_knowledge         40.0
dtype: float64

In [23]:
for k in circuit_counts:
    lists, _ = ontology.load_ontology(k, suffix=suffix, cerebellum=cerebellum)
    lists["DOMAIN"] = [k2name[k][i] for i in lists["CLUSTER"]]
    lists["CLUSTER"] = [k2name_ordered[k].index(dom) + 1 for dom in lists["DOMAIN"]]
    lists["FREQUENCY"] = [freq.loc[term] for term in lists["TOKEN"]]
    lists_ordered = pd.DataFrame()
    for name in k2name_ordered[k]:
        lists_ordered = lists_ordered.append(lists.loc[lists["DOMAIN"] == name])
    lists = lists.sort_values(["CLUSTER", "R"], ascending=[True, False])
    file = "../../nke-cerebellum-viewer/data/k{:02d}/words_k{:02d}.csv".format(k, k)
    lists.to_csv(file, index=None)

## Brain circuits

In [24]:
from nilearn import image, plotting
from statsmodels.stats.multitest import multipletests

In [25]:
def load_atlas_2mm(cerebellum="combo"):

    import numpy as np
    from nilearn import image

    cer = "../data/brain/atlases/Cerebellum-MNIfnirt-maxprob-thr25-2mm.nii.gz"
    cor = "../data/brain/atlases/HarvardOxford-cort-maxprob-thr25-2mm.nii.gz"
    sub = "../data/brain/atlases/HarvardOxford-sub-maxprob-thr25-2mm.nii.gz"

    sub_del_dic = {1:0, 2:0, 3:0, 12:0, 13:0, 14:0}
    sub_lab_dic_L = {4:1, 5:2, 6:3, 7:4, 9:5, 10:6, 11:7, 8:8}
    sub_lab_dic_R = {15:1, 16:2, 17:3, 18:4, 19:5, 20:6, 21:7, 7:8}

    sub_mat_L = image.load_img(sub).get_data()[46:,:,:]
    sub_mat_R = image.load_img(sub).get_data()[:46,:,:]

    for old, new in sub_del_dic.items():
        sub_mat_L[sub_mat_L == old] = new
    for old, new in sub_lab_dic_L.items():
        sub_mat_L[sub_mat_L == old] = new
    sub_mat_L = sub_mat_L + 48
    sub_mat_L[sub_mat_L == 48] = 0

    for old, new in sub_del_dic.items():
        sub_mat_R[sub_mat_R == old] = new
    for old, new in sub_lab_dic_R.items():
        sub_mat_R[sub_mat_R == old] = new
    sub_mat_R = sub_mat_R + 48
    sub_mat_R[sub_mat_R == 48] = 0

    cor_mat_L = image.load_img(cor).get_data()[46:,:,:]
    cor_mat_R = image.load_img(cor).get_data()[:46,:,:]

    mat_L = np.add(sub_mat_L, cor_mat_L)
    mat_L[mat_L > 56] = 0
    mat_R = np.add(sub_mat_R, cor_mat_R)
    mat_R[mat_R > 56] = 0

    if cerebellum == "combo":
        mat_R = mat_R + 59
        mat_R[mat_R > 118] = 0
        mat_R[mat_R < 60] = 0

    elif cerebellum == "seg":
        mat_R = mat_R + 74
        mat_R[mat_R > 148] = 0
        mat_R[mat_R < 75] = 0

    cer_mat_L = image.load_img(cer).get_data()[46:,:,:]
    cer_mat_R = image.load_img(cer).get_data()[:46,:,:]

    if cerebellum == "combo":
        cer_mat_L[np.isin(cer_mat_L,[1,3,5,14,17,20,23,26])] = 57
        cer_mat_L[np.isin(cer_mat_L,[8,11])] = 58
        cer_mat_L[np.isin(cer_mat_L,[6,9,12,15,18,21,24,27])] = 59
        cer_mat_R[np.isin(cer_mat_R,[2,4,7,16,19,22,25,28])] = 116
        cer_mat_R[np.isin(cer_mat_R,[10,13])] = 117
        cer_mat_R[np.isin(cer_mat_R,[6,9,12,15,18,21,24,27])] = 118

        mat_L = np.add(mat_L, cer_mat_L)
        mat_L[mat_L > 59] = 0
        mat_R = np.add(mat_R, cer_mat_R)
        mat_R[mat_R > 118] = 0

    elif cerebellum == "seg":
        cer_mat_L[cer_mat_L == 1] = 57
        cer_mat_L[cer_mat_L == 3] = 58
        cer_mat_L[cer_mat_L == 5] = 59
        cer_mat_L[cer_mat_L == 6] = 69
        cer_mat_L[cer_mat_L == 8] = 65
        cer_mat_L[cer_mat_L == 9] = 67
        cer_mat_L[cer_mat_L == 11] = 66
        cer_mat_L[cer_mat_L == 12] = 68
        cer_mat_L[cer_mat_L == 14] = 60
        cer_mat_L[cer_mat_L == 15] = 70
        cer_mat_L[cer_mat_L == 17] = 61
        cer_mat_L[cer_mat_L == 18] = 71
        cer_mat_L[cer_mat_L == 20] = 62
        cer_mat_L[cer_mat_L == 21] = 72
        cer_mat_L[cer_mat_L == 23] = 63
        cer_mat_L[cer_mat_L == 24] = 73
        cer_mat_L[cer_mat_L == 26] = 64
        cer_mat_L[cer_mat_L == 27] = 74

        cer_mat_R[cer_mat_R == 2] = 131
        cer_mat_R[cer_mat_R == 4] = 132
        cer_mat_R[cer_mat_R == 6] = 143
        cer_mat_R[cer_mat_R == 7] = 133
        cer_mat_R[cer_mat_R == 9] = 141
        cer_mat_R[cer_mat_R == 10] = 139
        cer_mat_R[cer_mat_R == 12] = 142
        cer_mat_R[cer_mat_R == 13] = 140
        cer_mat_R[cer_mat_R == 15] = 144
        cer_mat_R[cer_mat_R == 16] = 134
        cer_mat_R[cer_mat_R == 18] = 145
        cer_mat_R[cer_mat_R == 19] = 135
        cer_mat_R[cer_mat_R == 21] = 146
        cer_mat_R[cer_mat_R == 22] = 136
        cer_mat_R[cer_mat_R == 24] = 147
        cer_mat_R[cer_mat_R == 25] = 137
        cer_mat_R[cer_mat_R == 27] = 148
        cer_mat_R[cer_mat_R == 28] = 138

        mat_L = np.add(mat_L, cer_mat_L)
        mat_L[mat_L > 75] = 0
        mat_R = np.add(mat_R, cer_mat_R)
        mat_R[mat_R > 148] = 0
        
    mat = np.concatenate((mat_R, mat_L), axis=0)
    atlas_image = image.new_img_like(sub, mat)

    return atlas_image

In [26]:
atlas_2mm = load_atlas_2mm(cerebellum=cerebellum)

In [27]:
def threshold_pmi_by_fdr(pmi, act_bin, scores, n_iter=1000, verbose=False):
    pmi_null = ontology.compute_cooccurrences_null(act_bin, scores, n_iter=n_iter, verbose=verbose)
    p = pd.DataFrame(index=act_bin.columns, columns=scores.columns)
    for i, struct in enumerate(act_bin.columns):
        for j, dom in enumerate(scores.columns):
            obs = pmi.values[i,j]
            null = pmi_null[i,j,:]
            p.loc[struct, dom] = np.sum(null > obs) / float(n_iter)
    fdr = multipletests(p.values.ravel(), method="fdr_bh")[1]
    fdr = pd.DataFrame(fdr.reshape(p.shape), index=act_bin.columns, columns=scores.columns)
    pmi_thres = pmi[fdr < 0.01]
    pmi_thres = pmi_thres.fillna(0.0)
    return pmi_thres

In [28]:
for k in circuit_counts:
    print("Processing k={:02d}".format(k))
    lists, circuits = ontology.load_ontology(k, suffix=suffix, cerebellum=cerebellum)
    scores = utilities.score_lists(lists, dtm_bin, label_var="CLUSTER").loc[act_bin.index]
    pmi = ontology.compute_cooccurrences(act_bin, scores, positive=True)
    pmi = threshold_pmi_by_fdr(pmi, act_bin, scores, n_iter=1000, verbose=False)
    
    for struct in pmi.index:
        domain = circuits.loc[circuits["STRUCTURE"] == struct, "CLUSTER"].values[0]
        for k_i in range(1, k+1):
            if k_i != domain:
                pmi.loc[struct, k_i] = 0
    
    for f, feature in enumerate(k2order[k]):
        stat_map = image.copy_img(atlas_2mm).get_data()
        data = pmi[feature]
        for i, value in enumerate(data):
            stat_map[stat_map == i+1] = value
        stat_img = image.new_img_like(atlas_2mm, stat_map)
        
        img_file = "../../nke-cerebellum-viewer/data/k{:02d}/circuit_k{:02d}_dom{:02d}.nii.gz".format(k, k, f+1)
        stat_img.to_filename(img_file)

Processing k=02
Processing k=03
Processing k=04
Processing k=05
Processing k=06
Processing k=07
Processing k=08
Processing k=09
Processing k=10
Processing k=11
Processing k=12
Processing k=13
Processing k=14
Processing k=15
Processing k=16
Processing k=17
Processing k=18
Processing k=19
Processing k=20
Processing k=21
Processing k=22
Processing k=23
Processing k=24
Processing k=25
