In [1]:
import os
from glob import glob

from pyspark.sql import SparkSession

import pyspark.sql.functions as F

from pyspark.sql.types import (
    ArrayType,
    IntegerType,
    MapType,
    StringType,
    StructField,
    StructType,
    FloatType,
    TimestampType,
    BooleanType,
    DateType
)

In [3]:
%%capture
spark = SparkSession.builder.appName('BDP').getOrCreate()

In [4]:
def generate_schema():
    """
    JSON schema went through a few modifications across different versions of the dataset.
    In this work, we're finalizing the dataset version 111 and using its schema as final.
    The schema provided along with the dataset acts as a reference point however there were
    a few falsely nested structures in it which has been corrected by us. Further, there are
    quite a lot of fields but we'll extract only the ones we need for our project.
    
    Original schema:
    https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-03-13/json_schema.txt
    """
    
    # Extract author information, although not needed for modelling but good to have for data analysis
    authors_schema = ArrayType(
        StructType(
            [
                StructField("first", StringType()),
                StructField("middle", ArrayType(StringType())),
                StructField("last", StringType()),
                StructField("suffix", StringType())
            ]
        )
    )
    
    # Extract different sections such as abstract and body text
    section_schema = ArrayType(
        StructType(
            [
                StructField("text", StringType()),
                StructField("section", StringType())
            ]
        )
    )

    schema = StructType(
        [
            StructField("paper_id", StringType()),
            StructField(
                "metadata",
                StructType(
                    [
                        StructField("title", StringType()),
                        StructField("authors", authors_schema)
                    ]
                ),
                True,
            ),
            StructField("abstract", section_schema),
            StructField("body_text", section_schema),
        ]
    )
    
    return schema


def jsons_to_df(spark, paths, schema):
    return spark.read.json(paths, schema=schema, multiLine=True)

In [5]:
DATA_PATH = "gs://bdp_group6_bckt_2/data/"
SAVE_PATH = "gs://bdp_group6_bckt_2/data/processed_data"

filepaths = glob(os.path.join(DATA_PATH, "document_parses", "**", "*.json"))
print("Total number of papers:", len(filepaths))

Total number of papers: 716956


In [6]:
schema = generate_schema()
df = jsons_to_df(spark, paths=filepaths, schema=schema)

print("--- Schema before processing ---")
df.printSchema()

# Get title
df = df.withColumn('json_title', F.col('metadata.title'))

# Get authors
df = df.withColumn('firstnames', F.col('metadata.authors.first'))  # First name
df = df.withColumn('middlenames', F.col('metadata.authors.middle'))  # Middle name

# Middle names are list of list. Convert it to a list of strings
@F.udf(returnType=ArrayType(StringType()))
def parse_middlenames(array):
    return [" ".join(w).strip() for w in array]

df = df.withColumn('middlenames', parse_middlenames("middlenames"))
df = df.withColumn('lastnames', F.col('metadata.authors.last'))  # Last name
df = df.withColumn('suffixes', F.col('metadata.authors.suffix'))  # Suffix name

# Concat the first, middle, last and suffix names for each author
df = df.withColumn("json_authors", F.expr(
    "transform(firstnames, (x, i) -> concat(x, ' ', middlenames[i], ' ', lastnames[i], ' ', suffixes[i]))"))

# Concat the list of authors into one "; " separated string
df = df.withColumn('json_authors', F.concat_ws("; ", F.expr("transform(json_authors, x -> trim(x))")))

# Remove additional empty spaces from the names
@F.udf()
def parse_authors(array):
    return " ".join([x for x in array.split() if x.strip()])

df = df.withColumn('json_authors', F.concat(parse_authors("json_authors")))

# Get abstract
df = df.withColumn('json_abstract', F.col('abstract.text'))
df = df.withColumn('json_abstract', F.concat_ws(". ", 'json_abstract'))

# Get body text
df = df.withColumn('body_text', F.col('body_text.text'))
df = df.withColumn('body_text', F.concat_ws(". ", 'body_text'))

df = df.drop(*["metadata", "abstract", "firstnames", "middlenames", "lastnames", "suffixes"])

print()
print("--- Schema after processing ---")
df.printSchema()

                                                                                

--- Schema before processing ---
root
 |-- paper_id: string (nullable = true)
 |-- metadata: struct (nullable = true)
 |    |-- title: string (nullable = true)
 |    |-- authors: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- first: string (nullable = true)
 |    |    |    |-- middle: array (nullable = true)
 |    |    |    |    |-- element: string (containsNull = true)
 |    |    |    |-- last: string (nullable = true)
 |    |    |    |-- suffix: string (nullable = true)
 |-- abstract: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- section: string (nullable = true)
 |-- body_text: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- section: string (nullable = true)


