In [1]:
from pyspark.sql import SparkSession
import requests
import json
import time

In [2]:
# Initialize Spark session
spark = SparkSession.builder.appName("HuggingFaceAPICall").master("spark://spark-master:7077").config("spark.cores.max", "2").config("spark.executor.memory", "512m").config("spark.eventLog.enabled", "true").config("spark.eventLog.dir", "file:///opt/workspace/events").getOrCreate()

# Read JSON files from HDFS
df = spark.read.json("all_papers.json")

# Define function to make Hugging Face API call
def call_hugging_face_api(text, model="oracat/bert-paper-classifier"):
    api_url = f"https://api-inference.huggingface.co/models/{model}"
    headers = {"Authorization": "Bearer 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"}
    wait_time = 1
    for attempt in range(3):
        try:
            response = requests.post(api_url, headers=headers, json={"inputs": text})
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt < 3 - 1:
                print(f"Waiting for {wait_time} seconds before retrying...")
                time.sleep(wait_time)
                wait_time *= 2  # Exponential backoff
            else:
                print("All attempts failed. Raising exception.")
                raise

24/05/23 19:54:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/23 19:54:55 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

In [3]:
# Define transformation function
def transform_row(row):
    json_id = row['id']
    text = row['title']
    result = call_hugging_face_api(text)
    print(result)
    # Extract the first label from the result
    first_label = result[0][0]['label'] if type(result) == list else 'Not detected'

    return (json_id, first_label)

# Apply transformation function to each row
transformed_data = df.rdd.map(transform_row)

# Convert RDD to DataFrame
transformed_df = transformed_data.toDF(["id", "theme"])

                                                                                

In [4]:
# Save DataFrame to CSV
transformed_df.coalesce(1).write.csv("tabular/paper_theme", header=True, mode="overwrite")

                                                                                

In [5]:
spark.stop()