In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, DateType, FloatType, IntegerType, TimestampType, ArrayType, StructType, StructField
from pyspark.sql.functions import from_unixtime, sum, rank,lag, explode, expr,spark_partition_id, to_date, coalesce, lit, to_timestamp, col, month, concat, count, max, when, dayofweek, datediff,dense_rank, desc, date_format
import pyspark.sql.functions as F
from pyspark.sql.window import Window
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import re
from pyspark.ml.feature import Tokenizer, StopWordsRemover, Word2Vec,HashingTF,IDF, CountVectorizer,VectorAssembler
from pyspark.sql.functions import udf
from pyspark.ml.feature import Tokenizer, CountVectorizer, IDF
from pyspark.ml import Pipeline,PipelineModel
from sparknlp.base import DocumentAssembler, Finisher
from sparknlp.annotator import LemmatizerModel
from pyspark.ml.classification import LinearSVC, LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator,BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
import sparknlp
import warnings


In [2]:
# # remove all warnings
# warnings.filterwarnings('ignore')

In [3]:
# spark = SparkSession.builder.appName('SparkBasics').getOrCreate()
spark = SparkSession.builder.appName('ml').getOrCreate()

# Get the context of the Pyspark environment
spark.sparkContext.getConf().getAll()
# Store spark context as a variable
sc = spark.sparkContext

In [4]:
reddit_data_df = spark.read.parquet("gs://msca-bdp-student-gcs/Group2_Final_Project/reddit_data/",header=True, inferSchema=True)
reddit_data_df = reddit_data_df.dropna()

reddit_data_df.printSchema()
reddit_data_df.select(col("subreddit_name_prefixed")).show()

                                                                                

root
 |-- archived: string (nullable = true)
 |-- author: string (nullable = true)
 |-- author_fullname: string (nullable = true)
 |-- body: string (nullable = true)
 |-- comment_type: string (nullable = true)
 |-- controversiality: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- edited: string (nullable = true)
 |-- gilded: string (nullable = true)
 |-- id: string (nullable = true)
 |-- link_id: string (nullable = true)
 |-- locked: string (nullable = true)
 |-- name: string (nullable = true)
 |-- parent_id: string (nullable = true)
 |-- permalink: string (nullable = true)
 |-- retrieved_on: string (nullable = true)
 |-- score: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- subreddit_name_prefixed: string (nullable = true)
 |-- subreddit_type: string (nullable = true)
 |-- total_awards_received: string (nullable = true)





+-----------------------+
|subreddit_name_prefixed|
+-----------------------+
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
+-----------------------+
only showing top 20 rows



                                                                                

In [5]:
# for debug only, slice datasets
# # Assuming df is your original DataFrame
games_df = reddit_data_df.filter(col("subreddit_name_prefixed") == "r/Games").limit(100)
diy_df = reddit_data_df.filter(col("subreddit_name_prefixed") == "r/DIY").limit(100)

# # If you need to combine these two DataFrames
combined_df = games_df.union(diy_df)

# # Show the result
combined_df.groupBy(col("subreddit_name_prefixed")).agg(count("*")).show()


                                                                                

+-----------------------+--------+
|subreddit_name_prefixed|count(1)|
+-----------------------+--------+
|                  r/DIY|     100|
|                r/Games|     100|
+-----------------------+--------+



In [6]:
twitter_data_df = spark.read.parquet("gs://msca-bdp-student-gcs/Group2_Final_Project/twitter_data/",header=True, inferSchema=True)
twitter_data_df = twitter_data_df.dropna()
twitter_data_df.show()
grouped_tw_by_usr = twitter_data_df.groupBy(col("user")).agg(count("*").alias("tweet_count")).orderBy(col("tweet_count"),ascending = False)
tw_by_usr_dist = grouped_tw_by_usr.select("*").groupBy("tweet_count").agg(count("*").alias("tweet_count_dist")).orderBy(col("tweet_count"))
# # grouped_tw_by_usr.show()
tw_by_usr_dist.show()
print(tw_by_usr_dist.count())

