In [None]:
import pandas as pd
import math
import networkx as nx
import numpy as np
import json
from collections import Counter
from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns
import plotly
import string
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from itertools import combinations

In [None]:
# Load metadata
df = pd.read_json('xai_scholar_metadata_expanded.csv')

In [None]:
df = df.set_index('paperId')

In [None]:
level = 0

In [None]:
# Load communities
comm_df = pd.read_json(f'communities/final_oslom_{level}.json')

In [None]:
comm_df.columns = [col-2000 for col in comm_df.columns]

In [None]:
# Load community lables (TFIDF)
with open(f'topic_labels_oslom_{level}.json', 'r') as f_in:
    label_dict = json.load(f_in)

In [None]:
# Load community labels (CTD)
with open(f'topic_labels_oslom_{level}_reweighted.json', 'r') as f_in:
    label_dict_reweighted = json.load(f_in)

In [None]:
community_sizes = {}
for year in range(24):
    community_sizes[year] = comm_df[year].explode().value_counts()

In [None]:
graphs = {}
for year in tqdm(range(24)):
    graphs[year] = nx.read_gexf(f'citation_graphs/{2000+year}.0.gexf')

In [None]:
# generate community interaction networks
community_graphs_normed = {}
for year in tqdm(range(24)):

    G = graphs[year].copy()
    communities = list(comm_df[year].explode().replace(0,np.nan).reset_index().groupby(year).agg(list)['index'])
    
    edge_weights = {}
    for id_u, u_community in enumerate(communities):
        for id_v, v_community in enumerate(communities):
            weight = sum(1 for u in u_community for v in v_community if G.has_edge(u, v))
            denom = len(u_community) * len(v_community)
            edge_weights[id_u+1,id_v+1] = weight / denom
    
    # Create the block model graph
    block_model_graph = nx.Graph()
    block_model_graph.add_nodes_from([community+1 for community in range(len(communities))])
    block_model_graph.add_weighted_edges_from([(u, v, weight) for (u, v), weight in edge_weights.items()])
    
    nx.set_node_attributes(block_model_graph, community_sizes[year], 'size')
    
    community_labels = {int(k) : "_".join(label_dict[str(year)][k][:3]) for k in label_dict[str(year)]}
    
    nx.set_node_attributes(block_model_graph, community_labels, 'label')
    
    community_fields = {}
    for i, comm in enumerate(communities):
        fields = sorted(df.loc[comm]['fields'].explode().value_counts().head(1).index)
        community_fields[i+1] = "_".join(fields)
    
    nx.set_node_attributes(block_model_graph, community_fields, 'fields')

    community_graphs_normed[year] = block_model_graph

In [None]:
# generate community interaction newtworks 
community_graphs = {}
for year in tqdm(range(24)):

    G = graphs[year].copy()
    communities = list(comm_df[year].explode().replace(0,np.nan).reset_index().groupby(year).agg(list)['index'])
    
    edge_weights = {}
    for id_u, u_community in enumerate(communities):
        for id_v, v_community in enumerate(communities):
            weight = sum(1 for u in u_community for v in v_community if G.has_edge(u, v))
            # denom = len(u_community) * len(v_community)
            if (weight > 1) & (id_u != id_v): 
                edge_weights[id_u+1,id_v+1] = weight # / denom
    
    # Create the block model graph
    block_model_graph = nx.Graph()
    block_model_graph.add_nodes_from([community+1 for community in range(len(communities))])
    block_model_graph.add_weighted_edges_from([(u, v, weight) for (u, v), weight in edge_weights.items()])
    
    nx.set_node_attributes(block_model_graph, community_sizes[year], 'size')
    
    community_labels = {int(k) : "_".join(label_dict[str(year)][k][:3]) for k in label_dict[str(year)]}
    
    nx.set_node_attributes(block_model_graph, community_labels, 'label')
    
    community_fields = {}
    for i, comm in enumerate(communities):
        fields = sorted(df.loc[comm]['fields'].explode().value_counts().head(1).index)
        community_fields[i+1] = "_".join(fields)
    
    nx.set_node_attributes(block_model_graph, community_fields, 'fields')

    community_graphs[year] = block_model_graph

## Tracking communities:

In [None]:
d_dict = {d:{0:d} for d in comm_df[0].explode().replace(0,np.nan).dropna().unique()}
scores = {d:{0:d} for d in comm_df[0].explode().replace(0,np.nan).dropna().unique()}

In [None]:
t = 0.3

In [None]:
def comm_lookup(year, comm_id):
    return set(comm_df[year].explode()[comm_df[year].explode() == comm_id].index)

