# Section 3: Geographical analysis
 * Where is AI research happening?
 * Who is doing it?
 * Do we find any differences in the topics that different countries focus on?

## Preamble

In [None]:
%run ../notebook_preamble.ipy

In [None]:
import altair as alt
import random
from toolz.curried import *
from ast import literal_eval
from ai_covid_19.utils.utils import *


## 1. Read data

In [None]:
#All arXiv data
rxiv = pd.read_csv(f"{data_path}/processed/rxiv_metadata.csv",
                   dtype={'id':str,'is_ai':bool,'is_covid':bool,'mag_id':str}).pipe(preview)

In [None]:
topics = pd.read_csv(f"{data_path}/processed/covid_semantic.csv",
                    dtype={'article_id':str}).pipe(preview)

In [None]:
#Geodata
geo = pd.read_csv(f"{data_path}/processed/rxiv_geo.csv",
                 dtype={'article_id':str,'mag_id':str}).pipe(preview)

In [None]:
#Create a cov df
cov = rxiv.query("is_covid == True").reset_index(drop=True)

## 2. Analyse data

In [None]:
#Merge the rxiv metadata with the geocoded info, and label those institutions for which we don't have geo data
#as unmatched
rxiv_geo = pd.merge(rxiv,geo,left_on='id',right_on='article_id')
rxiv_geo['institute_country'].fillna('Unmatched',inplace=True) 

#### 1. Geography of activity

##### Country frequencies



In [None]:
#Focus on recent years
rxiv_geo =rxiv_geo.query('year >= 2019')

How active in Covid research are different countries?

In [None]:
country_freqs = rxiv_geo['institute_country'].value_counts().rename('all_arxiv')

In [None]:
#Country frequencies in different categories (based on the following queries)
queries = ["is_covid == True","is_ai == True","(is_covid ==1) & (is_ai ==True)"]
names = ['covid','ai','covid_ai']

all_acts = pd.concat([country_freqs,
    pd.concat([rxiv_geo.query(q)['institute_country'].value_counts(
    ).rename(n) for n,q in zip(names,queries)],axis=1)],axis=1,sort=True).fillna(0)

#Top countries
top_countries = list(all_acts.sort_values('covid_ai',ascending=False)[:25].index)

In [None]:
geo_activity_long_norm = (100*all_acts.apply(lambda x: x/x.sum()).sort_values('all_arxiv',ascending=False)).loc[
    top_countries].reset_index(drop=False).melt(id_vars=['index']).pipe(preview)

In [None]:
#Clean variable names
geo_activity_long_norm['variable'] = convert_group(geo_activity_long_norm['variable'])
geo_activity_long_norm.rename(columns={'variable':'Category','index':'Country'},inplace=True)
geo_activity_long_norm['% of activity'] = make_pc(geo_activity_long_norm['value'])

##### Cluster representation by country

In [None]:
#Here we are focusing on the covid AI papers 
cov_geo = rxiv_geo.query("(is_ai == True) & (is_covid == True)").reset_index(drop=False)

#Label them with their clusters
cluster_mapping = topics.drop_duplicates('article_id').set_index('article_id')['cluster'].to_dict()
cov_geo['cluster'] = cov_geo['id'].map(cluster_mapping)

#Get top clusters by AI activity
top_ai_clusters = topics.drop_duplicates('article_id').groupby(['is_ai','cluster']).size()[True].sort_values(
    ascending=False)[:8].index

#Cluster frequencies by cluster replacing less common clusters with "Other"
country_cluster = cov_geo.groupby(['institute_country','cluster']).size().reset_index(name='count')
country_cluster['cluster_short'] = [x if x in top_ai_clusters else 'Other' for x in country_cluster['cluster']]

#Clean up variable names etc
country_cluster['cluster_short'] = clean_cluster(country_cluster['cluster_short'])

country_cluster.rename(columns={'institute_country':'Country','cluster_short':'Cluster',
                               'count':'Number of papers'},inplace=True)


##### Create chart

