In [1]:
%run ../notebook_preamble.ipy
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 0) Preamble

### Set path variables

In [2]:
import sys
import os
import ratelim
from dotenv import load_dotenv,find_dotenv

PROJECT_PATH = project_dir
load_dotenv(find_dotenv())

sql_config = os.getenv('config_path')
AWS_SUBSCRIPTION_KEY = ''

### Imports and load data

In [3]:
%matplotlib inline
%autoreload 2
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import json
import itertools
from collections import Counter, defaultdict
from cord19.transformers.utils import get_engine
from cord19.transformers.utils import contains_keyword  # Specifies keywords ('SARS-CoV-2', 'COVID-19', 'coronavirus')
from nesta.packages.mag.query_mag_api import build_expr
from nesta.packages.mag.query_mag_api import query_mag_api

ModuleNotFoundError: No module named 'nesta'

In [None]:
%%time
con = get_engine(sql_config)
columns = ['id', 'created', 'title', 'abstract', 'mag_id', 'citation_count', 'article_source']
chunks = pd.read_sql_table('arxiv_articles', con, columns=columns, chunksize=1000)
covid_df = [df.loc[df.abstract.apply(contains_keyword) | df.title.apply(contains_keyword)]
            for df in chunks]
covid_df = pd.concat(covid_df)

In [None]:
covid_df.to_csv(f"{project_dir}/data/processed/covid_df.csv",index_label=False)

### Get MAG IDs for "covid+AI" articles

In [None]:
len(covid_df)

In [None]:
ai_paper_ids = set(pd.read_csv(f"{project_dir}/data/raw/ai_research/ai_paper_ids.csv")['id'])

In [None]:
#condition = covid_df.id.apply(lambda id: id in ai_paper_ids)
#mag_ids = [int(id) for id in covid_df.mag_id.loc[condition] if not pd.isnull(id)]

mag_ids = [int(id) for id in covid_df.mag_id if not pd.isnull(id)]

## 1) Get citation info from available MAG IDs

In [None]:
# Get the citation info
result_cont = []
for expr in build_expr(mag_ids, 'Id'):
    
    result = query_mag_api(expr, fields=['Id', 'CitCon'], subscription_key=AWS_SUBSCRIPTION_KEY)
    
    result_cont.append(result)

In [None]:
all_results = list(itertools.chain(*[x['entities'] for x in result_cont]))

In [None]:
# Mapping of {citing article id --> [list of citation article ids]}
citers = {int(article['Id']): list(article['CitCon'].keys()) 
          if 'CitCon' in article else [] for article in all_results}

# Set of ids of all cited articles
citee_ids = set(int(id) for id in itertools.chain.from_iterable(citers.values()))

f"Number of unique citees: {len(citee_ids)}"

In [None]:
# Get full info for each citation
results = []
query_count = 1000
for expr in build_expr(citee_ids.union(citers), 'Id'):
    count, offset = query_count, 0
    # Do until no results left
    while count == query_count:
        _result = query_mag_api(expr, fields=['Id', 'J.JN', 'D', 'DN', 'DOI', 'CC', 'F.FN'], 
                                subscription_key=AWS_SUBSCRIPTION_KEY, 
                                offset=offset, query_count=query_count)['entities']      
        
        count = len(_result)
        offset += count
        results += _result
        
# Data quality: check that we returned all of the citation IDs
returned_ids = {r['Id'] for r in results}
len(citee_ids - returned_ids), len(set(citers) - returned_ids)  # <-- these should be zero!

## 2) Save the citation information for later

In [None]:
# Look up for flattened variable names
field_dictionary = {'CC': 'citations', 
                    'D': 'date',
                    'DN': 'title',
                    'F': lambda x: {'fields_of_study': [_x['FN'] for _x in x]},
                    'Id': 'mag_id',
                    'J': lambda x: {'journal_title': x['JN']}}

# Mapping of all article ids (both citers and citees) --> flattened article data
articles = {}
for r in results:
    article = {}
    # Convert the field names from MAG to something legible
    for mag_key, field in field_dictionary.items():
        # Ignore this MAG field if the result doesn't have it!
        if mag_key not in r:
            continue
        # If the mapping is str --> value
        if type(field) is str:
            article[field] = r[mag_key]
        # Otherwise assume that the mapping is a lambda function
        else:
            article.update(field(r[mag_key]))
    articles[r['Id']] = article

# Mapping of all article ids (both citers and citees) --> flattened article data
with open(f'{PROJECT_PATH}/data/processed/ai_research/ai_article_mag_info.json', 'w') as f:
    f.write(json.dumps(articles))

# Citer ids. Together with `articles` you've got everything you need
with open(f'{PROJECT_PATH}/data/processed/ai_research/citation_lookup.json', 'w') as f:
    f.write(json.dumps(citers))