# State of Play

In this notebook we concentrate on research trends on AI: 

* How has the field evolved
* Where has the field spread
* How has the field been disrupted
* What is the situation in different countries

## 0. Preamble

In [None]:
%run notebook_preamble.ipy

In [None]:
# Ignore future warnings (for when I concatenate dfs)

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

### Other imports

In [None]:
import random

### Functions

Add a bunch of exogenous variables to the analysis df

In [None]:
def save_fig(name,path='../reports/figures/paper_rev/'):
    '''
    Saves a figure
    '''
    plt.tight_layout()
    
    plt.savefig(path+f'{today_str}_{name}')

    
    

In [None]:
# Put functions etc here

def flatten_list(my_list):
    '''
    Flattens a list
    '''
    
    return([x for el in my_list for x in el])


def get_example(df,number,length):
    '''
    Gets random examples in a field
    
    Args:
        Df is the dataframe we want to use
        number is the number of examples we want
        length is the length of the examples
    
    '''
    
    choose = random.sample(list(df.index),number)
    
    for x in df.loc[choose]['abstract']:
        
        print(x[:length])
        print('\n')
    

In [None]:
def trend_analysis(topic_mix,topics,year_var='year',year_lim = [2000,2019],thres=0.1):
    '''
    Takes a df and analyses topics trends
    
    Args:
        -The topic mix where the rows are papers and the columns are topics
        -The topics to visualise
        -The year variable to consider
        -Threshold for topic occurrence.
        -comms = community lookup (or false, if we are not using communities)
    
    Returns:
        -A table with levels of activity per topic and year
    
    '''

    #Topic count per year
    
    topic_count = pd.concat([pd.crosstab(topic_mix[year_var],topic_mix[t]>thres)[True] for t in topics],axis=1).fillna(0)
    topic_count.columns = topics
    

        #Count papers per topic
        #topic_count = pd.concat([pd.crosstab(topic_mix[year_var],topic_mix[t]>0)[True] for t in topics],axis=1).fillna(0)
        
        #Add columns
        
        
    #Normalise years
    topic_count = topic_count.loc[np.arange(year_lim[0],year_lim[1])].fillna(0)
        
    return(topic_count)
    
    
    
def plot_trend_of_interest(trend_df,topics,ax,wind=3,norm=False,**kwargs):
    '''
    Plots a trend of interest.
    
    Args: 
        trend_df: the df where rows = years and column = topics
        topic: topic or topics of interest
        wind: rolling mean normalisation
        norm: if 2 = normalise for year (importance of a topic in the period) if 1 = normalise for topic (share of year activity in the topic). If False = don't normalise
        
    Returns the plot

    '''
    
    #Normalise or not?
    
    if norm==False:
        trend_df[topics].rolling(window=wind).mean().dropna().plot(ax=ax,**kwargs)
        
    else:
        trend_norm = trend_df.apply(lambda x: x/x.sum(),norm-1).fillna(0)
        
        #print(trend_norm)
    
        trend_norm[topics].rolling(window=wind).mean().dropna().plot(ax=ax,**kwargs)
    

def trend_comparison(topic_mix,topics,var,ax,year_var='year',year_lim = [2000,2019],thres=0,norm=2):
    '''
    Compares two groups in a trend of interest
    
    Args:
        -topic_mix = topic mix
        -topics: topics of interest
        -var: variable we want to compare
        -ax will generaly be a matplotlib axis with two rows 
        -The year variable to consider
        -Threshold for topic occurrence.
        -comms = community lookup (or false, if we are not using communities)
    
    Returns the plot
    
    '''
    
    outputs = [trend_analysis(topic_mix.loc[topic_mix[var]==val],topics) for val in [False,True]]
    
    for n,out in enumerate(topics):
        
        #print(out)
        plot_trend_of_interest(out,topics,norm=norm,ax=ax[n])
    