+--------------+-------------------+--------------------+-------+--------+-----+------+--------------------+
|          user|                 id|               tweet|replies|retweets|likes|quotes|                date|
+--------------+-------------------+--------------------+-------+--------+-----+------+--------------------+
|MarketOne_Intl|1633518817476321299|Are you strugglin...|      0|       1|    3|     0|2023-03-08T17:23:...|
|MarketOne_Intl|1606666027252719616|Wishing all our c...|      0|       0|    0|     0|2022-12-24T15:00:...|
|MarketOne_Intl|1572643083279388673|So you've built y...|      0|       0|    0|     0|2022-09-21T17:45:...|
|MarketOne_Intl|1565372152945278978|Is your #Salesdev...|      0|       1|    1|     1|2022-09-01T16:12:...|
|MarketOne_Intl|1562786915916849152|Is your sales tea...|      0|       1|    0|     0|2022-08-25T13:00:...|
|MarketOne_Intl|1554461167313108995|Join Oktopost and...|      0|       0|    0|     0|2022-08-02T13:36:...|
|MarketOne_Intl|155

                                                                                

+-----------+----------------+
|tweet_count|tweet_count_dist|
+-----------+----------------+
|          1|            1947|
|          2|            1265|
|          3|             977|
|          4|             878|
|          5|             791|
|          6|             640|
|          7|             610|
|          8|             559|
|          9|             554|
|         10|             529|
|         11|             516|
|         12|             474|
|         13|             447|
|         14|             431|
|         15|             425|
|         16|             409|
|         17|             376|
|         18|             367|
|         19|             342|
|         20|             337|
+-----------+----------------+
only showing top 20 rows





1001


                                                                                

In [8]:
# Tokenize and stop word removal
def clean_text(text):
    # Deal with component words
    re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', text)
    # Convert to lowercase
    text = text.lower()
    # Remove Http / Https links in the text
    text = re.sub(r'http\S+', '', text)
    text = re.sub(r'https\S+', '', text)
    # Remove special characters and numbers
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # Handling repeated characters (more than 2)
    text = re.sub(r'(.)\1+', r'\1\1', text)
    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [10]:
# For the reddit datasets, could use subreddits as labels and predict how likely this comment belongs to what subreddits

reddit_count_df = reddit_data_df.groupby(col("subreddit_name_prefixed")).agg(count("*").alias("reddit_count")).orderBy("reddit_count", ascending=False)
reddit_count_df.show()




+-----------------------+------------+
|subreddit_name_prefixed|reddit_count|
+-----------------------+------------+
|   r/relationship_ad...|    13305512|
|               r/gaming|    11472113|
|   r/Damnthatsintere...|     9767618|
|    r/mildlyinteresting|     5795046|
|        r/todayilearned|     5059976|
|           r/technology|     4938743|
|       r/Showerthoughts|     4313694|
|             r/buildapc|     3037405|
|      r/personalfinance|     2801977|
|          r/LifeProTips|     2621564|
|         r/changemyview|     2386746|
|                 r/tifu|     2372659|
|                r/Games|     2159863|
|              r/science|     2039652|
|    r/explainlikeimfive|     2005588|
|                r/books|     1964390|
|       r/suggestmeabook|     1283406|
|            r/gardening|     1273715|
|                r/space|     1270348|
|              r/Fantasy|     1123160|
+-----------------------+------------+
only showing top 20 rows



                                                                                

In [11]:
# for debug only -- slice count datasets
reddit_count_df = combined_df.groupby(col("subreddit_name_prefixed")).agg(count("*").alias("reddit_count")).orderBy("reddit_count", ascending=False)
reddit_count_df.show()



+-----------------------+------------+
|subreddit_name_prefixed|reddit_count|
+-----------------------+------------+
|                  r/DIY|         100|
|                r/Games|         100|
+-----------------------+------------+



                                                                                

In [12]:
def encode_by_tags(df,labelcol):
    
    tags_lst = [row[labelcol] for row in df.select(col(labelcol)).distinct().collect()]
    df_res = df
    for tag in tags_lst:
        print(tag)
        df_res= df_res.withColumn(tag, when(col(labelcol) == tag, 1).otherwise(0))

    return tags_lst, df_res

In [13]:

def clean_df(df, inputcol):
    clean_text_udf = udf(clean_text, StringType())
    df_cleaned = df.withColumn(inputcol, clean_text_udf(df[inputcol]))
    return df_cleaned

def train_test_val_split(df, train_prob = 0.7, test_prob=0.2, val_prob= 0.1):
    train_df, test_df, validation_df = df.randomSplit([train_prob, test_prob, val_prob])
    return train_df, test_df, validation_df
    