In [None]:
d_counter = max(list(d_dict))
for y in range(1,24):
    comms = comm_df[y]
    for c in comms.explode().replace(0,np.nan).dropna().unique():
        match_found = False
        c_set = comm_lookup(y,c)
        for d in d_dict.copy():
            # D = comm_df[y-1]
            # d_set = set(D[D==d].index)
            front_year = max(d_dict[d])
            front_name = d_dict[d][front_year]
            d_set = comm_lookup(front_year,front_name)
            j_sim = len(c_set.intersection(d_set)) / len(c_set.union(d_set))
            if j_sim >= t:
                match_found = True
                if not y in d_dict[d]:
                    d_dict[d][y] = c
                    scores[d][y] = j_sim
                else: 
                    d_counter+=1
                    d_dict[d_counter] = d_dict[d]
                    d_dict[d_counter][y] = c 
                    scores[d_counter][y] = j_sim
        if not match_found:
            d_counter+=1
            d_dict[d_counter] = {y:c}
            scores[d_counter] = {y:0}

In [None]:
dyn_comm_df = pd.DataFrame(d_dict).T
scores_df = pd.DataFrame(scores).T

In [None]:
dyn_comm_df = dyn_comm_df[range(24)]
dyn_comm_df[dyn_comm_df.notna().sum(axis = 1) > 1]

In [None]:
dyn_comm_df = dyn_comm_df[dyn_comm_df.notna().sum(axis = 1) > 1][range(24)]

In [None]:
dyn_comm_df

## Community Metrics:

In [None]:
with open('embeddings/scibert.json','r') as infile:
    scibert = json.load(infile)

scibert = pd.Series(scibert)

In [None]:
size_df = pd.DataFrame()
for year in range(24):
    size_df[year] = dyn_comm_df[year].map(lambda x: community_sizes[year][x], na_action = 'ignore')

In [None]:
community_densities = {}
for year in range(24):
    comms = comm_df[year].reset_index().groupby(year).index.agg(list)
    dens = {}
    for comm_id, comm in comms.items(): 
        dens[comm_id] = nx.density(graphs[year].subgraph(comm))
    community_densities[year] = dens

In [None]:
dens_df = pd.DataFrame()
for year in range(24):
    dens_df[year] = dyn_comm_df[year].map(lambda x: community_densities[year][x], na_action = 'ignore')

In [None]:
community_disparity = {}
for year in range(24):
    comms = comm_df[year].reset_index().groupby(year).index.agg(list)
    dens = {}
    for comm_id, papers in comms.items(): 
        embs = scibert[papers]
        X = np.vstack(embs)
        m = X.mean(axis=0).reshape(1,-1)        
        dens[comm_id] = cosine_similarity(X,m).mean()
    community_disparity[year] = dens

In [None]:
disp_df = pd.DataFrame()
for year in range(24):
    disp_df[year] = dyn_comm_df[year].map(lambda x: community_disparity[year][x], na_action = 'ignore')

In [None]:
merges = []
for year in range(1,24):
    counts = dyn_comm_df[year].value_counts()
    cands = list(counts[counts > 1].index)
    # print('\n')
    for cand in cands:
        dyn_comms = dyn_comm_df[dyn_comm_df[year] == cand]
        dyn_comm_1 = dyn_comms.iloc[0]
        dyn_comm_1_id = dyn_comm_1.name
        for dyn_comm_2_id, dyn_comm_2 in dyn_comms.iloc[1:].iterrows():
            if not dyn_comm_1[range(year)].eq(dyn_comm_2[range(year)]).any():
                # print(f'comm {dyn_comm_1_id} and comm {dyn_comm_2_id} merge in {2000+year}')
                merges.append({'comm1':dyn_comm_1_id,'comm2':dyn_comm_2_id,'year':year})
            elif not dyn_comm_1[range(year,24)].dropna().eq(dyn_comm_2[range(year,24)].dropna(),fill_value = 0).all():
                print(f'comm {dyn_comm_1_id} and comm {dyn_comm_2_id} SPLIT after {2000+year}')
            

In [None]:
def compare_communities(year1, year2, comm1, comm2):
    
    papers_1 = list(comm_lookup(year1, comm1))
    papers_2 = list(comm_lookup(year2, comm2))
    X = np.vstack(scibert[papers_1])
    Y = np.vstack(scibert[papers_2])
    return cosine_similarity(X,Y).mean()

In [None]:
def get_comm_interactions(year1, year2, comm1, comm2):
    
    comm1, comm2 = comm_lookup(year1, comm1), comm_lookup(year2, comm2)
    G = graphs[max(year1,year2)].copy()
    weight = sum(1 for u in comm1 for v in comm2 if G.has_edge(u, v))
    denom = len(comm1) * len(comm2)
    return weight / denom

In [None]:
sim_by_year = {}
for year in range(1,24):
    other_communities = list(dyn_comm_df[year].dropna())
    sims = []
    for comms in combinations(other_communities, r = 2):
        sims.append(compare_communities(year, year, *comms)) 
    sim_by_year[year] = sims