def make_network_from_doc_term_matrix(mat,threshold,id_var):
    '''
    Create a network from a document term matrix.
    
    Args
        Document term matrix where the rows are documents and the columns are topics
        threshold is the threshold to consider that a topic is present in a matrix.
        
    Returns: 
        A network
    
    '''
    
    #Melt the topic mix and remove empty entries
    cd = pd.melt(mat.reset_index(drop=False),id_vars=[id_var])

    cd = cd.loc[cd['value']>threshold]

    #This gives us the topic co-occurrence matrix
    co_occurrence = cd.groupby(id_var)['variable'].apply(lambda x: list(x))
    
    #Here the idea is to create a proximity matrix based on co-occurrences

    #Turn co-occurrences into combinations of pairs we can use to construct a similarity matrix
    sector_combs = flatten_list([sorted(list(combinations(x,2))) for x in co_occurrence])
    sector_combs = [x for x in sector_combs if len(x)>0]

    #Turn the sector combs into an edgelist
    edge_list = pd.DataFrame(sector_combs,columns=['source','target'])

    edge_list['weight']=1

    #Group over edge pairs to aggregate weights
    edge_list_weighted = edge_list.groupby(['source','target'])['weight'].sum().reset_index(drop=False)

    edge_list_weighted.sort_values('weight',ascending=False).head(n=10)
    
    #Create network and extract communities
    net = nx.from_pandas_edgelist(edge_list_weighted,edge_attr=True)
    
    return(net)
    

In [None]:
def make_highlight_plot(trends,vars_interest,ax,cmap,alpha=0.3,lab_map=False):
    '''
    Creates a df where we select the topics to focus on
    
    
    Args:
        Trend is a trend df
        vars_interest are the topics or variables we eanrt to focus on
        ax the axis
        cmap is the color map we want to use
        lab_map is the tidy label map we use
    
    Returns a plot
    
    '''
    
    #Create a lookup with numbers for values
    topic_lookup = {name:val for val,name in enumerate(vars_interest)}

    #Color map
    cols = plt.cm.get_cmap(cmap)

    #Create a vector of colors
    cols_to_show = [(0.5,0.5,0.5,alpha) if v not in topic_lookup.keys() else cols(topic_lookup[v]) for v in trends.columns]
    lw = [1 if v not in topic_lookup.keys() else 3 for v in trends.columns]
    
    #Plot
    (100*trends.rolling(window=4).mean()).dropna().plot(color=cols_to_show,ax=ax,linewidth=3)

    #Fix the legend to focus on key topics
    hand,labs = ax.get_legend_handles_labels()

    ax.legend(bbox_to_anchor=(1,1),handles = [x[0] for x in zip(hand,labs) if x[1] in vars_interest],
              #labels=map(lambda z: lab_map[z], [x in [x[1][:50] for x in zip(hand,labs) if x[1] in vars_interest]])
              labels=map(lambda x: lab_map[x],[x[1][:50] for x in zip(hand,labs) if x[1] in vars_interest]) if lab_map!=False
              else [x[1][:50] for x in zip(hand,labs) if x[1] in vars_interest]
             )

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from scipy.stats import zscore
from sklearn.metrics import pairwise_distances
import seaborn as sns

def make_tidy_lookup(names_list,length=False):
    '''
    
    Creates a cheap lookup between names, removing underscores and capitalising
    
    Args:
        names_list (list) is the list of names we want to tidy
        length is if we want to only keep a certain length of the name
    
    '''
    
    out = {x:re.sub('_',' ',x).capitalize() for x in names_list}
    return(out)

def show_network(ax,net,top_edge_share,label,loc,
                 size_lookup,
                 color_lookup,norm=2000,norm_2=1.2,layout=nx.kamada_kawai_layout,ec='white',alpha=0.6):
    '''
    Plots a network visualisation of the topic network.
    
    
    '''
    
    new_net = net.copy()
    
    #Get the weights
    net_weight = sorted(new_net.edges(data=True),key=lambda x: x[2]['weight'],reverse=True)

    #Select how many top edges do we want to visualise
    length = int(top_edge_share*len(net_weight))
    
    #Select the top edges
    top_edges = net_weight[:length]

    #Create a network with them
    new_net_2 = nx.Graph(top_edges)
    
    #Calculate the layout
    pos = layout(new_net_2,
                 #weight='weight',
                 center=(0.5,0.5)
                )
    
    #Draw the network. There is quite a lot of complexity here
    nx.draw_networkx_nodes(new_net_2,pos,
                       node_size=list([size_lookup[x]**norm_2 for x in dict(new_net_2.degree).keys()]),
                       node_color = [color_lookup[comm_names[comms[x]]] if comm_names[comms[x]] in color_lookup.keys() else 'white' for x in dict(new_net_2.nodes).keys()],
                       cmap='tab20c',
                       alpha=0.7,edgecolors='darkgrey',ax=ax)

    nx.draw_networkx_edges(new_net_2,pos,width=[e[2]['weight']/norm for e in new_net_2.edges(data=True)],edge_color=ec,ax=ax,alpha=alpha)

