In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import json
import jsonlines
from tqdm.notebook import tqdm

import torch
import sentence_splitter
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

datafiles = []
for dirname, _, filenames in os.walk('/cord19'):
    for filename in filenames:
        ifile = os.path.join(dirname, filename)
        if ifile.split(".")[-1] == "json":
            datafiles.append(ifile)
        #print(ifile)
        
os.mkdir("./files") # location of output files

In [None]:
def get_affiliations(authors_json):
    affiliations = []
    
    for author_meta in authors_json:
        first = author_meta["first"]
        middle = " ".join(author_meta["middle"])
        last = author_meta["last"]
#         print(author_meta)
        try:
            if author_meta["affiliation"]["institution"] != "":
                affiliation = author_meta["affiliation"]["institution"]
            elif author_meta["affiliation"]["laboratory"] != "":
                affiliation = author_meta["affiliation"]["laboratory"]
            else:
                continue
            affiliations.append([" ".join([first,middle,last]), affiliation])
        except:
            continue
    return affiliations

In [None]:
# load pre-trained model
config = BertConfig.from_pretrained('allenai/scibert_scivocab_cased')
config.num_labels = 3
config.use_bfloat16 = True
tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_cased', config=config)
model = BertForSequenceClassification.from_pretrained('model.pt', config=config)

In [None]:
START = 22250

model.to(device)
model.eval()

label_map = {0: "Method", 1: "Background", 2: "Result"}

data = []

json_files = []

parsed_papers = []
try:
    with open("./files/parsed_papers.txt", "r") as f:
        content = f.read()
        parsed_papers = list(set(content.split("\n")))
except:
    pass

for j,path in enumerate(tqdm(datafiles[START:])):
    with open(path, "r") as infile:
        paper_json = json.load(infile)

# for paper_json in tqdm(json_files):
    paper_id = paper_json["paper_id"]
    
    if paper_id in parsed_papers: 
        continue
        
    paper_title = paper_json["metadata"]["title"]
    affiliations = get_affiliations(paper_json["metadata"]["authors"])
    bib_entries = paper_json["bib_entries"]
    
    citations = []
    for text_json in paper_json["body_text"]:
        if not "cite_spans":
            continue
        
        text = text_json["text"]
        section = text_json["section"]
        sentences = sentence_splitter.split_text_into_sentences(text, language="en")
        
        
        for cite_span in text_json["cite_spans"]:
            cite_start = cite_span["start"]
            cite_end = cite_span["end"]
            ref_id = cite_span["ref_id"]
            
            citing_sentence = ""
            span = 0
            for i,sent in enumerate(sentences):
                if cite_start in range(span,span+len(sent)+1) and cite_end in range(span,span+len(sent)+1):
                    citing_sentence = sent
                    break
                else:
                    span = len(sent)
            
            if citing_sentence == "":
                continue
            
            # classify intent
            tokens = tokenizer.encode_plus(sent, truncation=True, max_length=100, padding="max_length", return_tensors="pt")
            logits = model(**tokens.to(device))[0]
            pred = torch.softmax(logits.detach().cpu(), dim=1)
            label = int(torch.argmax(pred))
            
            intent = label_map[label]
            
            citations.append({"ref_id": ref_id, 
                              "section": section, 
                              "intent": intent, 
                              "sentence": citing_sentence})
    
    for i,citation in enumerate(citations):
        ref_id = citation["ref_id"]
        
        for k,v in bib_entries.items():
            if k == ref_id:
                citation["bib_item"] = v
                break
        
        if "bib_item" in citation:
            citations[i] = citation
        
    data.append({"paper_id": paper_id, 
                 "paper_title": paper_title,
                 "citations": citations, 
                 "affiliations": affiliations})
    parsed_papers.append(paper_id)
    
    if j % 250 == 0:
        with jsonlines.open(f"./files/parsed_data_{str(START+j)}.jsonl", "w") as writer:
            writer.write(data)
            
        with open(f"./files/parsed_papers_{str(START+j)}.txt", "a") as f:
            for pid in parsed_papers:
                f.write(pid+"\n")
        data = []