# AI topic exploration

## Preamble

In [None]:
%run ../notebook_preamble.ipy

In [None]:
from cord19.transformers.nlp_2 import *
from toolz.curried import *
import altair as alt
from altair_saver import save
from selenium import webdriver

In [None]:
def preview(x):
    print(x.head())
    print(x.shape)
    return(x)


## 1. Load data

In [None]:
xiv = pd.read_csv(f"{project_dir}/data/processed/ai_research/xiv_papers_labelled.csv").pipe(preview)

In [None]:
ai = xiv.query("is_AI == 1").reset_index(drop=False).pipe(preview)

In [None]:
cov = pd.read_csv(f"{project_dir}/data/processed/ai_research/tidy_paper_topics_ai.csv").pipe(preview)

## Processing

* pre-process AI text
* Train topic model

In [None]:
#Clean and tokenise the AI data
#Remove line breaks
abstr_clean = [re.sub("\n"," ",x).strip() for x in ai['abstract']]

In [None]:
ct = CleanTokenize(abstr_clean)
ct.clean().bigram().bigram()

In [None]:
#Train a topic model with eg 100 topics

lda = LdaPipeline(ct.tokenised)

In [None]:
lda.filter().process().fit_lda(num_topics=100)

In [None]:
lda.lda_topics

In [None]:
lda.predict_topics()

In [None]:
num_words = 5
topic_word_mixes = [x[1].split("+") for x in lda.lda_topics]
topic_names = ["_".join([re.sub('"','',w.split("*")[1].strip()) for w in x][:num_words]) for x in topic_word_mixes]

In [None]:
ai_year_map = ai.set_index('id')['year'].to_dict()

In [None]:
topic_df = lda.predicted_df.copy()

In [None]:
topic_df.columns = topic_names
topic_df['id'],topic_df['mag_id'] = ai['id'],ai['mag_id']

topic_df['year'] = topic_df['id'].map(ai_year_map)

In [None]:
topic_long = topic_df.melt(id_vars=['id','mag_id','year'],var_name='topic',value_name='weight'
                          ).pipe(preview)

In [None]:
topic_long.to_csv(f"{project_dir}/data/processed/ai_research/ai_topics.csv",index_label=False)

## 2. Analysis

In [None]:
#Extra labels for analysis
cov_ids = set(cov['index'])
cov_lookup = cov.set_index('index')['cluster'].to_dict()


topic_long['is_covid'] = topic_long['id'].apply(lambda x: x in cov_ids)
topic_long['cluster'] = topic_long['id'].map(cov_lookup)

In [None]:
#Here we subset the long topic df to focus on papers with some topic presence
thres = 0.01

#We are focusing on papers published recently
topic_long_recent = topic_long.query("year > 2019")

#We will use these topics for normalisation
totals = topic_long_recent.drop_duplicates('id')['is_covid'].value_counts()

papers_with_topic = topic_long_recent.query(f"weight > {thres}").reset_index(drop=True).pipe(preview)
topic_distr = papers_with_topic.groupby(['topic','is_covid'])['weight'].sum().reset_index(drop=False)

In [None]:
#We focus on variables of interest
topic_distr_wide = topic_distr.pivot_table(index='topic',columns='is_covid',values='weight').sort_values(
    True,ascending=False).pipe(preview)

#This normalises the topics by total numbers of paper in a category
top_distr_norm = (100*(topic_distr_wide/totals)).reset_index(drop=False).melt(id_vars='topic').pipe(preview)

In [None]:
#Identify the papers with the biggest deltas between AI and non-AI
top_deltas = (topic_distr_wide/totals).assign(delta = lambda x: abs(x[True]-x[False])
                                             ).sort_values('delta',ascending=False).pipe(preview)

top_deltas['max'] = 100*top_deltas.iloc[:,:-1].max(axis=1).pipe(preview)

top_differences = top_deltas[:10].index

#Add the maximum value for a topic to help with the labelling later
top_distr_norm['max'] = top_distr_norm['topic'].map(top_deltas['max'].to_dict())

In [None]:
#Create the chart
base = alt.Chart(top_distr_norm).encode(x=alt.X('topic',sort=list(top_distr_norm.index),
                                               axis=alt.Axis(labels=False)),tooltip=['topic'])

p = base.mark_point(filled=True).encode(y='value',color='is_covid:N')

c = base.mark_line(strokeWidth=1,color='darkgrey',strokeDash=[1,1]).encode(y='value',detail='topic')

t = (base
     .transform_filter(alt.FieldOneOfPredicate('topic',list(top_differences)))
     .mark_text(align='left',fontSize=10,angle=0,xOffset=2,yOffset=-5,color='black',opacity=0.8)
     .encode(text='topic',y=alt.Y('max',title='% of papers with topic')))

out = (p + c + t).properties(width=400,height=450)

save(out,"test.png",method='selenium',
         webdriver=DRIVER,scale_factor=2)

out

## Statistical tests

In [None]:
from scipy.stats import ttest_ind

In [None]:
ttest_res = []

#Loops over the topic names and tests differences in means
for n,t in enumerate(topic_names):
    
    df_in_topic = topic_long_recent.loc[[x==t for x in topic_long_recent['topic']]]
    
    ttest = ttest_ind(df_in_topic.query(f'is_covid == True')['weight'],
                     df_in_topic.query(f'is_covid == False')['weight'])
    
    ttest_res.append(ttest)

In [None]:
rest = pd.DataFrame([pd.Series(
    {'t_stat_abs':abs(x[0]),'p_val':x[1],'higher_group': 'covid' if x[0]>0 else 'non_covid'},name=t) for t,x in zip(topic_names,ttest_res)])

In [None]:
rest_sort = rest.sort_values('t_stat_abs',ascending=False).reset_index(drop=False).iloc[:30].melt(
    id_vars=['index','higher_group'])

alt.Chart(rest_sort).mark_bar().transform_filter(alt.datum.variable=='t_stat_abs').encode(
    y=alt.Y('index',sort=alt.EncodingSortField('value',order='descending')),
    x=alt.X('value',title='t-statistic (absolute)'),
    color='higher_group').properties(width=200,height=400)

In [None]:
xiv.columns