## Pipeline Parameters

In [0]:
dbutils.widgets.dropdown('input_time_filter', 'year', ['year', 'month', 'day'])

input_time_filter = dbutils.widgets.get("input_time_filter")

In [0]:
# Standard library imports
import datetime as dt
import os
import time

# Third-party library imports
import asyncpraw
import pandas as pd
from dotenv import load_dotenv
from pyspark.sql.functions import current_timestamp
from pyspark.sql.types import (
    DateType,
    BooleanType,
    FloatType,
    IntegerType,
    StringType,
    StructField,
    StructType,
    TimestampType
)

load_dotenv()

reddit_api_client_id = os.getenv('REDDIT_API_CLIENT_ID')
reddit_api_client_secret = os.getenv('REDDIT_API_CLIENT_SECRET')
reddit_api_user_agent = os.getenv('REDDIT_API_USER_AGENT')
catalog_name = os.getenv('DATABRICKS_CATALOG_NAME')
schema_name = os.getenv('DATABRICKS_SCHEMA_NAME')

reddit = asyncpraw.Reddit(
    client_id=reddit_api_client_id,
    client_secret=reddit_api_client_secret,
    user_agent=reddit_api_user_agent,
    ratelimit_seconds=600
)

In [0]:
reddit_post_schema = StructType([
    StructField("company_name", StringType(), True),
    StructField("post_id", StringType(), True),
    StructField("post_title", StringType(), True),
    StructField("subreddit_id", StringType(), True),
    StructField("subreddit", StringType(), True),
    StructField("created_utc", DateType(), True),
    StructField("score", IntegerType(), True),
    StructField("upvote_ratio", FloatType(), True),
    StructField("num_comments", IntegerType(), True),
    StructField("post_body_text", StringType(), True),
    StructField("is_self_post", BooleanType(), True),
    StructField("is_original_content", BooleanType(), True),
    StructField("permalink", StringType(), True),
    StructField("post_url", StringType(), True)
])

## Get S&P 500 Company Names

In [0]:
SP500_companies = (spark.read.table(f'{catalog_name}.{schema_name}.kdayno_bronze_SP500_companies')
                  .select('company_name')
                )

SP500_companies_list = [row['company_name'] for row in SP500_companies.collect()]

## Extract and Load "Top" Posts data from multiple Subreddits

In [0]:
subreddits = ['stocks', 'investing', 'trading', 'wallstreetbets']
keywords = SP500_companies_list

In [0]:
spark.sql(f"""DELETE FROM {catalog_name}.{schema_name}.kdayno_bronze_reddit_top_posts""")

In [0]:
keyword_count = 0

for subreddit in subreddits:

    sr = await reddit.subreddit(f"{subreddit}")

    post_data = {'company_name':[], 'post_id':[], 'post_title':[], 'subreddit_id':[], 'subreddit':[],  
                 'created_utc':[], 'score':[], 'upvote_ratio':[],'num_comments':[], 'post_body_text':[], 
                 'is_self_post':[], 'is_original_content':[], 'permalink':[], 'post_url':[] }

    for keyword in keywords:  

        print(f'Getting data for company: {keyword} ...')

        posts = sr.search(keyword, sort='top', time_filter=input_time_filter, limit=500)

        async for post in posts:

            print(f'Getting data for post: {post.title} ...')

            post_data['company_name'].append(str(keyword))
            post_data['post_id'].append(post.id)
            post_data['post_title'].append(post.title) 
            post_data['subreddit_id'].append(post.subreddit_id)
            post_data['subreddit'].append(str(post.subreddit))
            post_data['created_utc'].append(dt.datetime.fromtimestamp(post.created_utc))
            post_data['score'].append(post.score)
            post_data['upvote_ratio'].append(post.upvote_ratio)
            post_data['num_comments'].append(post.num_comments)
            post_data['post_body_text'].append(post.selftext)
            post_data['is_self_post'].append(post.is_self)
            post_data['is_original_content'].append(post.is_original_content)
            post_data['permalink'].append(post.permalink)
            post_data['post_url'].append(post.url)
        
        keyword_count += 1
        if keyword_count % 100 == 0:
            print("Pausing for 60 seconds to avoid hitting API rate limits...")
            time.sleep(60)

    # Unpacks the dict values, creates tuples, then converts to a list of tuples, where each tuple contains the data for a given Reddit post
    reddit_top_posts_df = spark.createDataFrame(list(zip(*post_data.values())), reddit_post_schema)

    reddit_top_posts_df = reddit_top_posts_df.withColumn('load_date_ts', current_timestamp())

    (reddit_top_posts_df.write.format("delta")
                              .mode("append")
                              .partitionBy('subreddit')
                              .saveAsTable(f'{catalog_name}.{schema_name}.kdayno_bronze_reddit_top_posts'))