def make_time_net(ax,dataset,size_lookup,my_label,ec='darkgrey',alpha=0.8):
    '''
    Function to visualise a network
    
    Args:
        dataset (df) is the topic co-occurrence matrix we want to use to create the network
        label (str) is the title for the network
        size_lookup (dict) is the size (level of activity) for each topic
        fig_save (str) is the name we use to save the figure
        
    '''
    #Extract the network based on the input co-occurrence matrix
    
    top_net_old= make_network_from_doc_term_matrix(dataset,0.025,'paper_id')

    #Show the network
    show_network(ax,top_net_old,0.02,label=my_label,norm=500,norm_2=0.9,
                 color_lookup=color_lookup,size_lookup=size_lookup,
                 layout=nx.kamada_kawai_layout,loc=(-0.29,1.1),ec=ec,
                 alpha=alpha)
    
    #Remove ticks
    ax.set_xticks([])
    ax.set_yticks([])
    
    ax.set_title(my_label,size=18)

def plot_centrality(network,measure,cl,ax,plot_name):
    '''
    This is to plot the centrality of different topics inside the topic network.
    
    Args:
        -network is the network whose centralities we want to plot
        -measure is the measure we want to plot
        -colour lookup is to colour the bars in the network
        -ax is the axis
    
    Returns a plot of the distributions of centrality
    
    '''
    
    #Calculate the centrality measure and normalise it
    c = pd.Series(measure(network,weight='weight'))
    
    #Normalise the centrality
    c_norm =  pd.Series(zscore(c),index=c.index)
    
    #Sort by centralities
    c_sorted = c_norm.sort_values(ascending=False)
    
    #Add colors based on the colour lookup
    cols = [cl[comm_names[comms[x]]] if comm_names[comms[x]] in cl.keys() else 'lightgrey' for x in c_sorted.index]
    
    #Plot
    c_sorted.plot.bar(color=cols,ax=ax,width=1)
    
    #Some final changes in the plot
    ax.legend(handles=patches,ncol=3)
    ax.set_xticklabels([])
    ax.set_xticks([])
    ax.set_ylabel('Normalised centrality')

def make_disruption_tables(df,period=np.arange(2000,2019)):
    '''
    This function creates two datasets capturing inter-year changes in activity, which we consider a proxy for 'disruption'
    
    Arguments:
        df (df) is a dataframe with the topics
        period is the period we are interested in capturing
        
    Will return a df and a table of mean changes ready for visualisation
    
    '''
    
    #We want to measure distances between activity profiles in years

    #We create a vector with counts of papers with activity in a year
    year_topics = pd.concat([(df.loc[df['year']==y,topics]>0.05).sum() for y in period],axis=1)

    year_topics.columns = period

    #We normalise the results (we want to consider the relative importance of topics, not absolute)
    topics_years_norm = year_topics.T.apply(lambda x: zscore(x)).dropna(axis=1)
    
    #We calculate distances between years
    year_sims = pd.DataFrame(1-pairwise_distances(topics_years_norm,metric='cosine'),index=period,columns=period)

    
    #We also calculate rolling intra-year distances. We focus on the diagonal for visualisation
    mean_sims = pd.Series(np.diag(np.matrix(year_sims.rolling(window=3).mean())))
    mean_sims.index = period
    
    return([year_sims,mean_sims])