In [None]:
int_by_year = {}
for year in tqdm(range(1,24)):
    other_communities = list(dyn_comm_df[year].dropna())
    G = community_graphs[year].copy()
    X = nx.to_numpy_array(G)
    X = np.tril(X, k = -1)
    int_by_year[year] = X[X > 0].flatten()

In [None]:
fig, ax = plt.subplots(1,1)
pd.concat([pd.Series(int_by_year[year]) for year in int_by_year], axis = 1).boxplot(ax = ax)
for merge in merges:
    ax.scatter(merge['year'],merge['interactions'], color = 'red', marker = '+', zorder = 3)

In [None]:
fig, ax = plt.subplots(1,1)
pd.concat([pd.Series(sim_by_year[year]) for year in sim_by_year], axis = 1).boxplot(ax = ax)
for merge in merges:
    ax.scatter(merge['year'],merge['sim'], color = 'red', marker = '+', zorder = 3)

In [None]:
for merge in merges:
    year = merge['year']
    
    x = merge['sim']
    data = sim_by_year[year]
    std = np.std(data)
    mean = np.mean(data)
    merge['sim_z'] = (x - mean) / std
    
    x = merge['interactions']
    data = int_by_year[year]
    data = np.log(data)
    std = np.std(data)
    mean = np.mean(data)
    merge['int_z'] = (np.log(x) - mean) / std

In [None]:
def merge_mean_disp(comm1, comm2, year):
    disps = [] 
    for comm in [comm1,comm2]:
        d = dyn_comm_df.loc[comm][range(int(year))].dropna().to_dict()
        y = max(d)
        c = d[y]
        disps.append(community_disparity[y][c])
    return len(disps) / sum([1/d for d in disps]) 

In [None]:
merge_df = pd.DataFrame(merges)

In [None]:
merge_df['coherence'] = merge_df.apply(lambda row: merge_mean_disp(row['comm1'],row['comm2'],row['year']), axis = 1)

In [None]:
merge_df['sim_adj'] = merge_df['sim']*merge_df['coherence']

In [None]:
merge_df['sim_adj_z'] = (merge_df['sim_adj'] - merge_df['sim_adj'].mean()) / merge_df['sim_adj'].std()

In [None]:
for _, merge in merge_df[(merge_df['year1'] == merge_df['year2']) & (merge_df['sim_adj_z']<0)].sort_values('sim_adj_z').iterrows():
    sub_df = dyn_comm_df.loc[[merge['comm1'],merge['comm2']]]
    for i in range(2):
        d = sub_df.iloc[i][range(int(merge['year']))].dropna().to_dict()
        year = max(d)
        comm = d[year]
        labels = label_dict[str(int(year))][str(int(comm))]
        disp = community_disparity[year][comm]
        dens = community_densities[year][comm]
        
        print(f'{dens:.2f},{disp:.2f}:{labels}')
        
    print('\n')

In [None]:
comms = list(dyn_comm_df.loc[1].dropna().items())

In [None]:
def community_stability(dyn_comm):
    comms = list(dyn_comm.dropna().items())
    content_sim = []
    membership_sim = []
    sizes = []
    fields = []
    for i in range(len(comms) - 1):
        (year1, comm1), (year2, comm2) = comms[i], comms[i+1]
        content_sim.append(compare_communities(year1, year2, comm1, comm2))
        c_set_1 = comm_lookup(year1,comm1)
        c_set_2 = comm_lookup(year2,comm2)
        membership_sim.append(len(c_set_1.intersection(c_set_2)) / len(c_set_1.union(c_set_2)))
        sizes.append(len(c_set_1))
        field = df.loc[list(c_set_1)].fields.explode().value_counts()
        field = list(field[field>0.5].index)
        fields.append(field)
    field = df.loc[list(c_set_2)].fields.explode().value_counts()
    field = list(field[field>0.5].index)
    fields.append(field)
    sizes.append(len(c_set_2))
    return pd.Series({'membership_sim':membership_sim,'content_sim':content_sim,'lifespan':len(comms),'sizes':sizes,'fields':fields})

In [None]:
dyn_comm_stats = dyn_comm_df.apply(community_stability,axis = 1)

In [None]:
dyn_comm_stats['field'] = dyn_comm_stats['fields'].map(lambda x: pd.Series(x).explode().value_counts().index[0])

In [None]:
dyn_comm_stats['membership_stability'] = dyn_comm_stats['membership_sim'].map(np.mean)
dyn_comm_stats['content_stability'] = dyn_comm_stats['content_sim'].map(np.mean)
dyn_comm_stats['size_mean'] = dyn_comm_stats['sizes'].map(np.mean)

In [None]:
field_list = dyn_comm_stats.fields.explode().map(lambda x:x[0]).unique()
field_list

## Community Lifecycles:

In [None]:
field_list = ['Mathematics', 'Computer Science', 'Engineering', 'Physics', 'Psychology', 'Philosophy', 'Medicine', 'Biology', 'Business', 'Economics', 'Environmental Science']

