## Pipeline Parameters

In [0]:
dbutils.widgets.text('input_subreddit_list', "['stocks', 'investing', 'trading', 'wallstreetbets']")
dbutils.widgets.dropdown('input_sort_type', 'top', ['relevance', 'hot', 'top', 'new', 'comments'])
dbutils.widgets.dropdown('input_time_filter', 'month', ['day', 'month', 'year'])
dbutils.widgets.text("input_number_of_posts_limit", '500')
dbutils.widgets.text("bronze_table_name", 'kdayno_bronze_reddit_top_posts')

# Audit Parameters
dbutils.widgets.text("job_id", "")
dbutils.widgets.text("job_name", "")
dbutils.widgets.text("job_start_date", "")
dbutils.widgets.text("job_start_datetime", "")
dbutils.widgets.text("task_run_id", "")
dbutils.widgets.text("task_name", "")

In [0]:
%run ../utils/loggers

In [0]:
# Standard library imports
import ast
import datetime as dt
import os
import time

# Third-party library imports
import asyncpraw
import pandas as pd
from dotenv import load_dotenv
from pyspark.sql.functions import current_timestamp
from pyspark.sql.types import (
    DateType,
    BooleanType,
    FloatType,
    IntegerType,
    StringType,
    StructField,
    StructType,
    TimestampType
)

In [0]:
# ETL Inputs
input_subreddit_list = ast.literal_eval(dbutils.widgets.get("input_subreddit_list"))
input_time_filter = dbutils.widgets.get("input_time_filter")
input_sort_type = dbutils.widgets.get("input_sort_type")
input_number_of_posts_limit = int(dbutils.widgets.get("input_number_of_posts_limit"))

load_dotenv()

source_table_name = 'kdayno_bronze_SP500_companies'
target_bronze_table_name = dbutils.widgets.get("bronze_table_name")
catalog_name = os.getenv('DATABRICKS_CATALOG_NAME')
schema_name = os.getenv('DATABRICKS_SCHEMA_NAME')

# Reddit API Credentials
reddit_api_client_id = os.getenv('REDDIT_API_CLIENT_ID')
reddit_api_client_secret = os.getenv('REDDIT_API_CLIENT_SECRET')
reddit_api_user_agent = os.getenv('REDDIT_API_USER_AGENT')

# Audit Variables
job_id = dbutils.widgets.get('job_id')
job_name = dbutils.widgets.get('job_name')
job_start_date = dbutils.widgets.get('job_start_date')
job_start_datetime = dbutils.widgets.get('job_start_datetime')
task_run_id = dbutils.widgets.get('task_run_id')
task_name = dbutils.widgets.get('task_name')

reddit_post_schema = StructType([
    StructField("company_name", StringType(), True),
    StructField("post_id", StringType(), True),
    StructField("post_title", StringType(), True),
    StructField("subreddit_id", StringType(), True),
    StructField("subreddit", StringType(), True),
    StructField("created_utc", DateType(), True),
    StructField("score", IntegerType(), True),
    StructField("upvote_ratio", FloatType(), True),
    StructField("num_comments", IntegerType(), True),
    StructField("post_body_text", StringType(), True),
    StructField("is_self_post", BooleanType(), True),
    StructField("is_original_content", BooleanType(), True),
    StructField("permalink", StringType(), True),
    StructField("post_url", StringType(), True)
])

## Pipeline Logging

In [0]:
input_params_at_run_time = [input_subreddit_list, input_time_filter, input_sort_type, input_number_of_posts_limit]

audit_logger(job_id, job_name, input_params_at_run_time, job_start_date, job_start_datetime, task_run_id,  task_name, source_table_name, target_bronze_table_name)

etl_logger = etl_logger()

## Get S&P 500 Company Names

In [0]:
SP500_companies = (spark.read.table(f'{catalog_name}.{schema_name}.{source_table_name}')
                  .select('company_name'))

SP500_companies_list = [row['company_name'] for row in SP500_companies.collect()]

## Extract and Load Posts data from multiple Subreddits

In [0]:
reddit = asyncpraw.Reddit(
    client_id=reddit_api_client_id,
    client_secret=reddit_api_client_secret,
    user_agent=reddit_api_user_agent,
    ratelimit_seconds=600
)

subreddits = input_subreddit_list
keywords = SP500_companies_list

In [0]:
spark.sql(f"""DELETE FROM {catalog_name}.{schema_name}.{target_bronze_table_name}""")

In [0]:
keyword_count = 0

for subreddit in subreddits:

    etl_logger.info(f'Processing subreddit: r/{subreddit} ...')

    sr = await reddit.subreddit(f"{subreddit}")

    post_data = {'company_name':[], 'post_id':[], 'post_title':[], 'subreddit_id':[], 'subreddit':[],  
                 'created_utc':[], 'score':[], 'upvote_ratio':[],'num_comments':[], 'post_body_text':[], 
                 'is_self_post':[], 'is_original_content':[], 'permalink':[], 'post_url':[] }

    for keyword in keywords:  

        etl_logger.info(f'Getting data for company: {keyword} ...')

        posts = sr.search(keyword, sort=input_sort_type, time_filter=input_time_filter, limit=input_number_of_posts_limit)

        async for post in posts:

            etl_logger.info(f'Getting data for post: {post.title} ...')

            post_data['company_name'].append(str(keyword))
            post_data['post_id'].append(post.id)
            post_data['post_title'].append(post.title) 
            post_data['subreddit_id'].append(post.subreddit_id)
            post_data['subreddit'].append(str(post.subreddit))
            post_data['created_utc'].append(dt.datetime.fromtimestamp(post.created_utc))
            post_data['score'].append(post.score)
            post_data['upvote_ratio'].append(post.upvote_ratio)
            post_data['num_comments'].append(post.num_comments)
            post_data['post_body_text'].append(post.selftext)
            post_data['is_self_post'].append(post.is_self)
            post_data['is_original_content'].append(post.is_original_content)
            post_data['permalink'].append(post.permalink)
            post_data['post_url'].append(post.url)
        
        keyword_count += 1
        if keyword_count % 100 == 0:
            etl_logger.info("Pausing for 60 seconds to avoid hitting API rate limits...")
            time.sleep(60)

    # Unpacks the dict values, creates tuples, then converts to a list of tuples, where each tuple contains the data for a given Reddit post
    reddit_posts_df = spark.createDataFrame(list(zip(*post_data.values())), reddit_post_schema)
    
    reddit_posts_df = reddit_posts_df.withColumn('load_date_ts', current_timestamp())

    etl_logger.info(f'Loading: {reddit_posts_df.count()} rows to: {target_bronze_table_name}')

    (reddit_posts_df.write.format("delta")
                              .mode("append")
                              .partitionBy('subreddit')
                              .saveAsTable(f'{catalog_name}.{schema_name}.{target_bronze_table_name}'))