def make_disruption_plot(disr_inputs,ax):
    '''
    This function creates a disruption plot
    
    Arguments:
        -disr_inputs (list) is the output of the make_disruption_tables
    
    '''
    #Get the period
    period = disr_inputs[0].columns
    
    #This is to select the lower triangular matrix for the visualisation
    year_sims_2 =  pd.DataFrame(np.tril(disr_inputs[0], 
                                        k=0),index=period,columns=period).applymap(
        lambda x: np.nan if x==0 else x) #We make the zeroes nans to colour things later
    
    
    #Get the colourmap

    my_map = plt.cm.get_cmap('seismic')
    
    #Set missing values to white
    my_map.set_bad('white')


    #fig,ax = plt.subplots(figsize=(10,8),nrows=2,gridspec_kw={'height_ratios':[3,1.2]})
    
    #Create the heatmap
    im = ax[0].imshow(year_sims_2,cmap=my_map,aspect='auto')
    
    
    #Some formatting of labels etc
    #You always have to add the ticks and the ticklabels when doing imshow
    ax[0].set_xticks([])
    ax[0].set_xticklabels([])
    ax[0].set_yticks(np.arange(0,len(period)))
    ax[0].set_yticklabels(period)
    

    #ax[0].set_title('Year on year topic similarity',size=14)

    #We remove the top and right-and lines of the frame
    ax[0].spines['top'].set_edgecolor('white')
    ax[0].spines['right'].set_edgecolor('white')
    
    #Plot the mean 'disruption'
    
    disr_inputs[1].plot(ax=ax[1])
    
    #Add a y label 
    ax[1].set_ylabel('Year-on-year \n similarity \n (rolling mean)')
    
    #Fix the x axis
    ax[1].set_xticks(np.arange(1999,2019))
    ax[1].set_xticklabels(np.arange(2000,2019))

    #Remove the top line
    ax[1].spines['top'].set_edgecolor('white')
    
    #Add a vertical grid to connect both charts
    ax[1].grid(which='both', axis='x', linestyle='--')

    #And finally the colour bar
    #The list there describes the position of the new axis
    cbaxes = fig.add_axes([0.55, 0.8, 0.3, 0.02]) 
    
    #We draw the axis
    cb = plt.colorbar(im, cax = cbaxes,orientation='horizontal')  
    
    #And name it
    cb.ax.set_title('Inter-year cosine similarity')

In [None]:
def change_shares(reps,drop_china=False):
    '''
    Returns a df with mean changes in representation per country.
    
    Args:
        Reps: dfs with shares of a country in total activity in different years
        drop_china whether we want to drop China from the analysis
    
    We can use this to calculate means, variances, etc.
    
    
    '''
    
    mean_change= []

    for l2 in all_reps:
        
        l = l2.copy()
        
        if drop_china==True:
            l.drop('China',axis=0,inplace=True)
    
        l['change'] = 1-(l.iloc[:,1]/l.iloc[:,0])

        #print(l)

        mean_change.append(l['change'])
        
    return(pd.DataFrame(mean_change,index=['All arXiv','All AI','SoTA AI']).T)
    #return(mean_change)

def share_all(df,countries,years,n):
    '''
    This function compares shares of activity before and after a year for all papers
    
    We use it instead of the one above because the data is in a different format
    
    Args:
        df (df) with activity and country information
        countries (list) countries we want to focus on
        years (list) of threshold years
        
    
    '''
    #AI activity
    p_1 = df.loc[df['year']<years[0]]
    p_2 = df.loc[df['year']>years[1]]
    
    #All share
    act_share = pd.concat([p['institute_country'].value_counts(normalize=True) for p in [p_1,p_2]],axis=1)
    
    act_share.columns = [f'{n} before {years[0]}',f' {n} after {years[1]}']
    
    return(act_share.loc[countries])
    
    

In [None]:
def flatten_freq(nested_list):
    '''
    
    Function to calculate frequencies of elements within a nested list
    
    '''
    
    return(pd.Series(flatten_list(nested_list))).value_counts()

def share_comp(df,countries,years,n):
    '''
    Function to compare shares of activity before and after a year
    
    Arguments:
        df (df) is the usual dataframe
        countries (list) is the list of top countries
        years (list) are the max year for period 1 and the min period for year 2
    
    '''
    
    #Extract both periods
    p_1 = df.loc[df['year']<years[0]]
    p_2 = df.loc[df['year']>years[1]]
    
    #Extract share of papers by country
    ai_share = pd.concat([pd.DataFrame([pd.Series(len(p.loc[[c in c_list for c_list in p['country_list']]])/len(p),name=c) for
                c in countries]) for p in [p_1,p_2]],axis=1)
    
    ai_share.columns = [f'{n} before {years[0]}',f' {n} after {years[1]}']
    
    return(ai_share) 


