# Introduction

This notebook will take a data-driven approach to generating word lists for mental functions that are related to brain circuitry. The overall process is as follows:

1. Cluster brain structures into circuits by PMI-weighted co-occurrences with mental function terms.
2. Identify the mental function terms most highly associated to each circuit over a range of list lengths.
3. Select the list length for each circuit that maximizes word-structure classification performance. 
4. Select the number of circuits that maximizes circuit-function classification performance.

# Load the data

In [1]:
import numpy as np
import pandas as pd
import sys
sys.path.append("..")
import utilities, ontology

## Cluster range to visualize

In [2]:
circuit_counts = range(2, 26)

## Brain activation coordinates

In [3]:
act_bin = utilities.load_coordinates()
print("Document N={}, Structure N={}".format(
      act_bin.shape[0], act_bin.shape[1]))

Document N=18155, Structure N=114


## Terms for mental functions

In [4]:
version = 190325
dtm_bin = utilities.load_doc_term_matrix(version=version, binarize=True)

In [5]:
lexicon = utilities.load_lexicon(["cogneuro"])
lexicon = sorted(list(set(lexicon).intersection(dtm_bin.columns)))
len(lexicon)

1683

In [6]:
dtm_bin = dtm_bin[lexicon]
print("Document N={}, Term N={}".format(
      dtm_bin.shape[0], dtm_bin.shape[1]))

Document N=18155, Term N=1683


In [7]:
# Total occurrences of terms in the lexicon
dtm = utilities.load_doc_term_matrix(version=version, binarize=False)
dtm = dtm[lexicon]
np.sum(dtm.values)

4831488

## Document splits

In [8]:
train, val = [[int(pmid.strip()) 
               for pmid in open("../data/splits/{}.txt".format(split))] 
                    for split in ["train", "validation"]]
print("Training N={}, Validation N={}".format(len(train), len(val)))

Training N=12708, Validation N=3631


# Name the domains

In [9]:
k2terms, k2name = {}, {}
for k in circuit_counts:
    lists, circuits = ontology.load_ontology(k)
    k2terms[k] = {i: list(set(lists.loc[lists["CLUSTER"] == i+1, "TOKEN"])) for i in range(k)}
    k2name[k] = {i+1: "" for i in range(k)}
    names, degs = [""]*k, [0]*k
    while "" in names:
        for i in range(k):
            degrees = ontology.term_degree_centrality(i+1, lists, circuits, dtm_bin, train)
            degrees = degrees.loc[k2terms[k][i]].sort_values(ascending=False)
            name = degrees.index[0].upper()
            if name not in names:
                names[i] = name
                degs[i] = max(degrees)
                k2name[k][i+1] = name
            elif name in names:
                name_idx = names.index(name)
                if degs[name_idx] > degs[i]:
                    k2terms[k][i] = [term for term in k2terms[k][i] if term != name.lower()]

In [10]:
k2order = {2: [1,2],
           3: [1,2,3],
           4: [4,3,1,2],
           5: [4,1,3,2,5],
           6: [3,6,5,4,2,1],
           7: [6,3,7,2,5,4,1],
           8: [4,8,5,1,3,2,7,6],
           9: [5,9,1,2,8,3,6,7,4],
           10: [5,1,10,3,9,6,8,2,7,4],
           11: [1,11,8,9,10,5,2,3,4,6,7],
           12: [4,6,7,9,3,1,10,11,12,2,8,5],
           13: [5,10,7,1,9,13,2,6,11,3,4,12,8],
           14: [11,8,6,7,10,13,12,3,1,9,4,14,5,2],
           15: [2,13,14,12,6,7,15,5,4,8,10,11,3,1,9],
           16: [1,2,6,10,7,15,16,9,13,4,3,12,11,5,14,8],
           17: [7,8,10,13,4,17,5,15,6,11,3,2,14,1,9,12,16],
           18: [4,10,14,15,5,1,9,11,6,13,17,2,8,12,3,16,7,18],
           19: [1,10,16,11,14,5,6,12,13,9,18,3,7,4,2,17,19,15,8],
           20: [3,5,13,12,18,20,7,15,10,17,6,1,4,8,11,19,16,9,2,14],
           21: [3,6,12,17,9,18,20,5,11,15,7,16,2,21,13,19,1,4,10,8,14],
           22: [3,13,20,22,6,14,8,11,9,12,17,1,16,5,7,18,2,19,10,4,21,15],
           23: [1,8,10,9,14,15,20,19,21,4,23,12,17,18,11,3,2,5,16,6,13,7,22],
           24: [9,15,17,18,21,4,22,23,7,19,10,12,16,14,5,11,2,8,13,6,20,3,1,24],
           25: [3,13,20,22,23,6,24,8,11,9,12,17,1,16,18,5,2,7,25,19,10,4,21,14,15]}