field_labels = dict(zip(field_list,range(len(field_list))))

In [None]:
dyn_comm_fields = []
for comm in dyn_comm_df.index:
    fields = dyn_comm_stats.loc[comm]['fields']
    d = dyn_comm_df.loc[comm].dropna().to_dict()
    for i,k in enumerate(d):
        d[k] = field_labels[fields[i][0]]
    for i in range(24):
        if not i in d:
            d[i] = np.nan
    d['comm'] = comm
    dyn_comm_fields.append(d)
dyn_comm_fields = pd.DataFrame(dyn_comm_fields).set_index('comm')

In [None]:
recent_comms_mask = (dyn_comm_fields[range(16)].notna().sum(axis=1)==0)&(dyn_comm_fields[range(16,24)].notna().sum(axis=1)==6)

In [None]:
### n = len(field_labels)
cmap = sns.color_palette("Paired", n) 
# size_min, size_max = [0.7,0.8]
fig, ax = plt.subplots(1,1,figsize = (8,8))
sns.heatmap(dyn_comm_fields[recent_comms_mask], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=ax, vmin=0, vmax=10)
ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
ax.set_xticklabels([i+2000 if i%2==0 else '' for i in range(24)])
# ax.xticks(rotation=30, ha='center')
ax.set_ylabel('')
ax.set_yticks([])
colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(field_labels))                                          
# fig.tight_layout()
plt.subplots_adjust(hspace=0.1, wspace=0.05)
plt.show()

## Pairwise Community Interactions:

In [None]:
p = {}
s = {}
for i in range(23):
    years = (i, i+1)
    pairs = dyn_comm_df[list(years)].dropna()
    for comm1, comm2 in tqdm(list(combinations(pairs.itertuples(), r = 2)), postfix = f'{years}'):
        p2 = community_graphs_normed[years[1]][comm1[2]][comm2[2]]['weight']
        p1 = community_graphs_normed[years[0]][comm1[1]][comm2[1]]['weight'] 
        idx = (comm1[0],comm2[0]) 
        # if not comm1[1] == comm2[1]:
        if not idx in p:
            p[idx] = {}
        p[idx][years[0]] = p1
        p[idx][years[1]] = p2
        if not idx in s:
            s[idx] = {}
        s[idx][years[0]] = compare_communities(years[0],years[0],comm1[1],comm2[1])
        s[idx][years[1]] = compare_communities(years[1],years[1],comm1[2],comm2[2])

In [None]:
int_s = pd.Series(p)
sim_s = pd.Series(s)

In [None]:
int_s = pd.DataFrame(int_s)
int_s['len'] = int_s[0].map(len)
sim_s = pd.DataFrame(sim_s)
sim_s['len'] = sim_s[0].map(len)

In [None]:
int_df = int_s[0].apply(pd.Series)
sim_df = sim_s[0].apply(pd.Series)

In [None]:
bet_lookup = {}
for year in range(24):
    bet_lookup[year] = dict(nx.betweenness_centrality(community_graphs[year]))

bet_df = pd.DataFrame(columns=range(24))
for label, row in dyn_comm_df.iterrows():
    cent = pd.Series()
    for year, id_c in row.dropna().items():
        cent[year] = bet_lookup[year][id_c]
    bet_df.loc[label] = cent

In [None]:
deg_lookup = {}
for year in range(24):
    deg_lookup[year] = dict(nx.degree(community_graphs[year]))

deg_df = pd.DataFrame(columns=range(24))
for label, row in dyn_comm_df.iterrows():
    cent = pd.Series()
    for year, id_c in row.dropna().items():
        cent[year] = deg_lookup[year][id_c]
    deg_df.loc[label] = cent

In [None]:
from matplotlib import gridspec

In [None]:
df['year'].value_counts()[range(2000,2024)].plot(kind = 'bar')

## Research Question 1: 
- Identify the foundational topics in the literature:
    Longlived, coherent communities with sustained centrality
- Identify contemporary topics as large, recent communities
 
     

In [None]:
n = len(field_labels)
cmap = sns.color_palette("Paired", n) 
# size_min, size_max = [0.7,0.8]
fig, ax = plt.subplots(1,1,figsize = (8,8))
sns.heatmap(dyn_comm_fields[(lifespan>=11)], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=ax, vmin=0, vmax=10)
# ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
ax.set_xticklabels([i+2000 if i%2==0 else '' for i in range(24)])
# ax.xticks(rotation=30, ha='center')
ax.set_ylabel('')
# ax.set_yticks([])
colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(field_labels))                                          
# fig.tight_layout()
plt.subplots_adjust(hspace=0.1, wspace=0.05)
plt.show()