def make_comp_plot_2(comp_list,ax):
    '''
    Creates a comparison plot between variables
    
    Args:
        list is a list of three objects with research activities in different times
    
    '''
    
    s=70
    
    colors = [['mistyrose','lightblue'],['salmon','cornflowerblue'],['red','blue']]
    
    for offs,element,color in zip([-0.2,0,0.2],comp_list,colors):
        
        direction = [element.iloc[n,1]-element.iloc[n,0] for n in np.arange(0,len(element))]
        
        ax.scatter([x+offs for x in np.arange(0,len(element))],100*element.iloc[:,0],color=color[0],marker='o',
                   edgecolors='darkgrey',
                   s=s)
       

        for n,c in enumerate(element.index):
            ax.scatter(n+offs,100*element.iloc[n,1],
                       color=color[1],
                       marker='^' if direction[n]>0 else 'v',
                       edgecolors='darkgrey',
                       s=s+20)
            
            ax.vlines(x=n+offs,ymin=100*element.loc[c].min(),ymax=100*element.loc[c].max(),
                      color=color[1] if element.loc[c][1]>element.loc[c][0] else color[0],
                      #linestyle=':',
                      linewidth=1)

        
    ax.set_xticks(np.arange(0,len(comp_list[0])))
    ax.set_xticklabels(comp_list[0].index,rotation=90)
    
    
    for n in np.arange(0,len(comp_list[0])):
        ax.vlines(x=n-0.5,ymin=0,
                  ymax=40,color='darkgrey',linestyle=':')
        
        ax.vlines(x=n+0.5,ymin=0,
                  ymax=40,color='darkgrey',linestyle=':')

                       
                       
                       
                       
    
    ax.set_ylabel('% of all activity with presence')

## 1. Load data

`analysis_pack` contains the metadata and data that we serialised at the end of the `06` data integration notebook.

This includes:

* Community names for the communities (`index->community name`)
* Community indices for topics (`topic -> community index`)
* Filtered topic names (`topic names`)
* Network object with topic co-occurrences
* Analysis df
* arx is the enriched arXiv dataset



In [None]:
with open('../data/processed/24_8_2019_analysis_pack.p','rb') as infile:
    analysis_pack = pickle.load(infile)

In [None]:
comm_names = analysis_pack[0]
comms = analysis_pack[1]
topics = analysis_pack[2]
network = analysis_pack[3]
data = analysis_pack[4]
arx = analysis_pack[5]

In [None]:
len(arx['has_female'].dropna())

In [None]:
#We load this to consider overall research trends
arx_geo = pd.read_csv('../data/external/17_8_2019_papers_institution_ucl_cleaned.csv',compression='zip',dtype={'article_id':str})

## 2. Analysis


### a. Trends

In [None]:
arx['year'] = [int(x) for x in arx['year']]

#### i. Total activity

In [None]:
fig_1_data = (100*pd.crosstab(arx['year'],arx['is_ai'],normalize=1))

ax = fig_1_data.plot(figsize=(8,5),linewidth=3)

ax.legend(title='AI paper',labels=['Not AI','AI'])

ax.set_ylabel('% of all papers in year')
#ax.set_title('AI research trends (overall)')

save_fig('fig_1_trends_overall.pdf')

In [None]:
# Cumulative research trends

fig_1_data.iloc[::-1].cumsum()[:5]

Three quarters of the AI papers in the data have been written in the last 5 years

#### ii. Activity by field

In [None]:
#These are the field names
field_names = ['field_astrophysics',
 'field_biological',
 'field_complex_systems',
 'field_informatics',
 'field_machine_learning_data',
 'field_materials_quantum',
 'field_mathematical_physics',
 'field_mathematics_1',
 'field_mathematics_2',
 'field_optimisation',
 'field_particle_physics',
 'field_physics_education',
 'field_societal',
 'field_statistics_probability']

#Create tidy field names for legend etc
tidy_field_lookup = {x:re.sub('_',' ',x[6:]).capitalize() for x in field_names}

In [None]:
#We will not plot maths as we have two categories
fields_to_plot = [x for x in field_names if not any(num in x for num in ['1','2'])]

In [None]:
#AI in fields
ai_in_fields = pd.concat([pd.crosstab(arx.loc[arx[t]>0.5]['year'],
                                     arx.loc[arx[t]>0.5]['is_ai'],normalize=0)[1] for t in fields_to_plot],axis=1).fillna(0)