def preprocess_pipeline(labelcol, inputcol,finfeaturecol):
    # tokenize the comments into words
    tokenizer = Tokenizer(inputCol=inputcol, outputCol="token")
    
    # remove stop words
    remover = StopWordsRemover(inputCol="token", outputCol="filtered_token")
    
    # vecotorize the words
    vectorizer = CountVectorizer(inputCol="filtered_token", outputCol="features")
    idf = IDF(inputCol="features", outputCol="tfidf_features")
    
    # assemble all features into 1 column
    assembler = VectorAssembler(inputCols=["features","tfidf_features"], outputCol=finfeaturecol)

    # Create the preprocessing piplines for the tweets
    pipeline = Pipeline().setStages([
        tokenizer,
        remover,
        vectorizer,
        idf,
        assembler
    ])
    return pipeline

def model_training_pipeline(featurecol, labelcol):
    svm = LinearSVC(labelCol=labelcol, featuresCol=featurecol)
    pipeline = Pipeline(stages=[svm])
    return svm,pipeline

In [23]:
def find_best_svm_hyperparameters(df,featurecol,labelCol):
    
    svm,pipeline = model_training_pipeline(featurecol, labelCol)
    # Set up the parameter grid
    paramGrid = ParamGridBuilder() \
        .addGrid(svm.maxIter, [10, 100]) \
        .addGrid(svm.regParam, [0.01, 0.1, 1.0]) \
        .build()

    # Evaluator
    evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol=labelCol)
    df.select(col(featurecol)).show()

    # Cross Validator
    crossval = CrossValidator(estimator=svm,
                              estimatorParamMaps=paramGrid,
                              evaluator=evaluator,
                              numFolds=3)

    # Run cross-validation and choose the best model
    print("start fitting")
    cvModel = crossval.fit(df)
    print(cvModel)

    return cvModel.bestModel

In [None]:

# clean dataframe
def get_model_specs(df,labelcol,inputcol,finfeaturecol,model_train):
    # model_dict
    model_lst = {}
    
    # model_spec
    model_spec = {}
    
    # encode the reddit_data_df into subreddits one-hot encoding df
    tags_lst, encoded_df = encode_by_tags(df,labelcol)
    
    # train the models by different tags
    for tag in tags_lst:
    
        f1_evaluator = MulticlassClassificationEvaluator(labelCol=tag, predictionCol="prediction", metricName="f1")
        accuracy_evaluator = MulticlassClassificationEvaluator(labelCol=tag, predictionCol="prediction", metricName="accuracy")

        # use clean_df to clean up special characters and spaces
        df_cleaned = clean_df(encoded_df,inputcol)

        # train, test, val random split dataset
        train_df, test_df, val_df = train_test_val_split(df_cleaned)
        df_cleaned.select(tag)

        #Preprocess the dataframes
        preprocess = preprocess_pipeline(tag,inputcol,finfeaturecol)
        preprocess = preprocess.fit(train_df)
        train_processed = preprocess.transform(train_df)
        test_processed = preprocess.transform(test_df)#.select(finfeaturecol, tag)
        validation_processed = preprocess.transform(val_df)#.select(finfeaturecol, tag)

        # find the best model by training through different hyperparameters of SVM
        train_processed.show()
        
        print("start predicting best model")
        best_model = find_best_svm_hyperparameters(train_processed,finfeaturecol,tag)
        print("done predicting best model")

        svm, pipeline = model_train(finfeaturecol,tag)
         # Predict the model
        model = pipeline.fit(train_processed)
        predictions = best_model.transform(test_processed)
        val_predictions = best_model.transform(validation_processed)

        # get the acc and f1-score
        f1_score = f1_evaluator.evaluate(predictions)
        accuracy = accuracy_evaluator.evaluate(predictions)
        val_accuracy = accuracy_evaluator.evaluate(val_predictions)
        
        #save model and model spec
        model_lst[tag] = model
        model_spec[tag] = {"f1_score": f1_score, "accuracy": accuracy, "val_accuracy": val_accuracy}
        model.write().overwrite().save("msca-bdp-student-gcs/Group2_Final_Project/model"+ tag +"_model")
    return model_lst, model_spec

# Transform the data
labelcol = "subreddit_name_prefixed"
inputcol = "body"
finfeaturecol = "final_features"
model_lst, model_spec = get_model_specs(combined_df,labelcol,inputcol,finfeaturecol, model_training_pipeline)



                                                                                

r/DIY
r/Games


                                                                                