--- Schema after processing ---
root
 |-- paper_id: string (nullable = true)
 |-- body_text: string (nulla

In [7]:
# Load the metadata
metadata = spark.read.csv(os.path.join(INPUT_PATH, "metadata.csv"), inferSchema=True, header=True, multiLine=True)
print("Total number of metadata papers:", metadata.count())

[Stage 2:>                                                          (0 + 1) / 1]

Total number of metadata papers: 1056660


                                                                                

In [8]:
metadata.printSchema()

root
 |-- cord_uid: string (nullable = true)
 |-- sha: string (nullable = true)
 |-- source_x: string (nullable = true)
 |-- title: string (nullable = true)
 |-- doi: string (nullable = true)
 |-- pmcid: string (nullable = true)
 |-- pubmed_id: string (nullable = true)
 |-- license: string (nullable = true)
 |-- abstract: string (nullable = true)
 |-- publish_time: string (nullable = true)
 |-- authors: string (nullable = true)
 |-- journal: string (nullable = true)
 |-- mag_id: string (nullable = true)
 |-- who_covidence_id: string (nullable = true)
 |-- arxiv_id: string (nullable = true)
 |-- pdf_json_files: string (nullable = true)
 |-- pmc_json_files: string (nullable = true)
 |-- url: string (nullable = true)
 |-- s2_id: string (nullable = true)



In [9]:
# Rename a few columns
metadata = metadata.withColumnRenamed("authors", "metadata_authors")
metadata = metadata.withColumnRenamed("abstract", "metadata_abstract")
metadata = metadata.withColumnRenamed("title", "metadata_title")

In [11]:
@F.udf()
def split_path(s):
    if s:
        return s.split("/")[-1].split(".json")[0].split(".xml")[0]
    return s

metadata = metadata.withColumn("pdf_json_files", split_path("pdf_json_files"))
metadata = metadata.withColumn("pmc_json_files", split_path("pmc_json_files"))

metadata = metadata.drop(*["sha", "license", "mag_id", "who_covidence_id", "url", "s2_id",
                           "arxiv_id", "doi", "pmcid", "pubmed_id"])

In [12]:
metadata.printSchema()

root
 |-- cord_uid: string (nullable = true)
 |-- source_x: string (nullable = true)
 |-- metadata_title: string (nullable = true)
 |-- metadata_abstract: string (nullable = true)
 |-- publish_time: string (nullable = true)
 |-- metadata_authors: string (nullable = true)
 |-- journal: string (nullable = true)
 |-- pdf_json_files: string (nullable = true)
 |-- pmc_json_files: string (nullable = true)



In [16]:
# Create a paper_id column consistent with the JSON files
metadata = metadata.withColumn("paper_id", F.when(F.col("pmc_json_files").isNull() | F.isnan("pmc_json_files"),
                                                 metadata.pdf_json_files).otherwise(metadata.pmc_json_files))

In [17]:
# Merge the JSON data and metadata files on the paper_id column
merged = metadata.join(df, on="paper_id", how="inner")

# Decide on which title, abstract, and authors column to trust
merged = merged.withColumn("title",
                            F.when(F.isnan("metadata_title") | F.col("metadata_title").isNull(),
                                   merged.json_title).otherwise(merged.metadata_title))
merged = merged.withColumn("abstract",
                            F.when(F.isnan("metadata_abstract") | F.col("metadata_abstract").isNull(),
                                   merged.json_abstract).otherwise(merged.metadata_abstract))
merged = merged.withColumn("authors",
                            F.when(F.isnan("metadata_authors") | F.col("metadata_authors").isNull(),
                                   merged.json_authors).otherwise(merged.metadata_authors))

# Delete redundant columns
merged = merged.drop(*["metadata_title", "json_title", "metadata_abstract", "json_abstract",
                       "metadata_authors", "json_authors", "pdf_json_files", "pmc_json_files"])

In [18]:
merged.printSchema()

root
 |-- paper_id: string (nullable = true)
 |-- cord_uid: string (nullable = true)
 |-- source_x: string (nullable = true)
 |-- publish_time: string (nullable = true)
 |-- journal: string (nullable = true)
 |-- body_text: string (nullable = false)
 |-- title: string (nullable = true)
 |-- abstract: string (nullable = true)
 |-- authors: string (nullable = true)



In [19]:
# Check for missing values anywhere
merged.select([F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) for c in merged.columns]).show()



+--------+--------+--------+------------+-------+---------+-----+--------+-------+
|paper_id|cord_uid|source_x|publish_time|journal|body_text|title|abstract|authors|
+--------+--------+--------+------------+-------+---------+-----+--------+-------+
|       0|       0|       0|           0|      0|        0|    0|       0|      0|
+--------+--------+--------+------------+-------+---------+-----+--------+-------+



                                                                                

In [21]:
print("Final number of papers:", merged.count())

Final number of papers: 598169


In [20]:
# Save data in parquet format
merged.write.parquet(os.path.join(SAVE_PATH, "processed_data.parquet"), mode="overwrite")