ai_in_fields.columns = fields_to_plot

#Sort top fields (for the legend)
top_ai_fields = ai_in_fields.loc[2018].sort_values().index[::-1][:9]

top_ai_fields_all = ai_in_fields.loc[2018].sort_values().index[::-1]

In [None]:
#AI in fields share of activity

fig,ax = plt.subplots(nrows=6,ncols=2,figsize=(6,10),
                      sharey='row',
                      sharex=True)

row=0
col=0

for n,f in enumerate(top_ai_fields_all):
    
    
    
    rel = arx.loc[arx[f]>0.5]

    year_act = pd.crosstab(rel['year'],rel['is_ai'],normalize=1).loc[np.arange(2000,2019)]
    
    
    
    (100*year_act[0].rolling(window=3).mean()).dropna().plot(ax=ax[row,col],legend=False,linewidth=1.5,linestyle='--')
    (100*year_act[1].rolling(window=3).mean()).dropna().plot(ax=ax[row,col],legend=False,linewidth=3)
    
    ax[row,col].set_title(tidy_field_lookup[f])
    
    if n % 2==0:
        
        row=row
        col=1
    
    else:
        row=row+1
        col=0
        


plt.tight_layout()

ax[0,0].legend(labels=['Not AI','AI'])

save_fig('fig_2_trend_fields.pdf')
    

In [None]:
100*pd.crosstab(arx['is_ai'],arx['top_field'],normalize=1)

In [None]:
# ax = (100*ai_in_fields.loc[np.arange(2000,2019),top_ai_fields].rolling(window=3).mean()).dropna().plot(figsize=(10,6),cmap='tab10',linewidth=3)

# ax.legend(bbox_to_anchor=(1,1),title='Scientific field',labels=list(map(lambda x: tidy_field_lookup[x],top_ai_fields)))

# ax.set_title('AI intensity by scientific field')

# ax.set_ylabel('AI as % of papers in topic')
# ax.set_xlabel('')

# save_fig('fig_2_trends_field.pdf')

In [None]:
def get_example(df,number,length):
    '''
    Gets random examples in a field
    
    Args:
        Df is the dataframe we want to use
        number is the number of examples we want
        length is the length of the examples
    
    '''
    
    choose = random.sample(list(df.index),number)
    
    for x in df.loc[choose]['abstract']:
        
        print(x[:length])
        print('\n')
    

In [None]:
# for x in ['field_astrophysics','field_biological','field_complex_systems','field_materials_quantum','field_societal']:
    
#     print(x)
#     print('====')
    
#     d = arx.loc[(arx['is_ai']==True) & (arx[x]>0.75)].reset_index(drop=True)
    
#     get_example(d,5,1000)

#### iii. Activity by topic community

In [None]:
# Create community names.
# We remove 'mixed' since this is an excluded community
community_names = list(set(comm_names.values()))

community_names.remove('mixed')

In [None]:
#Plot some of the topics
topics_for_plot = ['symbolic','statistics',
                   'deep_learning','computer_vision','robotics_agents','language']

tidy_comm_names = make_tidy_lookup(community_names)

In [None]:
# This is to normalise the years
comm_trends = trend_analysis(data,community_names,thres=0.05)
all_years = data['year'].value_counts()
comm_norm = comm_trends.apply(lambda x: x/all_years).dropna()

In [None]:
#Sort the communities for the legend
sorted_comms = [x for x in comm_trends.loc[2018].sort_values(ascending=False).index]

In [None]:
fig,ax = plt.subplots(figsize=(10,6))

make_highlight_plot(comm_norm[sorted_comms],
                    topics_for_plot,cmap='Dark2_r',ax=ax,alpha=0.15,lab_map=tidy_comm_names)

#ax.set_title('Evolution of activity by topic_community')

ax.set_ylabel('% of all AI papers with topic presence')

plt.tight_layout()

save_fig('fig_3_topic_community.pdf')

#### v. Activity by detailed topic

