In [1]:
from pyspark.sql import SparkSession
import requests
import json
import re
from concurrent.futures import ThreadPoolExecutor
from requests.adapters import HTTPAdapter

In [2]:
def clean_publisher(publisher):
    '''
    Function to clean the publisher name
    
    Parameters:
    publisher (str): the name of the publisher
    
    Returns:
    cleaned_publisher (str): the cleaned name of the publisher
    '''
    # Remove "Published by:" text
    publisher = publisher.replace("Published by:", "")
    # Remove content within parentheses or square brackets
    publisher = re.sub(r'\([^)]*\)|\[[^\]]*\]', '', publisher)
    # Remove surrounding single or double quotes
    publisher = publisher.strip('"\'')
    # If publisher contains a comma or semicolon, split by comma or semicolon and take the first part
    if ',' in publisher:
        publisher = publisher.split(',')[0].strip()
    elif ';' in publisher:
        publisher = publisher.split(';')[0].strip()
    # If publisher contains a forward slash (/), split by forward slash and take the first part
    if '/' in publisher:
        publisher = publisher.split('/')[0].strip()
    # Remove periods
    publisher = publisher.replace('.', '')
    # Remove trailing numbers
    publisher = re.sub(r'\d+$', '', publisher)
    return publisher.strip()

def get_publishers(json_file):
    '''
    Function to get all 'clean' publishers from a JSON file
    
    Parameters:
    json_file (str): the path to the JSON file
    
    Returns:
    publishers (list): a list of tuples (id, publisher)
    '''
    publishers = []
    with open(json_file, 'r') as f:
        for line in f:
            try:
                data = json.loads(line)
                publisher = str(data.get('publisher', ''))
                idp = str(data.get('id', ''))
                cleaned_publisher = clean_publisher(publisher)
                if cleaned_publisher:  # Check if cleaned publisher is not an empty string
                    publishers.append((idp, cleaned_publisher))
                else:
                    publishers.append((idp, 'Unknown'))
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON: {line}")
                continue
    return publishers

def get_publisher_description(publisher_name, session):
    '''
    Function to get the description of a publisher from Wikidata
    
    Parameters:
    publisher_name (str): the name of the publisher
    session (requests.Session): a session object to make HTTP requests

    Returns:
    description (str): the description of the publisher

    '''
    if publisher_name == 'Unknown':
        return 'No description'

    url = "https://www.wikidata.org/w/api.php"
    params = {
        "action": "wbsearchentities",
        "format": "json",
        "language": "en",
        "search": publisher_name
    }

    try:
        with session.get(url, params=params) as response:
            data = response.json()
            if data.get("search"):
                try:
                    description = data["search"][0]["description"]
                    return description
                except:
                    return 'No description'
            else:
                return 'No description'
    except Exception as e:
        print(f"Failed to fetch description for {publisher_name}: {e}")
        return 'No description'
    
def save_publisher_descriptions(publishers):
    '''
    Function to save the descriptions of all the publishers to a CSV file

    Parameters:
    publishers (list): a list of tuples (id, publisher)
    '''
    # Initialize Spark session
    spark = SparkSession.builder.appName("papers_publisher").master("spark://spark-master:7077").config("spark.cores.max", "2").config("spark.executor.memory", "512m").config("spark.eventLog.enabled", "true").config("spark.eventLog.dir", "file:///opt/workspace/events").getOrCreate()

    # Convert list of tuples to DataFrame
    publishers_df = spark.createDataFrame(publishers, ['id', 'publisher'])
    
    with requests.Session() as session:
        session.mount('https://', HTTPAdapter(max_retries=3))
        # Use map transformation to get descriptions for each publisher
        descriptions_rdd = publishers_df.rdd.map(lambda row: (row['id'],row['publisher'], get_publisher_description(row['publisher'], session)))

    # Convert RDD back to DataFrame
    descriptions_df = descriptions_rdd.toDF(['id', 'publisher', 'description'])
    descriptions_df.coalesce(1).write.csv("dataout/paper_publishers", header=True, mode="overwrite")
    # Stop Spark session
    spark.stop()

In [3]:
publishers = get_publishers("all_papers.json")
save_publisher_descriptions(publishers)

24/05/23 08:10:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
                                                                                