In [None]:
stop = ['data', 'model', 'network', 'disentanglement', 'models']
def annotate_dynamic_community(comm, n_chars=50, n_terms=3, y=24, reweighted=False):
    if reweighted:
        labels = label_dict_reweighted
    else:
        labels = label_dict
    comm = float(comm)
    term_series = dyn_comm_df.loc[comm][range(y)].dropna().T.reset_index().apply(lambda x: labels[str(int(x['index']))][str(int(x[comm]))], axis = 1)
    terms = term_series.explode()
    terms = terms[terms.apply(lambda x: x not in stop)]
    term_counts = terms.value_counts()
    top_terms = term_counts.iloc[:10].index
    stop_terms = []
    for term in top_terms:
        # stop_terms.append(term+'s')
        stop_terms.append(term[:-1])
        split_term = term.split(' ')
        if len(split_term) > 1:
            stop_terms.extend(split_term)
            # stop_terms.extend([t+'s' for t in split_term])
            stop_terms.extend([t[:-1] for t in split_term])
    top_terms = [term for term in top_terms if term not in stop_terms]
    # print(stop_terms)
    label_string = '/'.join(top_terms)
    n = len(top_terms)
    while (len(label_string) > n_chars) or (n>n_terms):
        n-=1
        label_string = '/'.join(top_terms[:n])
    return label_string

In [None]:
def annotate_community(comm, n=5, y=24, reweighted=False):
    if reweighted:
        labels = label_dict_reweighted
    else:
        labels = label_dict
    terms = pd.Series(labels['23'][str(int(comm))])
    terms = terms[terms.apply(lambda x: x not in stop)]
    term_counts = terms.value_counts()
    top_terms = term_counts.iloc[:10].index
    stop_terms = []
    for term in top_terms:
        # stop_terms.append(term+'s')
        stop_terms.append(term[:-1])
        split_term = term.split(' ')
        if len(split_term) > 1:
            stop_terms.extend(split_term)
            # stop_terms.extend([t+'s' for t in split_term])
            stop_terms.extend([t[:-1] for t in split_term])
    top_terms = [term for term in top_terms if term not in stop_terms][:n]
    # print(stop_terms)
    return '/'.join(top_terms)

In [None]:
most_central = bet_df[lifespan>10][range(10)].fillna(0).mean(axis = 1).sort_values(ascending = False).dropna().head(8)
most_central.index.map(lambda x: print(x,annotate_dynamic_community(x),most_central[x]))
least_central = bet_df[lifespan>12][range(10)].mean(axis = 1).sort_values().dropna().head(8)

In [None]:
### n = len(field_labels)
cmap = sns.color_palette("Paired", n) 
# size_min, size_max = [0.7,0.8]
fig, ax = plt.subplots(1,1,figsize = (8,8))
sns.heatmap(dyn_comm_fields.loc[most_central.index], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=ax, vmin=0, vmax=10)
# ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
ax.set_xticklabels([i+2000 if i%4==0 else '' for i in range(24)])
# ax.xticks(rotation=30, ha='center')
ax.set_yticklabels([annotate_dynamic_community(i,n_chars=35) for i in most_central.index])
# ax.set_yticks([])
colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(field_labels))                                          
# fig.tight_layout()
plt.subplots_adjust(hspace=0.1, wspace=0.05)
plt.yticks(rotation=0)
plt.show()

In [None]:
# del field_labels['Environmental Science']
n = len(field_labels)
cmap = sns.color_palette("Paired", n) 
# size_min, size_max = [0.7,0.8]
fig, axs = plt.subplots(1, 4, figsize=(11, 4),gridspec_kw={'width_ratios': [1, 0.65, 1, 0.04]})
sns.heatmap(dyn_comm_fields.loc[most_central.index], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=axs[0], vmin=0, vmax=10, cbar=False)
# ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
# ax.xticks(rotation=30, ha='center')
axs[0].set_yticklabels([annotate_dynamic_community(i,n_chars=35) for i in most_central.index])
# ax.set_yticks([])

sns.heatmap(dyn_comm_fields.loc[least_central.index], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=axs[2], vmin=0, vmax=10, cbar_ax=axs[3])
# ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
x_labels = [(2*i)+2000 if i%4==0 else '' for i in range(12)]
axs[2].set_xticklabels(x_labels)
axs[0].set_xticklabels(x_labels)
# ax.xticks(rotation=30, ha='center')
axs[2].set_yticklabels([annotate_dynamic_community(i,n_chars=30) for i in least_central.index])
# ax.set_yticks([])
colorbar = axs[2].collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(field_labels))                                          
# fig.tight_layout()
axs[0].set_yticks(axs[0].get_yticks(), axs[0].get_yticklabels(), rotation=0)
axs[2].set_yticks(axs[2].get_yticks(), axs[2].get_yticklabels(), rotation=0)
plt.subplots_adjust(hspace=0.1, wspace=0.05)
plt.yticks(rotation=0)
axs[0].set_title('High Centrality')
axs[2].set_title('Low Centrality')
axs[1].set_axis_off()
# plt.tight_layout()
# plt.subplots_adjust(wspace = 1)
plt.savefig('figures/high_centrality_low_centrality.png', bbox_inches='tight')
plt.show()