In [None]:
notable_topics = [
    'reinforcement_learning-policy-policies-reward-deep_reinforcement_learning',
    'cnn-convolutional_neural_networks-cnns-convolutional_neural_network-convolutional_neural_network_cnn',
    'training-trained-deep_learning-deep-train',
    'generator-gan-discriminator-generative_adversarial_networks_gans-gans',
    'translation-neural_machine_translation-machine_translation-translate-translations',
    'recurrent-lstm-rnn-recurrent_neural_network-recurrent_neural_networks']

In [None]:
topic_trends = trend_analysis(data,topics,thres=0.05)
all_years = data['year'].value_counts()
topic_trends_norm = topic_trends.apply(lambda x: x/all_years).dropna()

In [None]:
#Tedious sorting of topics
sorted_topics = topic_trends_norm.rolling(window=3).mean().loc[2018].sort_values(ascending=False).index
notable_sorted = [x for x in sorted_topics if x in notable_topics]

In [None]:
fig,ax = plt.subplots(figsize=(14,6))

make_highlight_plot(topic_trends_norm.loc[np.arange(2005,2019),sorted_topics],notable_sorted,cmap='Dark2',ax=ax,alpha=0.1,lab_map=False)

#ax.set_title('Evolutio of activity by detailed topic')

ax.set_ylabel('Share of AI papers with topic')

plt.tight_layout()

save_fig('fig_4_trending_topics.pdf')

### b. Networks

We will combine a plot of network structure and centrality




In [None]:
#We want to make the size of the nodes comparable between years
size_lookup = pd.concat([(data.loc[[x in year_set for x in data['year']]][topics]>0.05).sum() for 
                         year_set in [
                             set(np.arange(1990,2019)),
                             set(np.arange(1990,2012)),
                             set(np.arange(2012,2015)),
                             set(np.arange(2015,2019))]],axis=1)

size_lookup.columns = ['all','pre','mid','late']

size_lookup_dict = size_lookup.to_dict()

In [None]:
color_lookup = {
    'deep_learning':'blue',
    'robotics_agents':'cornflowerblue',
    'computer_vision':'aqua',
    'symbolic':'red',
    'health':'lime',
    'social':'forestgreen',
    'technology':'magenta',
    'statistics':'orange',
    'language':'yellow'
}

In [None]:
patches = [mpatches.Patch(facecolor=c, label=tidy_comm_names[l],edgecolor='black') for l,c in color_lookup.items()]


#### Main network

In [None]:
fig,ax = plt.subplots(figsize=(12,8))

#Show the network
show_network(ax,network,0.05,norm=200,norm_2=0.9,color_lookup=color_lookup,size_lookup=size_lookup['all'],
             layout=nx.kamada_kawai_layout,label='All years',loc=(-0.5,1.48),ec='black',alpha=0.7)

#Draw the legend
ax.legend(handles=patches,facecolor='white',loc='upper right',title='Area')

#Remove ticks
ax.set_xticks([])
ax.set_yticks([])

plt.tight_layout()

save_fig('fig_5_network_all_years.png')


In [None]:
# color_lookup_2 = {
#     'deep_learning':'blue',
#     #'robotics_agents':'cornflowerblue',
#     'computer_vision':'aqua',
#     'symbolic':'red',
#     'statistics':'orange',
#     #'language':'yellow'
# }

In [None]:
patches = [mpatches.Patch(facecolor=c, label=tidy_comm_names[l],edgecolor='black') for l,c in color_lookup.items()]

In [None]:
#Create the integrated network - centrality plot
fig,ax = plt.subplots(figsize=(14,8),nrows=2,ncols=2,gridspec_kw={'width_ratios':[1.5,2]})

#Subset the old period
old_period = data.loc[data['year']<2011][topics]

#Make the network
make_time_net(ax[0][0],old_period,size_lookup['pre'],my_label='Before 2012')

#Plot the centrality
plot_centrality(make_network_from_doc_term_matrix(old_period,0.025,'paper_id'),
                nx.eigenvector_centrality,cl=color_lookup,ax=ax[0][1],plot_name='Before 2012')

#Same as above but with the more modern data
late_period = data.loc[(data['year']>2015)][topics]

#Make networks
make_time_net(ax[1][0],late_period,size_lookup['late'],my_label='After 2015')

#Plot centrality
plot_centrality(
    make_network_from_doc_term_matrix(late_period,0.025,'paper_id'),nx.eigenvector_centrality,cl=color_lookup,ax=ax[1][1],plot_name='After 2015')