In [11]:
k2name_ordered = {k: [k2name[k][i] for i in k2order[k]] for k in circuit_counts}

In [12]:
k2name_titles = {}
for k in circuit_counts:
    k2name_titles[k] = [dom.replace("_", " ").title() for dom in k2name_ordered[k]]
k2name_titles

{2: ['Arousal', 'Manipulation'],
 3: ['Arousal', 'Manipulation', 'Language'],
 4: ['Memory', 'Reaction Time', 'Vision', 'Hearing'],
 5: ['Memory', 'Reaction Time', 'Manipulation', 'Vision', 'Language'],
 6: ['Memory',
  'Reward',
  'Reaction Time',
  'Manipulation',
  'Vision',
  'Language'],
 7: ['Emotion',
  'Anticipation',
  'Cognitive Process',
  'Manipulation',
  'Vision',
  'Hearing',
  'Word'],
 8: ['Memory',
  'Episodic Memory',
  'Emotion',
  'Reaction Time',
  'Manipulation',
  'Vision',
  'Hearing',
  'Language'],
 9: ['Episodic Memory',
  'Reward',
  'Anticipation',
  'Arousal',
  'Manipulation',
  'Memory',
  'Vision',
  'Hearing',
  'Language'],
 10: ['Memory',
  'Reward',
  'Arousal',
  'Cognitive',
  'Manipulation',
  'Episodic Memory',
  'Recall',
  'Vision',
  'Hearing',
  'Language'],
 11: ['Memory',
  'Episodic Memory',
  'Emotion',
  'Decision Making',
  'Anticipation',
  'Arousal',
  'Manipulation',
  'Vision',
  'Hearing',
  'Language',
  'Meaning'],
 12: ['Memor

# Visualize the term lists

In [13]:
c = {"red": "#CE7D69", "orange": "#BA7E39", "yellow": "#CEBE6D", "chartreuse": "#AEC87C", "green": "#77B58A", 
     "blue": "#7597D0", "magenta": "#B07EB6", "purple": "#7D74A3", "brown": "#846B43", "pink": "#CF7593",
     "slate": "#6F8099", "crimson": "#8C4058", "gold": "#D8AE54", "teal": "#5AA8A7", "indigo": "#3A3C7C", 
     "lobster": "#FF7B5B", "olive": "#72662A", "lime": "#91E580", "sky": "#BCD5FF", "fuschia": "#E291DD",
     "violet": "#8B75EA", "tan": "#E0BD84", "berry": "#D64F7C", "mint": "#A4EACA", "sun": "#F4FF6B"}

palettes = {"data-driven": [c["blue"], c["magenta"], c["yellow"], c["green"], c["red"], 
                            c["purple"], c["chartreuse"], c["orange"], c["pink"], c["brown"], 
                            c["slate"], c["crimson"], c["gold"], c["teal"], c["indigo"],
                            c["lobster"], c["olive"], c["lime"], c["sky"], c["fuschia"],
                            c["violet"], c["tan"], c["berry"], c["mint"], c["sun"]],
            "rdoc": [c["blue"], c["red"], c["green"], c["purple"], c["yellow"], c["orange"]],
            "dsm": [c["purple"], c["chartreuse"], c["orange"], c["blue"], c["red"], c["magenta"], c["yellow"], c["green"], c["brown"]]}

In [14]:
def plot_wordclouds(k, domains, lists, dtm, framework="data-driven"):

    from wordcloud import WordCloud
    import matplotlib.pyplot as plt

    for i, dom in enumerate(domains):
        def color_func(word, font_size, position, orientation, 
                       random_state=None, idx=0, **kwargs):
            return palettes[framework][i]

        tkns = lists.loc[lists["DOMAIN"] == dom, "TOKEN"]
        freq = dtm[tkns].sum().values
        tkns = [t.replace("_", " ") for t in tkns]

        cloud = WordCloud(background_color="rgba(255, 255, 255, 0)", mode="RGB", 
                          max_font_size=100, prefer_horizontal=1, scale=20, margin=3,
                          width=550, height=850, font_path=utilities.arial, 
                          random_state=42).generate_from_frequencies(zip(tkns, freq))

        fig = plt.figure(1, figsize=(2,10))
        plt.axis("off")
        plt.imshow(cloud.recolor(color_func=color_func, random_state=42))
        file_name = "figures/lists/kvals/k{:02d}_wordcloud_{}.png".format(k, dom)
        plt.savefig(file_name, 
                    dpi=800, bbox_inches="tight")
        utilities.transparent_background(file_name)
        plt.close()

In [None]:
for k in circuit_counts:
    lists = pd.read_csv("lists/lists_k{:02d}_oplen.csv".format(k), index_col=None, header=0)
    lists["DOMAIN"] = [k2name[k][i] for i in lists["CLUSTER"]]
    lists_ordered = pd.DataFrame()
    for name in k2name_ordered[k]:
        lists_ordered = lists_ordered.append(lists.loc[lists["DOMAIN"] == name])
    plot_wordclouds(k, k2name_ordered[k], lists_ordered, dtm_bin)

# Visualize the circuits

In [16]:
atlas = utilities.load_atlas()

In [17]:
purples = utilities.make_cmap([(1,1,1), (0.365,0,0.878)])
chartreuses = utilities.make_cmap([(1,1,1), (0.345,0.769,0)])
magentas = utilities.make_cmap([(1,1,1), (0.620,0,0.686)])
yellows = utilities.make_cmap([(1,1,1), (0.937,0.749,0)])
browns = utilities.make_cmap([(1,1,1), (0.82,0.502,0)])
pinks = utilities.make_cmap([(1,1,1), (0.788,0,0.604)])
cmaps = ["Blues", magentas, yellows, "Greens", "Reds", purples, chartreuses, "Oranges", pinks, browns]

In [18]:
for k in range(2, len(cmaps)):
    circuits = pd.read_csv("circuits/circuits_k{:02d}.csv".format(k), index_col=None, header=0)
    circuits["DOMAIN"] = [k2name[k][i] for i in circuits["CLUSTER"]]
    circuit_mat = pd.DataFrame(0.0, index=act_bin.columns, columns=k2name_ordered[k])
    for name in k2name_ordered[k]:
        structures = circuits.loc[circuits["DOMAIN"] == name, "STRUCTURE"]
        for structure in structures:
            circuit_mat.loc[structure, name] = 1.0
    utilities.map_plane(circuit_mat, atlas, "figures/circuits/kvals/k{:02d}".format(k), 
                        suffix="_z", cmaps=cmaps, plane="z", cbar=False, vmin=0.0, vmax=2.0,
                        verbose=False, print_fig=False, annotate=True)

  data[slices] *= 1.e-3


# Export the results

## File structure

In [19]:
import os

In [20]:
for k in circuit_counts:
    path = "../../nke-viewer/data/k{:02d}".format(k)
    if not os.path.exists(path):
        os.mkdir(path)

## Word lists

In [21]:
freq = dtm.sum()
freq.head()

3d_object                 318
abductive_reasoning        11
abstract_analogy            6
abstract_concrete_task      8
abstract_knowledge         57
dtype: int64

In [22]:
for k in circuit_counts:
    lists = pd.read_csv("lists/lists_k{:02d}_oplen.csv".format(k), index_col=None, header=0)
    lists["DOMAIN"] = [k2name[k][i] for i in lists["CLUSTER"]]
    lists["CLUSTER"] = [k2name_ordered[k].index(dom) + 1 for dom in lists["DOMAIN"]]
    lists["FREQUENCY"] = [freq.loc[term] for term in lists["TOKEN"]]
    lists_ordered = pd.DataFrame()
    for name in k2name_ordered[k]:
        lists_ordered = lists_ordered.append(lists.loc[lists["DOMAIN"] == name])
    lists = lists.sort_values(["CLUSTER", "R"], ascending=[True, False])
    file = "../../nke-viewer/data/k{:02d}/words_k{:02d}.csv".format(k, k)
    lists.to_csv(file, index=None)

## Brain circuits

In [23]:
from nilearn import image, plotting

In [24]:
def load_atlas_2mm():

    import numpy as np
    from nilearn import image

    cer = "../data/brain/atlases/Cerebellum-MNIfnirt-maxprob-thr25-2mm.nii.gz"
    cor = "../data/brain/atlases/HarvardOxford-cort-maxprob-thr25-2mm.nii.gz"
    sub = "../data/brain/atlases/HarvardOxford-sub-maxprob-thr25-2mm.nii.gz"

    sub_del_dic = {1:0, 2:0, 3:0, 12:0, 13:0, 14:0}
    sub_lab_dic_L = {4:1, 5:2, 6:3, 7:4, 9:5, 10:6, 11:7, 8:8}
    sub_lab_dic_R = {15:1, 16:2, 17:3, 18:4, 19:5, 20:6, 21:7, 7:8}

    sub_mat_L = image.load_img(sub).get_data()[46:,:,:]
    sub_mat_R = image.load_img(sub).get_data()[:46,:,:]

    for old, new in sub_del_dic.items():
        sub_mat_L[sub_mat_L == old] = new
    for old, new in sub_lab_dic_L.items():
        sub_mat_L[sub_mat_L == old] = new
    sub_mat_L = sub_mat_L + 48
    sub_mat_L[sub_mat_L == 48] = 0

    for old, new in sub_del_dic.items():
        sub_mat_R[sub_mat_R == old] = new
    for old, new in sub_lab_dic_R.items():
        sub_mat_R[sub_mat_R == old] = new
    sub_mat_R = sub_mat_R + 48
    sub_mat_R[sub_mat_R == 48] = 0

    cor_mat_L = image.load_img(cor).get_data()[46:,:,:]
    cor_mat_R = image.load_img(cor).get_data()[:46,:,:]

    mat_L = np.add(sub_mat_L, cor_mat_L)
    mat_L[mat_L > 56] = 0
    mat_R = np.add(sub_mat_R, cor_mat_R)
    mat_R[mat_R > 56] = 0

    mat_R = mat_R + 57
    mat_R[mat_R > 113] = 0
    mat_R[mat_R < 58] = 0

    cer_mat_L = image.load_img(cer).get_data()[46:,:,:]
    cer_mat_R = image.load_img(cer).get_data()[:46,:,:]
    cer_mat_L[cer_mat_L > 0] = 57
    cer_mat_R[cer_mat_R > 0] = 114

    mat_L = np.add(mat_L, cer_mat_L)
    mat_L[mat_L > 57] = 0
    mat_R = np.add(mat_R, cer_mat_R)
    mat_R[mat_R > 114] = 0

    mat = np.concatenate((mat_R, mat_L), axis=0)
    atlas_image = image.new_img_like(sub, mat)

    return atlas_image

In [25]:
atlas_2mm = load_atlas_2mm()

In [26]:
for k in circuit_counts:
    circuits = pd.read_csv("circuits/circuits_k{:02d}.csv".format(k), index_col=None, header=0)
    circuits["DOMAIN"] = [k2name[k][i] for i in circuits["CLUSTER"]]
    columns = ["dom{:02d}".format(i) for i in range(1, k+1)]
    circuit_mat = pd.DataFrame(0.0, index=act_bin.columns, columns=columns)
    for n, name in enumerate(k2name_ordered[k]):
        structures = circuits.loc[circuits["DOMAIN"] == name, "STRUCTURE"]
        for structure in structures:
            circuit_mat.loc[structure, "dom{:02d}".format(n+1)] = 1.0 + np.random.uniform()
    
    for f, feature in enumerate(circuit_mat.columns):
        
        stat_map = image.copy_img(atlas_2mm).get_data()
        data = circuit_mat[feature]
        for i, value in enumerate(data):
            stat_map[stat_map == i+1] = value
        stat_img = image.new_img_like(atlas_2mm, stat_map)
        
        img_file = "../../nke-viewer/data/k{:02d}/circuit_k{:02d}_{}.nii.gz".format(k, k, feature)
        stat_img.to_filename(img_file)