In [None]:
modern_most_central = bet_df[bet_df[23].notna()][[21,22,23]].mean(axis = 1).sort_values(ascending=False).head(7)
modern_most_central.index.map(lambda x: print(x,annotate_dynamic_community(x)))

In [None]:
# del field_labels['Environmental Science']
n = len(field_labels)
cmap = sns.color_palette("Paired", n) 
# size_min, size_max = [0.7,0.8]
fig, axs = plt.subplots(1, 4, figsize=(12, 4),gridspec_kw={'width_ratios': [1, 0.6, 1, 0.04]})
sns.heatmap(dyn_comm_fields.loc[most_central.index], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=axs[0], vmin=0, vmax=10, cbar=False)
# ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
# ax.xticks(rotation=30, ha='center')
axs[0].set_yticklabels([annotate_dynamic_community(i,n_chars=30) for i in most_central.index])
# ax.set_yticks([])

sns.heatmap(dyn_comm_fields.loc[modern_most_central.index], cmap=cmap, linecolor='white', linewidths=0.1, 
            ax=axs[2], vmin=0, vmax=10, cbar_ax=axs[3])
# ax.set_yticklabels([])
str_intervals = str(q).replace("(","").replace("]", "").split(", ")
x_labels = [(2*i)+2000 if i%4==0 else '' for i in range(12)]
axs[2].set_xticklabels(x_labels)
axs[0].set_xticklabels(x_labels)
# ax.xticks(rotation=30, ha='center')
axs[2].set_yticklabels([annotate_dynamic_community(i,n_chars=35) for i in modern_most_central.index])
# ax.set_yticks([])
colorbar = axs[2].collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(field_labels))                                          
# fig.tight_layout()
axs[0].set_yticks(axs[0].get_yticks(), axs[0].get_yticklabels(), rotation=0)
axs[2].set_yticks(axs[2].get_yticks(), axs[2].get_yticklabels(), rotation=0)
plt.subplots_adjust(hspace=0.1, wspace=0.05)
plt.yticks(rotation=0)
axs[0].set_title('Foundation Topics')
axs[2].set_title('Recent Central Topics')
axs[1].set_axis_off()
# plt.tight_layout()
# plt.subplots_adjust(wspace = 1)
plt.savefig('figures/central_topics.png', bbox_inches='tight')
plt.show()

In [None]:
def community_year(year,comm):
    papers = comm_lookup(year,comm)
    return pd.Series([df['year'][paper] for paper in papers]).mean()

In [None]:
[int(community_year(23,front)) for front in fronts]

In [None]:
dyn_comm_stats['year'] = dyn_comm_df[23].map(lambda x: int(community_year(23,x)), na_action = 'ignore')

In [None]:
threshold = dyn_comm_stats['content_stability'].mean()

In [None]:
mask = (dyn_comm_stats['size_mean']>50)
mask = mask&(dyn_comm_stats['content_stability']>threshold)
mask = mask&(dyn_comm_stats['year']>=2017)
mask = mask&(bet_df[23].notna())
# mask = mask&(prop_s>0.1)
mask = mask&(lifespan>2)

In [None]:
recent_comms = dyn_comm_stats[mask].sort_values('year', ascending=False).iloc[:20]

In [None]:
recent_fronts = dyn_comm_df.loc[recent_comms.index][23].map(int)#.unique()

In [None]:
fronts, most_central

In [None]:
n_chars = 35
X = nx.to_numpy_array(community_graphs_normed[23]).T[recent_fronts][:, fronts] * 100
foundation_order = X.mean(axis=0).argsort()
sorted_X = X[X.mean(axis=1).argsort()][:, foundation_order]

# Compute heatmap data for the second heatmap (X_modern)
X_modern = nx.to_numpy_array(community_graphs_normed[23]).T[recent_fronts][:, modern_fronts] * 100
modern_order = X_modern.mean(axis=0).argsort()
sorted_X_modern = X_modern[X.mean(axis=1).argsort()][:, X_modern.mean(axis=0).argsort()]

# Create subplots to display heatmaps side by side
fig, axs = plt.subplots(1, 3, figsize=(10, 8),gridspec_kw={'width_ratios': [1, 1, 0.05]})

# Plot the first heatmap (sorted_X) on the first subplot (axs[0])
sns.heatmap(sorted_X, annot=True, vmax=2, cbar=False, linecolor='white', linewidths=0.1, cmap='crest',
            xticklabels=[annotate_dynamic_community(front, n_chars=n_chars) for front in most_central.index[foundation_order]],
            yticklabels=[annotate_dynamic_community(front, n_chars=n_chars) for front in recent_comms.index[X.mean(axis=1).argsort()]],
            ax=axs[0])
axs[0].set_title('Foundation Topics')