plt.tight_layout()

#save_fig('fig_6_network_comp.png')


In [None]:
# #Create the integrated network - centrality plot
# fig,ax = plt.subplots(figsize=(14,6),nrows=2,ncols=2,gridspec_kw={'width_ratios':[1.5,2]})

# #Subset the old period
# old_period = data.loc[data['year']<2011][topics]

# #Make the network
# make_time_net(ax[0][0],old_period,size_lookup['pre'],my_label='Before 2012')

# #Plot the centrality
# plot_centrality(make_network_from_doc_term_matrix(old_period,0.025,'paper_id'),
#                 nx.eigenvector_centrality,cl=color_lookup,ax=ax[0][1],plot_name='Before 2012')

# #Same as above but with the more modern data
# late_period = data.loc[(data['year']>2015)][topics]

# #Make networks
# make_time_net(ax[1][0],late_period,size_lookup['late'],my_label='After 2015')

# #Plot centrality
# plot_centrality(
#     make_network_from_doc_term_matrix(late_period,0.025,'paper_id'),nx.eigenvector_centrality,cl=color_lookup,ax=ax[1][1],plot_name='After 2015')

# plt.tight_layout()

# #save_fig('fig_6_network_comp.png')

# save_fig('neurips_network_comp.png')


### 3. Thematic disruption

Our final descriptive analysis considers disruption over time: what have been the changes in the composition of AI since the 2000s?

We create a matrix that compares the topic vector for every year (a normalised sum) across years.

In [None]:
#We use the function that we defined above
disr = make_disruption_tables(data)

In [None]:
#We plot the results, which show quite starkly the disruption in AI research before and after 2012.

fig,ax = plt.subplots(figsize=(10,8),nrows=2,gridspec_kw={'height_ratios':[3,1.2]})

make_disruption_plot(disr,ax=ax)

plt.subplots_adjust(hspace=0)

plt.savefig(f'../reports/figures/paper_rev/{today_str}_fig_7_trends.pdf')

Can we calculate the half life of similarity?

### 4. Spatial disruption

Here we want to calculate disruption measures for the top 10 countries by activity

* Changes in the share of the total by country and in notable topics

* Measures of disruptions like above


In [None]:
#We merge these two so we can look at geographical activity by year
arx_geo_year = pd.merge(arx_geo,arx[['article_id','year']],left_on='article_id',right_on='article_id')

In [None]:
#Focus on data with country information
data_w_countries = data.dropna(axis=0,subset=['country_list'])

#Focus on top countries - we ignore #1 because it is 'multinational'
top_countries = list(flatten_freq(data_w_countries['country_list'])[1:11].index)

data_w_countries_core = data_w_countries.loc[data_w_countries[notable_topics].apply(lambda x: any(x>0.05),axis=1)]


In [None]:
#Calculate national representation in AI and its components
national_rep = [share_comp(d,top_countries,[2012,2015],n=n) for d,n in zip([data_w_countries,data_w_countries_core],['All AI','SotA AI'])]

#Calculate representation in all activity 
national_rep_all = share_all(arx_geo_year,top_countries,[2012,2015],n='All arXiv')

#Combine them
all_reps =  [national_rep_all]+national_rep

In [None]:
#Patches for the legend
patches_2 = [mpatches.Patch(facecolor=c, label=l,edgecolor='black') for l,c in 
          zip(['All arXiv before 2012','All AI before 2012','All SotA before 2012',
               'All arXiv after 2015','All AI after 2015','All SotA after 2015'],
              ['mistyrose','salmon','red','lightblue','cornflowerblue','blue'])]

In [None]:
#Plot everything

In [None]:
fig,ax = plt.subplots(figsize=(10,5))

make_comp_plot_2(all_reps,ax)

ax.legend(handles=patches_2,ncol=2)

save_fig('fig_8_geo_changes.pdf')

In [None]:
#WHat is the variance in changes in representation (including China)
changes = change_shares(all_reps)

changes.var()


In [None]:
#And excluding China?
changes_nc = change_shares(all_reps,drop_china=True)

changes_nc.var()

Conclusion: the geography of AI research is changing faster than the geography of research overall, specially in State of the Art topics

In [None]:
0.18/0.05

In [None]:
all_reps[2]