# Section 2: AI topical analysis

Here we analyse the cluster membership of AI papers.

## Preamble

In [None]:
%run ../notebook_preamble.ipy

In [None]:
import re
import altair as alt
import random
from toolz.curried import *
from ai_covid_19.utils.utils import *
from ai_covid_19.hSBM_Topicmodel.sbmtm import sbmtm
from ai_covid_19.estimators.post_process_topsbm import *


In [None]:
driv = altair_visualisation_setup()

In [None]:
FIG_PATH = f"{project_dir}/reports/figures/mohg_figures"

In [None]:
def save_altair_(fig,name):
    
    save_altair(fig,name,path=FIG_PATH,driver=driv)

## 1. Read data

### rXiv metadata

In [None]:
rxiv = pd.read_csv(f"{data_path}/processed/rxiv_papers_update.csv",dtype={'id':str,'is_ai':bool,
                                                                    'is_covid':bool})
is_cov = set(rxiv.query('is_covid == True')['id'])

In [None]:
#Load and process model

with open(f"{project_dir}/models/top_sbm/top_sbm_ai.p",'rb') as infile:
    model = pickle.load(infile)

post = post_process_model(model,top_level=0,cl_level=1,top_thres=0.4)

topics = post[0].reset_index(drop=False
                                       ).melt(id_vars=['index','cluster'],var_name='topic',value_name='weight'
                                       ).rename(columns={'index':'article_id'}
                                       )
topics['is_cov'] = topics['article_id'].isin(is_cov)

In [None]:
#Add a cluster label to each article in the the covid df

ai = rxiv.query('is_ai == True')

ai['cluster'] = ai['id'].map(topics.drop_duplicates('article_id').set_index('article_id')['cluster'].to_dict())

In [None]:
#Remember we only trained the model for 2019 & 2020
ai_rec = ai.dropna(axis=0,subset=['cluster'])

## 2. Analyse data

### 1. Cluster content

We look for salient topics in clusters in order to interpret the clusters

#### Distribution of AI papers over topics

In [None]:
#Shares of Covid by cluster
cluster_cov = 100*pd.crosstab(ai_rec['cluster'],ai_rec['is_covid'],
                        normalize=1).sort_values(True,ascending=False)
#This is the list of clusters to order the chart
bar_order = clean_cluster(cluster_cov.index)

cluster_distr = cluster_cov.reset_index(drop=False).melt(id_vars='cluster')

In [None]:
#Clean up the Cluster distribution df using various functions from utils
cluster_distr['cluster'] = clean_cluster(cluster_distr['cluster'])
cluster_distr['is_covid'] = convert_covid(cluster_distr['is_covid'])
cluster_distr['value_label'] = make_pc(cluster_distr['value'])
cluster_distr.rename(columns={'is_covid':'Category'},inplace=True)

#### Salient topics in AI papers

In [None]:
#We look for the top topics and the topics where AI has the highest share of activity
w = 0.1
topic_count = topics.loc[
    topics['weight']>w].groupby(['topic','cluster']).size().reset_index(name='count')

#This is to clean variable names
topic_count['cluster'] = clean_cluster(topic_count['cluster'])
topic_count['topic'] = [
    ', '.join([re.sub('-','',x.capitalize()) for x in mix.split('_')]) for mix in topic_count['topic']]

#These are the top Covid topics
covid_topics = clean_topics(list(topics.loc[
    topics['weight']>w].groupby(['topic','is_cov']).size().reset_index(
    name='count').pivot_table(index='topic',columns='is_cov',values='count').fillna(
    0).assign(share=lambda x: x[True]/x[False]).sort_values('share',ascending=False).index))

#### Chart

In [None]:
#Barchart component
bar_b = (alt.Chart(cluster_distr)
         .mark_bar(opacity=0.5,width=5,stroke='black',strokeWidth=1)
         .encode(x=alt.X('cluster',sort=list(bar_order),axis=alt.Axis(labels=False,title="",ticks=False)),
                 y=alt.Y('value',title=['% of category', 'in cluster'],stack=False),
                 color='Category:N',tooltip=['value_label:N','Category:N']))
bar = bar_b.properties(height=100)

#Heatmap component (note that we focus on the top 40 AI topics)
hm_b = (alt.Chart(topic_count)
        .transform_filter(alt.FieldOneOfPredicate('topic',covid_topics[:40])))
hm = (hm_b.mark_rect(stroke='black')
      .encode(
          x=alt.X('cluster',sort=list(bar_order)),
          y=alt.Y('topic',sort=list(covid_topics),title='Salient terms in topic'),
          color=alt.Color('count:Q',title=['Number of papers', 'with topic']),
          tooltip=['topic','cluster']))

#Concatenate both
topic_chart = alt.vconcat(bar.properties(width=600),hm.properties(
    height=600,
    width=600),spacing=0).configure_axisX(grid=True)

save_altair_(topic_chart,'fig_4_topics')

topic_chart

### 2. Cluster provenance

What is the origin (in terms of article source) for different clusters in the data?

In [None]:
source_cluster_ai = ai_rec.groupby(['article_source','is_covid','cluster']).size().pipe(preview).reset_index(
    name='Number of Papers')

In [None]:
#Clean up variable names
source_cluster_ai['Source'],source_cluster_ai['Category'],source_cluster_ai['Cluster'] = [
    func(source_cluster_ai[x]) for func,x in zip([convert_source,convert_covid,clean_cluster],
                                                 ['article_source','is_covid','cluster'])]

In [None]:
source_cluster_ai.rename(columns={'paper_count':'Number of Papers'},inplace=True)

In [None]:
#Plot vhart
source_bar = (alt.Chart(source_cluster_ai)
     .mark_bar().encode(
         x=alt.X('Cluster:N',sort=list(bar_order)),
         y=alt.Y('Number of Papers:Q'),
         color='Source:N',
         tooltip=['Cluster','Category','Source','Number of Papers:Q'],
         row=alt.Row('Category',sort=['COVID-19','Not COVID-19'])))
source_bar = source_bar.resolve_scale(y='independent').properties(width=700,height=100)

save_altair_(source_bar,"fig_5_topic_sources")

source_bar

### 3. Export table with examples

In [None]:
#Dict to store result
example_table_content = {'cluster':[],'salient_topics':[],'example_ai_papers':[],'example_non_ai_papers':[]}

for n in np.arange(0,len(bar_order)):
    #Add name
    cluster_name = f"cluster_{str(n)}"
    
    #Get relevant papers
    rel = topics.loc[topics['cluster']==cluster_name]
    
    #Add top topics
    sal_tops = '\n'.join(
        rel.groupby('topic')['weight'].mean().sort_values(ascending=False).index[:3])
     
    #Get some of the papers
    rel_indices = set(rel['article_id'])
    
    cov_in_cluster = cov.loc[cov.id.isin(rel_indices)]
    
    
    ex_ai,ex_nai = ['\n'.join(get_examples(
        list(cov_in_cluster.loc[cov_in_cluster['is_ai']==val]['title']),values=2)) for val in [True,False]]
    
    #Store all the results
    example_table_content['cluster'].append(cluster_name)
    example_table_content['salient_topics'].append(sal_tops)
    example_table_content['example_ai_papers'].append(ex_ai)
    example_table_content['example_non_ai_papers'].append(ex_nai)
    
example_table = pd.DataFrame(example_table_content)
example_table.to_csv(f"{FIG_PATH}/table_1.csv",mode='w')

In [None]:
example_table.head()