# Plot the second heatmap (sorted_X_modern) on the second subplot (axs[1])
sns.heatmap(sorted_X_modern, annot=True, vmax=2, cbar_ax=axs[2], linecolor='white', linewidths=0.1, cmap='crest',
            xticklabels=[annotate_dynamic_community(front, n_chars=n_chars) for front in modern_most_central.index[modern_order]],
            yticklabels=[annotate_dynamic_community(front, n_chars=n_chars) for front in recent_comms.index[X.mean(axis=1).argsort()]],
            ax=axs[1])
axs[1].set_title('Recent Central Topics')
axs[1].set_yticklabels([])
axs[0].set_xticks(axs[0].get_xticks(), axs[0].get_xticklabels(), rotation=45, ha='right')
axs[1].set_xticks(axs[1].get_xticks(), axs[1].get_xticklabels(), rotation=45, ha='right')
# Adjust layout to prevent overlap of subplots
# plt.tight_layout()

# Show the plot
axs[0].set_title('Foundation Topics', fontsize=16)
axs[1].set_title('Recent Central Topics', fontsize=16)

# Show the plot
plt.savefig('figures/topic_interactions.png', bbox_inches='tight')
plt.show()

In [None]:
n_terms = 3
X = nx.to_numpy_array(community_graphs_normed[23]).T[recent_fronts][:, fronts] * 100
sorted_X = X[X.mean(axis=1).argsort()][:, foundation_order]

# Compute heatmap data for the second heatmap (X_modern)
X_modern = nx.to_numpy_array(community_graphs_normed[23]).T[recent_fronts][:, modern_fronts] * 100
sorted_X_modern = X_modern[X.mean(axis=1).argsort()][:, modern_order]

# Create subplots to display heatmaps side by side
fig, axs = plt.subplots(4, 3, figsize=(10, 8),gridspec_kw={'width_ratios': [1, 1, 0.04]})
ax_cbar = fig.add_subplot(1,40,40)

ls = [np.array([472,332,258]), np.array([602,451,331]),  np.array([543,262,266]), np.array([458,493,301])]
topic_subsets = ['Fairness\n', 'Natural Language \n Processing', 'Computer Vision\n', 'Adversarial ML\n']

for i in range(len(axs)):
    l = ls[i]
    l_x = np.array([recent_fronts[l_x_i] for l_x_i in ls[i]])
    X = nx.to_numpy_array(community_graphs_normed[23]).T[l_x][:,fronts]*100
    sns.heatmap(X[:,foundation_order], annot = True,
                vmax=2, ax=axs[i][0], cbar = False, linecolor='white', linewidths=0.1, cmap='crest',
                xticklabels=[annotate_dynamic_community(front, n_chars=35) for front in most_central.index[foundation_order]],
                yticklabels=[annotate_dynamic_community(f, n_chars=35) for f in l[X.mean(axis=1).argsort()]])
    if i < (len(axs)-1):
        axs[i][0].set_xticklabels([])
        axs[i][0].set_xticks([])
    else:
        axs[i][0].set_xticks(axs[i][0].get_xticks(), axs[i][0].get_xticklabels(), rotation=45, ha='right')
    # axs[i][0].set_ylabel(topic_subsets[i],fontsize = 12)
    axs[i][0].yaxis.set_label_coords(-0.6,0.55)
        
    X_modern = nx.to_numpy_array(community_graphs_normed[23]).T[l_x][:, modern_fronts] * 100
    sns.heatmap(X_modern[:,modern_order], annot = True,
                vmax=2, ax=axs[i][1], cbar_ax = ax_cbar, linecolor='white', linewidths=0.1, cmap='crest',
                xticklabels=[annotate_dynamic_community(front, n_chars=35) for front in modern_most_central.index[modern_order]],
                yticklabels=[annotate_dynamic_community(f, n_chars=35) for f in l[X.mean(axis=1).argsort()]])
    if i < (len(axs)-1):
        axs[i][1].set_xticklabels([])
        axs[i][1].set_xticks([])
    else:
        axs[i][1].set_xticks(axs[i][1].get_xticks(), axs[i][1].get_xticklabels(), rotation=45, ha='right')
    axs[i][1].set_yticklabels([])
    axs[i][2].set_axis_off()
# Adjust layout to prevent overlap of subplots
# plt.tight_layout()

axs[0][0].set_title('Foundation Topics', fontsize=16)
axs[0][1].set_title('Recent Central Topics', fontsize=16)

# Show the plot
# plt.show()
plt.savefig('figures/topic_interactions_sub.png', bbox_inches='tight')

## Research Question 2:

-Model the relationship between content similarity, second-order network proximity and citations/community interactions 

In [None]:
dyn_comm_df[23].dropna().nunique()

