In [None]:
import json
import subprocess

from collections import defaultdict
from sys import getsizeof

### New API

API key is required

In [None]:
RELEASE = 'latest'
DATASET = 'citations'
API_KEY = 'NpToKVNU472E8fiIPEu6N1KLG1ECGEju6xZe3wgN'

## API -> Direct Co-Citations

In [None]:
import networkx as nx
import pandas as pd
import requests

In [None]:
PAPER_URL = 'https://api.semanticscholar.org/graph/v1/paper/{0}?fields={1}'
PAPER_CITATION_FIELDS = 'citationCount,citations.paperId,citations.citationCount,' + \
    'citations.title,citations.url,citations.authors,citations.journal,citations.year'
CITATIONS_URL = 'https://api.semanticscholar.org/graph/v1/paper/{0}/citations?fields={1}&offset={2}&limit={3}'
CITATIONS_FIELDS = 'paperId,url,title,citationCount,authors,journal,year'
MAX_LIMIT = 1000
MAX_TOTAL_COUNT = 10000

In [None]:
groups = [
    # [
    #     '10.1016/j.neuroimage.2011.01.057'
    # ],
    [
        # '10.1038/nn.3101',
        '10.1016/j.neuroimage.2012.03.048'
    ],
    [
        # '10.1002/hbm.20346',
        # '10.1016/j.clinph.2004.04.029',
        # '10.1016/j.neuroimage.2011.11.084',
        # '10.1016/j.neuroimage.2011.01.055',
        # '10.1002/(SICI)1097-0193(1999)8:4<194::AID-HBM4>3.0.CO;2-C',
        '10.1103/PhysRevLett.100.234101',
        # '10.1098/rsta.2011.0081'
    ]
]

In [None]:
def get_paper_data(doi, fields):
    return requests.get(PAPER_URL.format(doi, fields)).json()

In [None]:
def get_citation_data(doi, fields, offset, limit=MAX_LIMIT):
    if offset + limit >= MAX_TOTAL_COUNT:
        limit = MAX_TOTAL_COUNT - offset - 1
    return requests.get(CITATIONS_URL.format(doi, fields, offset, limit)).json()

In [None]:
def authors_compact(authors):
    if not authors:
        return ''
    if len(authors) <= 3:
        return ', '.join([a['name'] for a in authors])
    else:
        return f"{authors[0]['name']} et al."

In [None]:
def journal_compact(journal):
    if not journal or 'name' not in journal:
        return ''
    result = journal['name']
    if 'volume' in journal:
        result += f" {journal['volume']}"
    if 'pages' in journal:
        result += f":{journal['pages']}"
    return result

In [None]:
def build_graph(groups):
    graph = nx.DiGraph()
    node_groups = []
    for group_papers in groups:
        nodes = []
        for paper_doi in group_papers:
            # Get paper data
            paper_data = get_paper_data(paper_doi, PAPER_CITATION_FIELDS)
            paper_id = paper_data.get('paperId', None)
            if not paper_id:
                print(f'Failed to find paper: {paper_doi}')
                continue

            citations = paper_data.get('citations', [])
            num_cit_retrieved = len(citations)
            num_cit_total = paper_data['citationCount']
            print(paper_doi)
            print(f'Retrieved {num_cit_retrieved} / {num_cit_total} citations')
            while num_cit_retrieved < num_cit_total:
                new_citations = get_citation_data(paper_doi, CITATIONS_FIELDS, num_cit_retrieved)
                citations.extend([cit['citingPaper'] for cit in new_citations.get('data', [])])
                num_cit_retrieved = new_citations.get('next', None)
                if num_cit_retrieved:
                    assert len(citations) == num_cit_retrieved, paper_doi
                else:
                    assert len(citations) == num_cit_total, paper_doi
                    num_cit_retrieved = num_cit_total
                print(f'Retrieved {num_cit_retrieved} / {num_cit_total} citations')

            # Add reversed edges to the graph for BFS
            citation_nodes = [(cit['paperId'], 
                              dict(citationCount=cit['citationCount'],
                                   url=cit['url'], title=cit['title'],
                                   year=cit.get('year', ''), 
                                   authors=authors_compact(cit.get('authors', [])),
                                   journal=journal_compact(cit.get('journal', {}))))
                              for cit in citations
                              if cit['paperId']]
            citation_edges = [(paper_id, cit['paperId']) 
                              for cit in citations
                              if cit['paperId']]  # skip None?

            # Add the node and incoming edges to the graph
            nodes.append(paper_id)
            graph.add_node(paper_id, citationCount=num_cit_total)            
            graph.add_nodes_from(citation_nodes)
            graph.add_edges_from(citation_edges)

        node_groups.append(nodes)

    return graph, node_groups

In [None]:
graph, node_groups = build_graph(groups)

In [None]:
print(graph.number_of_nodes())
print(graph.number_of_edges())
print(node_groups)

In [None]:
crosschecked = set(graph.nodes)
for group_nodes in node_groups:
    reachable = nx.bfs_layers(graph, group_nodes)
    print(next(reachable))  # skip group nodes in layer 0
    crosschecked &= set(next(reachable))

In [None]:
len(crosschecked)

In [None]:
crosschecked_data = [dict(paperId=n[0], **n[1]) for n in graph.nodes(data=True) if n[0] in crosschecked]

In [None]:
crosschecked_df = pd.DataFrame(crosschecked_data)

In [None]:
crosschecked_df.sort_values(by='citationCount', ascending=False).head(n=20)

### Old Way

In [None]:
DOWNLOAD_URL = 'https://s3-us-west-2.amazonaws.com/ai2-s2-research-public/open-corpus/2022-05-01/s2-corpus-{0:03d}.gz'
FILE_NAME_GZ = 's2-corpus-{0:03d}.gz'
FILE_NAME = 's2-corpus-{0:03d}'

In [None]:
def to_numeric(sha1_id):
    if sha1_id not in paper_ids:
        new_id = len(paper_ids)
        paper_ids[sha1_id] = new_id
        paper_index[new_id] = sha1_id
    return paper_ids[sha1_id]

In [None]:
def update_citations(corpus_id):
    p = subprocess.Popen(['wget', DOWNLOAD_URL.format(corpus_id)], 
                         stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    p.communicate()
    p.wait()
    
    p = subprocess.Popen(['gzip', '-d', FILE_NAME_GZ.format(corpus_id)], 
                         stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    p.communicate()
    p.wait()
    
    filename = FILE_NAME.format(corpus_id)
    with open(filename, 'r') as f:
        for line in f:
            data = json.loads(line)
            paper_id = to_numeric(data['id'])
            for cit_id in data['inCitations']:
                citations[to_numeric(cit_id)].add(paper_id)
            citations[paper_id].update(set(to_numeric(cit_id) for cit_id in data['outCitations']))
            
    p = subprocess.Popen(['rm', filename], 
                         stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    p.communicate()
    p.wait()

In [None]:
citations = defaultdict(set)
paper_ids = {}
paper_index = {}

print(f'Corpus ID\tNumPapers\tSizePapers\tNumCitations\tSizeCitations')
for corpus_id in range(10):
    update_citations(corpus_id)
    num_papers = len(paper_index)
    size_papers = getsizeof(paper_ids) + getsizeof(paper_index)
    num_citations = sum([len(s) for s in citations.values()])
    size_citations = getsizeof(citations)
    print(f'{corpus_id}\t{num_papers}\t{size_papers}\t{num_citations}\t{size_citations}')
    if size_papers + size_citations > 500_000_000:
        print('Alarm!')
        break