In [0]:
# Standard library imports
import os

# Third-party library imports
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType, StructType, StructField, FloatType

load_dotenv()

In [0]:
catalog_name = os.getenv('DATABRICKS_CATALOG_NAME')
schema_name = os.getenv('DATABRICKS_SCHEMA_NAME')

In [0]:
# Load the FinBERT model and tokenizer
model_name = "ProsusAI/finbert"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
classifier = pipeline('text-classification', model=model, tokenizer=tokenizer)

# Create a UDF for sentiment analysis
@udf(returnType=StructType([
    StructField("sentiment", StringType(), True),
    StructField("sentiment_score", FloatType(), True)
]))

def sentiment_analysis_udf(text):
    result = classifier(text)[0]
    return (result['label'], float(result['score']))

## Sentiment Analysis of Reddit "Top" Posts

In [0]:
# Apply the UDF and generate sentiment analysis results for "Top" posts
reddit_top_posts_df = spark.table(f"{catalog_name}.{schema_name}.kdayno_bronze_reddit_top_posts")

reddit_top_posts_sentiment_df = reddit_top_posts_df.withColumn("post_title_sentiment_result", sentiment_analysis_udf(col('post_title')))
reddit_top_posts_sentiment_df = reddit_top_posts_sentiment_df.select("*", "post_title_sentiment_result.*")

In [0]:
(reddit_top_posts_sentiment_df.write.format("delta")
                                    .mode("overwrite")
                                    .partitionBy('subreddit')
                                    .saveAsTable(f"{catalog_name}.{schema_name}.kdayno_silver_reddit_top_posts_sentiment"))

## Sentiment Analysis of Reddit "Hot" Posts

In [0]:
# Apply the UDF and generate sentiment analysis results for "Hot" posts
reddit_hot_posts_df = spark.table(f"{catalog_name}.{schema_name}.kdayno_bronze_reddit_hot_posts")

reddit_hot_posts_sentiment_df = reddit_hot_posts_df.withColumn("post_title_sentiment_result", sentiment_analysis_udf(col('post_title')))
reddit_hot_posts_sentiment_df = reddit_hot_posts_sentiment_df.select("*", "post_title_sentiment_result.*")

In [0]:
(reddit_hot_posts_sentiment_df.write.format("delta")
                                    .mode("overwrite")
                                    .partitionBy('subreddit')
                                    .saveAsTable(f"{catalog_name}.{schema_name}.kdayno_silver_reddit_hot_posts_sentiment"))