# Organisation exploration

Here we analyse the topic specialisation profiles of different organisations


## Preamble

In [None]:
%run ../notebook_preamble.ipy

import altair as alt
import yaml
from narrowing_ai_research.utils.altair_utils import *
from narrowing_ai_research.paper.s9_topic_comparison import *
#from narrowing_ai_research.paper.s10_vector_embeddings import *

pd.options.mode.chained_assignment = None 
%config Completer.use_jedi = False

In [None]:
# Run this if you want to save charts
driv = altair_visualisation_setup()

### Read data

In [None]:
papers,porgs,topic_mix,topic_category_map,arxiv_cat_lookup,topic_list = read_process_data()

### Create comparison

In [None]:
with open(f"{project_dir}/paper_config.yaml", "r") as infile:
        pars = yaml.safe_load(infile)["section_9"]
        
cats = pars['categories']
labels_to_display = pars['topic_highlights']

In [None]:
comp_ids,acad_ids = [query_orgs(porgs,
                                'org_type',t) for t in ['Company','Education']]

In [None]:
topic_comparison_chart,comp_table = make_chart_topic_comparison(
    topic_mix,arxiv_cat_lookup,[comp_ids,acad_ids],cats,['company','academia'],
    highlights=True,highlight_topics=labels_to_display,topic_list=topic_list,topic_category_map=topic_category_map,
save=False)

In [None]:
topic_comparison_chart

### Org profiles

In [None]:
porgs_ai = porgs.query("is_ai==True")
porgs_ai['year'] = porgs_ai['date'].map(lambda x: x.year)

In [None]:
old, mid, new = [make_chart_topic_spec(porgs_ai,'year',n,
                                       topic_category_map,
                                       cats,topic_mix) for n in [2018,2019,2020]]

In [None]:
save_altair((alt.hconcat(old,mid,new,
             title='Topic specialisation by year').configure_facet(spacing=0)
 .resolve_scale(y='shared')
 .configure_view(stroke=None)
 .configure_axis(labelFontSize=12, titleFontSize=12)),'trend_chart',driver=driv)



In [None]:
us,china,fr,canada,germany = [make_chart_topic_spec(porgs_ai,'institute_country',n,
                                       topic_category_map,
                                       cats,topic_mix) for n in ["United States","China",
                                                                 "France","Canada","Germany"]]

In [None]:
(alt.hconcat(us,china,fr,canada,germany,
             title='Topic specialisation by year').configure_facet(spacing=0)
 .resolve_scale(y='shared')
 .configure_view(stroke=None)
 .configure_axis(labelFontSize=12, titleFontSize=12))

## Another strategy for visualisation

In [None]:
def make_topic_rep_df(data, variable, value, topic_category_map, cats, topic_mix, ordered_cats=[]):
    
    logging.info(f"Extracting IDs {value}")
    _ids = set(data.loc[data[variable] == value]["article_id"])

    rep = (
        topic_rep(_ids, topic_mix, cats, topic_mix.columns, topic_category_map)[0]
        .dropna()
        .reset_index(drop=True))
    
    rep[variable]=value
    return rep

In [None]:
countries = ['United States','China','United Kingdom',
                                               'Australia','Germany',
                                               'Canada']

In [None]:
c = pd.concat([make_topic_rep_df(porgs_ai,'institute_country',c,topic_category_map,
                     cats,topic_mix) for c in countries]).reset_index(drop=True)

In [None]:
ordered_cats = c.groupby('cat_sel')['levels'].sum().sort_values(ascending=False).index.tolist()

c['centre']=0
c['ruler_2']=0

c = c.loc[c['ratio']<10]

cat_mean_status = (c.groupby(['cat_sel','institute_country'])['ratio'].mean()>0).to_dict()

c['ruler_color'] = ["High" if cat_mean_status[(x.cat_sel,x.institute_country)] is True else "Low" for
                   rid,x in c.iterrows()]
c['width'] = 10
c['long_names'] = [arxiv_cat_lookup[x][:35]+'...' for x in c['cat_sel']]

ordered_cat_names = [arxiv_cat_lookup[x][:35]+'...' for x in ordered_cats]

In [None]:
strip = (
        alt.Chart()
        .mark_circle(size=14, stroke="grey", strokeWidth=0.5)
        .encode(
            x=alt.X(
                "jitter:Q",
                title=None,
                axis=alt.Axis(values=[0], ticks=True, grid=False, labels=False),
                scale=alt.Scale(),
            ),
            y=alt.Y("ratio:Q", title="Specialisation",axis=alt.Axis(grid=True)),
            size=alt.Size(
                "levels",
                title=["Number", "of papers"],
            ),
            color=alt.Color(
                "long_names:N", scale=alt.Scale(scheme="tableau20"),title='arXiv categories',
                sort=ordered_cat_names
            ),
            opacity=alt.Opacity("ratio:Q", legend=None),
            tooltip=["index","cat_sel"])
    .transform_calculate(
        jitter="sqrt(-2*log(random()))*cos(2*PI*random())"
        ))

ruler_1 = (
    alt.Chart()
    .mark_tick(strokeWidth=2)
    .encode(
        x=alt.X('centre'),
        y='mean(ratio)',
        stroke=alt.Stroke('ruler_color',scale=alt.Scale(range=['red','blue']),title=['Average','specialisation']),
        ))

country_chart = (alt.layer(strip+ruler_1,
           data=c)
 .properties(width=15,height=150)
 .facet(
     column=alt.Column('cat_sel',sort=ordered_cats,title='arXiv category'
                      # header=alt.Header(labelAngle=90,labelAlign='left')
                      ),
     row=alt.Row('institute_country',sort=countries,
                title='Country'))).configure_facet(spacing=15).configure_view(
    stroke=None).configure_axis(grid=False)

In [None]:
save_altair(country_chart,'country_profiles_comparison',driv,fig_path)

In [None]:
country_chart