In [None]:
rel_df = pd.concat([int_df[23].dropna(),sim_df[23].dropna()], axis = 1)
rel_df.columns = ['interactions', 'content_sim']
rel_df['interactions_log'] = rel_df['interactions'].map(lambda x: np.log(10*(x+1)))
rel_df['content_sim_std'] = (rel_df['content_sim']-rel_df['content_sim'].mean()) / rel_df['content_sim'].std()

In [None]:
from sklearn.linear_model import TweedieRegressor, LinearRegression, GammaRegressor
from sklearn.model_selection import cross_val_score

In [None]:
X = nx.to_numpy_array(community_graphs_normed[23])

In [None]:
def compare_neighborhoods(comm1,comm2):
    f1, f2 = int(dyn_comm_df[23].loc[comm1]), int(dyn_comm_df[23].loc[comm2])
    return cosine_similarity([X[f2]],[X[f1]])[0][0]

In [None]:
rel_df['neighborhood_sim'] = rel_df.index.map(lambda comms: compare_neighborhoods(*comms))
rel_df['neighborhood_sim_std'] = (rel_df['neighborhood_sim'] - rel_df['neighborhood_sim'].mean()) / rel_df['neighborhood_sim'].std()

In [None]:
model = GammaRegressor()
X,y = rel_df[['content_sim_std','neighborhood_sim_std']],rel_df['interactions_log']
np.mean(-cross_val_score(model, X, y, cv=5, scoring='neg_mean_squared_error'))

In [None]:
model.fit(X,y)
rel_df['pred'] = model.predict(X)

In [None]:
rel_df['residual'] = rel_df['pred'] - rel_df['interactions_log'] 

In [None]:
rel_df[rel_df['neighborhood_sim'] < 0.99]['residual'].plot(kind = 'hist', bins = 50)

In [None]:
node_map = dict(zip(dyn_comm_df[23].dropna().unique(),range(len(dyn_comm_df[23].dropna().unique()))))
inverse_node_map = {node_map[k]:k for k in node_map}
X = np.zeros((len(node_map),len(node_map)))
# for edges, res in rel_df[rel_df['pred'] <=np.log(10)]['residual'].items():
for edges, res in rel_df['residual'].items():
    e1, e2 = dyn_comm_df[23].loc[edges[0]], dyn_comm_df[23].loc[edges[1]]
    e1, e2 = node_map[e1], node_map[e2]
    if e1 != e2:
        X[e1,e2] = res
        X[e2,e1] = res

In [None]:
mean_res = pd.Series(X.mean(axis=1).argsort())

In [None]:
mean_res_df = pd.concat([pd.Series(X.mean(axis = 1)[mean_res.values]),mean_res], axis = 1)

In [None]:
mean_res_df['comm'] = mean_res_df[1].map(inverse_node_map)

In [None]:
dyn_to_front = dyn_comm_df[23].drop_duplicates().dropna().to_dict()
front_to_dyn = {dyn_to_front[k]:k for k in dyn_to_front}

In [None]:
dyn_comm_df[dyn_comm_df[23] == dyn_to_front[613]]

In [None]:
def edge_annotations(x, asc = False, n = 3):
    if asc:
        neighbours = X[x].argsort()[:n]
    else:
        neighbours = reversed(X[x].argsort()[-n:])
    node_topic = annotate_community(inverse_node_map[x], n=3) 
    print(inverse_node_map[x], node_topic, community_year(23,inverse_node_map[x]))
    neighbour_list = {0:node_topic}
    for i,n in enumerate(neighbours):
        edge = sorted([front_to_dyn[inverse_node_map[n]], front_to_dyn[inverse_node_map[x]]])
        neighbour_topic = annotate_community(inverse_node_map[n])
        print('\t\t', inverse_node_map[n], neighbour_topic, community_year(23,inverse_node_map[n]))
        neighbour_list.update({i+1:neighbour_topic})
    return neighbour_list

In [None]:
import math

In [None]:
knowledge_gaps = [edge_annotations(x, asc=False, n = 5) for x in mean_res_df.sort_values(0,ascending=False).iloc[:5][1]]

In [None]:
print(pd.concat([pd.Series(d) for d in knowledge_gaps]).to_latex())

## Research Question 3:

In [None]:
silos = pd.Series(X.sum(axis = 1)).sort_values().iloc[:10]
silo_info = [{#'id':x,
  'label':annotate_dynamic_community(front_to_dyn[x], reweighted=False, n_chars = 100, n_terms = 5),
  # 'label_reweighted':annotate_dynamic_community(front_to_dyn[x], reweighted=True, n_chars = 100, n_terms = 5),
  'size':len(comm_lookup(23, x)),
  # 'z_score':total_interactions_z[x],
  'degree': f'{deg_df[23][front_to_dyn[x]]:.0f}',
  'density':f'{dens_df[23][front_to_dyn[x]]:.2f}'} for x in silos.index]
silo_df = pd.DataFrame(silo_info)

In [None]:
print(silo_df.iloc[1:6].to_latex())