+--------+--------------------+---------------+--------------------+------------+----------------+-----------+------------+------+-------+---------+------+----------+----------+--------------------+------------+-----+------------+-----------------------+--------------+---------------------+-----+-------+--------------------+--------------------+--------------------+--------------------+--------------------+
|archived|              author|author_fullname|                body|comment_type|controversiality|created_utc|      edited|gilded|     id|  link_id|locked|      name| parent_id|           permalink|retrieved_on|score|subreddit_id|subreddit_name_prefixed|subreddit_type|total_awards_received|r/DIY|r/Games|               token|      filtered_token|            features|      tfidf_features|      final_features|
+--------+--------------------+---------------+--------------------+------------+----------------+-----------+------------+------+-------+---------+------+----------+----------+-

                                                                                

+--------------------+
|      final_features|
+--------------------+
|(3208,[1,24,34,48...|
|(3208,[1,12,29,30...|
|(3208,[0,1,3,4,5,...|
|(3208,[1,27,37,45...|
|(3208,[26,45,348,...|
|(3208,[182,186,17...|
|(3208,[1,11,12,22...|
|(3208,[1,13,172,2...|
|(3208,[11,74,78,1...|
|(3208,[1,45,177,4...|
|(3208,[518,2122],...|
|(3208,[12,26,95,1...|
|(3208,[26,83,84,1...|
|(3208,[31,57,143,...|
|(3208,[12,151,182...|
|(3208,[1,2,8,9,11...|
|(3208,[1,2,161,35...|
|(3208,[267,552,97...|
|(3208,[1,12,27,40...|
|(3208,[2,27,54,57...|
+--------------------+
only showing top 20 rows

start fitting


23/11/25 05:26:45 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
23/11/25 05:26:45 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS

CrossValidatorModel_9eebdc6b5f0f
done predicting best model


                                                                                

+--------+--------------------+---------------+--------------------+------------+----------------+-----------+------------+------+-------+---------+------+----------+----------+--------------------+------------+-----+------------+-----------------------+--------------+---------------------+-----+-------+--------------------+--------------------+--------------------+--------------------+--------------------+
|archived|              author|author_fullname|                body|comment_type|controversiality|created_utc|      edited|gilded|     id|  link_id|locked|      name| parent_id|           permalink|retrieved_on|score|subreddit_id|subreddit_name_prefixed|subreddit_type|total_awards_received|r/DIY|r/Games|               token|      filtered_token|            features|      tfidf_features|      final_features|
+--------+--------------------+---------------+--------------------+------------+----------------+-----------+------------+------+-------+---------+------+----------+----------+-

                                                                                

+--------------------+
|      final_features|
+--------------------+
|(3276,[4,7,175,18...|
|(3276,[4,7,15,17,...|
|(3276,[50,57,73,5...|
|(3276,[44,72,295,...|
|(3276,[4,11,33,36...|
|(3276,[4,7,33,187...|
|(3276,[139,163,24...|
|(3276,[130,259,33...|
|(3276,[36,196,251...|
|(3276,[3,9,35,36,...|
|(3276,[2,3,4,5,7,...|
|(3276,[6,9,28,164...|
|(3276,[28,40,49,8...|
|(3276,[13,84,115,...|
|(3276,[273,377,40...|
|(3276,[7,11,17,33...|
|(3276,[11,17,35,7...|
|(3276,[7,30,54,81...|
|(3276,[3,25,42,12...|
|(3276,[192,633,10...|
+--------------------+
only showing top 20 rows

start fitting




CrossValidatorModel_f44b849622c4
done predicting best model


                                                                                

In [26]:
print(model_spec)

{'r/DIY': {'f1_score': 0.8151511835722363, 'accuracy': 0.8157894736842105, 'val_accuracy': 0.7058823529411765}, 'r/Games': {'f1_score': 0.7373422054273119, 'accuracy': 0.8095238095238095, 'val_accuracy': 0.6923076923076923}}


In [29]:

# convert the model into dataframe for better visualization
model_data = [(tag, specs["f1_score"], specs["accuracy"],specs["val_accuracy"]) for tag, specs in model_spec.items()]

schema = StructType([
    StructField("model", StringType(), True),
    StructField("f1_score", FloatType(), True),
    StructField("accuracy", FloatType(), True),
    StructField("val_accuracy", FloatType(), True)
])

# Create DataFrame
model_df = spark.createDataFrame(model_data, schema)

model_df.show()

+-------+----------+----------+------------+
|  model|  f1_score|  accuracy|val_accuracy|
+-------+----------+----------+------------+
|  r/DIY|0.81515115|0.81578946|   0.7058824|
|r/Games| 0.7373422| 0.8095238|   0.6923077|
+-------+----------+----------+------------+