In [None]:
#Components of first chart
base = (alt.Chart(geo_activity_long_norm)
        .encode(
            y=alt.Y('Country',sort=top_countries,title=''),
            x=alt.X('value',title='% of all activity in category')))

points = (base.mark_point(filled=True,
                 size=100,opacity=0.75,stroke='black',strokeWidth=1)
          .encode(
              color=alt.Color('Category'),
              shape=alt.Shape('Category',scale=alt.Scale(range=['circle','cross','circle','cross'])),
              tooltip = ['Category','Country','% of activity']))

points_line = (base.mark_line(strokeWidth=1.5,color='black')
               .encode(detail='Country'))

rel_line = (base
            .transform_filter(alt.datum.variable=='covid_ai')
            .mark_line(strokeWidth=1,color='steelblue',opacity=0.8,strokeDash=[2,1])
            .encode())

#Components of second chart
stack = (alt.Chart(country_cluster)
         .transform_filter(alt.FieldOneOfPredicate('Country',top_countries))
         .mark_bar(stroke='white',strokeWidth=0.1)
         .encode(
             y=alt.Y('Country',sort=top_countries,title=''),
             x='Number of papers',
             order=alt.Order('Number of papers',sort='descending'),
             tooltip = ['Country','Cluster','Number of papers'],
             color=alt.Color('Cluster',
                             title='Cluster',
                             sort=alt.EncodingSortField('Number of papers','mean','descending'))))

comp = (alt.hconcat((points+points_line+rel_line).properties(width=250,height=500),stack.properties(height=500,width=150))
 .resolve_scale(color='independent',shape='independent'))


comp.save(f"{fig_path}/fig_6.html")

comp

### 4. Evolution of activity

Here we compare the evolution of COVID-19 research activity between countries

In [None]:
#Research papers with geography
cov_geo_all = rxiv_geo.query("is_covid == 1").reset_index(drop=False)
cov_geo_all['date'] = pd.to_datetime(cov_geo_all['created'])

#Calculate trends focusing on top countries
cov_geo_trend = cov_geo_all.query('year ==2020').groupby(
    ['institute_country','is_ai','date']).size().loc[top_countries[:12]].reset_index(name='count')
cov_geo_trend['is_ai'] = convert_ai(cov_geo_trend['is_ai'])

##### Cumulative activity by year

When do different countries reach a critical mass of activity?

In [None]:
#Calculate using the geotrend data
cov_geo_cumul = (cov_geo_trend
                 .pivot_table(
                     index='date',columns='institute_country',values='count',aggfunc='sum').fillna(0)
                 .rolling(window=5)
                 .mean().dropna().cumsum())

#Calculate shares
cov_geo_shares = cov_geo_cumul/cov_geo_cumul.iloc[-1]

country_date = {'country':[],'first_date':[]}

#Extract the date when the country went over 25% of its total of activity
for c in cov_geo_shares.columns:
    first_date = (cov_geo_shares.loc[cov_geo_shares[c]>0.25]).index[0]
    country_date['country'].append(c)
    country_date['first_date'].append(first_date)

geo_dates_df = pd.DataFrame(country_date).sort_values('first_date',ascending=True)

countries_ordered = list(geo_dates_df['country'])

##### Create chart

In [None]:
trend_chart = (alt.Chart(cov_geo_trend).
               mark_line(opacity=0.9)
               .transform_window(m='mean(count)',frame=[-3,3],groupby=['institute_country','is_ai'])
               .encode(
                   x='date',
                   y=alt.Y('m:Q',
                           title=['Research','participations']),
                   color=alt.Color('is_ai:N',sort=['AI','Not AI']),
                   facet=alt.Facet('institute_country',columns=4,
                                   title='Country',
                                   sort = countries_ordered
                                   #sort=alt.EncodingSortField('count','sum',order='descending'))
                                  ))
               .properties(width=100,height=85)
               .resolve_scale(y='independent'))


trend_chart.save(f"{fig_path}/fig_7.html")

trend_chart