In [12]:
# Connect to the local cluster and initiate spark session

from pyspark.sql import SparkSession

spark_session = SparkSession.builder.appName("test_2").getOrCreate()

AttributeError: 'NoneType' object has no attribute 'sc'

In [None]:
# Import black for code formatting

import jupyter_black

jupyter_black.load()

In [None]:
# Load data and infer schema automatically from json keys

df_raw = spark_session.read.options(multiline=False, header=True).json(
    "sample_data.json"
)
df_raw.show(10)

In [None]:
# Print to inspect schema

df_raw.printSchema()

In [None]:
# Step 1: Filter out columns that will not be used

cols_to_keep = ["author", "body", "created_utc", "score", "subreddit", "subreddit_id"]

df_reddit = df_raw.select([col for col in cols_to_keep])
df_reddit.show()

In [None]:
# Step 2.1: Count the biggest subreddits (by number of posts)

df_subred_count = df_reddit.groupBy("subreddit").count()
df_subred_count.show()

In [None]:
# Step 2.2: Get list with top 100 subreddit ids

import pandas as pd

df_count_pd = df_subred_count.toPandas()

subs_to_incl = 100

df_top_subs = df_count_pd.sort_values(by="count", ascending=False).iloc[0:subs_to_incl]
top_subs = df_top_subs["subreddit"].tolist()
print(top_subs)

In [None]:
# Step 3.1: Filter data to only contain top subreddits

from pyspark.sql.functions import col

df_sub_filtered = df_reddit.filter(col("subreddit").isin(top_subs))
df_sub_filtered.show()

In [None]:
# Step 4.1: Find out who are active users

df_user_count = df_sub_filtered.groupBy("author").count()
df_user_count.show()

In [None]:
# Step 4.2: Find active and inactive users

comment_threshold = 1

df_top_users = df_user_count.filter(col("count") > comment_threshold)
top_users = df_top_users.select("author").rdd.flatMap(lambda x: x).collect()
print(top_users)

In [None]:
# Step 4.3: Filter out inactive users with few comments

df_user_filtered = df_sub_filtered.filter(col("author").isin(top_users))
df_user_filtered.show()

In [None]:
# Step 5.1: Create column with list of subreddits for each user

from pyspark.sql.functions import collect_set

df_user_subs = df_user_filtered.groupby("author").agg(
    collect_set("subreddit").alias("subreddit")
)
df_user_subs.show()

In [None]:
# Step 5.2: Create tuples from all those subreddits

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, DataType, StructType, StructField

tuple_schema = ArrayType(
    StructType(
        [
            StructField("tuple_1", StringType(), False),
            StructField("tuple_2", StringType(), False),
        ]
    )
)


def tuple_from_list(lst):
    return [(sub_1, sub_2) for sub_1 in lst for sub_2 in lst]


tuple_udf = udf(lambda x: tuple_from_list(x), tuple_schema)

df_user_subs = df_user_subs.withColumn("subreddit_tuples", tuple_udf(col("subreddit")))
df_user_subs.show()

In [None]:
# Step 5.3: Explode and get the count of each tuple

from pyspark.sql.functions import explode, count

# Explode the tuples into individual rows
df_exploded = df_user_subs.select(explode("subreddit_tuples").alias("tuple_col"))

# Group by the exploded tuples and count the occurrences of each tuple
df_tuple_counts = exploded_df.groupBy("tuple_col").agg(count("*").alias("count"))

# Show the results
df_tuple_counts.show()

In [None]:
# Step 5.4: Filter out entries where both tuple elements are the same

df_counts_filtered = df_tuple_counts.filter(
    ~(col("tuple_col").getField("tuple_1") == col("tuple_col").getField("tuple_2"))
)
df_counts_filtered.show()

In [None]:
# Step 5.6: Get rid of duplicates

result_clean = df_counts_filtered.rdd.map(
    lambda row: [(row[0][0], row[0][1]), row[1]]
).collect()

result_no_dupes = []
encountered_pairs = set()

for lst in result_clean:
    tup = lst[0]
    count = lst[1]
    sorted_tup = tuple(sorted(tup))

    if sorted_tup in encountered_pairs:
        continue

    encountered_pairs.add(sorted_tup)
    result_no_dupes.append([sorted_tup, count])

df_result = pd.DataFrame(result_no_dupes, columns=["subreddits", "count"]).sort_values(
    by=["count"], ascending=False
)
df_result.head(20)