In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, DateType, FloatType, IntegerType, TimestampType, ArrayType, StructType, StructField
from pyspark.sql.functions import from_unixtime, sum, rank,lag, explode, expr,spark_partition_id, to_date, coalesce, lit, to_timestamp, col, month, concat, count, max, when, dayofweek, datediff,dense_rank, desc, date_format
import pyspark.sql.functions as F
from pyspark.sql.window import Window
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import re
from pyspark.ml.feature import Tokenizer, StopWordsRemover, Word2Vec,HashingTF,IDF, CountVectorizer,VectorAssembler
from pyspark.sql.functions import udf
from pyspark.ml.feature import Tokenizer, CountVectorizer, IDF
from pyspark.ml import Pipeline,PipelineModel
from sparknlp.base import DocumentAssembler, Finisher
from sparknlp.annotator import LemmatizerModel
from pyspark.ml.classification import LinearSVC, LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator,BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
import sparknlp
import warnings
from google.cloud import storage

In [2]:
# # remove all warnings
# warnings.filterwarnings('ignore')

In [3]:
# spark = SparkSession.builder.appName('SparkBasics').getOrCreate()
spark = SparkSession.builder.appName('ml').getOrCreate()
stoarge_client = storage.Client()
# Get the context of the Pyspark environment
spark.sparkContext.getConf().getAll()
# Store spark context as a variable
sc = spark.sparkContext

:: loading settings :: url = jar:file:/usr/lib/spark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
graphframes#graphframes added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-8eb4bcf7-cecc-4e71-9be9-59df7f016802;1.0
	confs: [default]
	found com.johnsnowlabs.nlp#spark-nlp_2.12;4.4.0 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.828 in central
	found com.github.universal-automata#liblevenshtein;3.0.0 in central
	found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central
	found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central
	found com.google.code.gson#gson;2.3 in central
	found it.unimi.dsi#fastutil;7.0.12 in central
	found org.projectlombok#lombok;1.16.8 in central
	found com.google.cloud#google-cloud-storage;2.16.0 in central
	found com.google.guava#guava;31.1-jre in centra

In [4]:
reddit_data_df = spark.read.parquet("gs://msca-bdp-student-gcs/Group2_Final_Project/reddit_data/",header=True, inferSchema=True)
reddit_data_df = reddit_data_df.dropna()

reddit_data_df.printSchema()
reddit_data_df.select(col("subreddit_name_prefixed")).show()

                                                                                

root
 |-- archived: string (nullable = true)
 |-- author: string (nullable = true)
 |-- author_fullname: string (nullable = true)
 |-- body: string (nullable = true)
 |-- comment_type: string (nullable = true)
 |-- controversiality: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- edited: string (nullable = true)
 |-- gilded: string (nullable = true)
 |-- id: string (nullable = true)
 |-- link_id: string (nullable = true)
 |-- locked: string (nullable = true)
 |-- name: string (nullable = true)
 |-- parent_id: string (nullable = true)
 |-- permalink: string (nullable = true)
 |-- retrieved_on: string (nullable = true)
 |-- score: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- subreddit_name_prefixed: string (nullable = true)
 |-- subreddit_type: string (nullable = true)
 |-- total_awards_received: string (nullable = true)





+-----------------------+
|subreddit_name_prefixed|
+-----------------------+
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
|                  r/DIY|
+-----------------------+
only showing top 20 rows



                                                                                

In [5]:
# for debug only, slice datasets
# # Assuming df is your original DataFrame
# games_df = reddit_data_df.filter(col("subreddit_name_prefixed") == "r/Games").limit(100)
# diy_df = reddit_data_df.filter(col("subreddit_name_prefixed") == "r/DIY").limit(100)

# # If you need to combine these two DataFrames
# combined_df = games_df.union(diy_df)

# # Show the result
# combined_df.groupBy(col("subreddit_name_prefixed")).agg(count("*")).show()


In [6]:
# twitter_data_df = spark.read.parquet("gs://msca-bdp-student-gcs/Group2_Final_Project/twitter_data/",header=True, inferSchema=True)
# twitter_data_df = twitter_data_df.dropna()
# twitter_data_df.show()
# grouped_tw_by_usr = twitter_data_df.groupBy(col("user")).agg(count("*").alias("tweet_count")).orderBy(col("tweet_count"),ascending = False)
# tw_by_usr_dist = grouped_tw_by_usr.select("*").groupBy("tweet_count").agg(count("*").alias("tweet_count_dist")).orderBy(col("tweet_count"))
# # grouped_tw_by_usr.show()
# tw_by_usr_dist.show()
# print(tw_by_usr_dist.count())

In [7]:
# Tokenize and stop word removal
def clean_text(text):
    # Deal with component words
    re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', text)
    # Convert to lowercase
    text = text.lower()
    # Remove Http / Https links in the text
    text = re.sub(r'http\S+', '', text)
    text = re.sub(r'https\S+', '', text)
    # Remove special characters and numbers
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # Handling repeated characters (more than 2)
    text = re.sub(r'(.)\1+', r'\1\1', text)
    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [8]:
# For the reddit datasets, could use subreddits as labels and predict how likely this comment belongs to what subreddits

# reddit_count_df = reddit_data_df.groupby(col("subreddit_name_prefixed")).agg(count("*").alias("reddit_count")).orderBy("reddit_count", ascending=False)
# reddit_count_df.show()


In [9]:
# for debug only -- slice count datasets
# reddit_count_df = combined_df.groupby(col("subreddit_name_prefixed")).agg(count("*").alias("reddit_count")).orderBy("reddit_count", ascending=False)
# reddit_count_df.show()

In [10]:
def encode_by_tags(df,labelcol):
    
    tags_lst = [row[labelcol] for row in df.select(col(labelcol)).distinct().collect()]
    df_res = df
    for tag in tags_lst:
        print(tag)
        df_res= df_res.withColumn(tag, when(col(labelcol) == tag, 1).otherwise(0))

    return tags_lst, df_res

In [11]:

def clean_df(df, inputcol):
    clean_text_udf = udf(clean_text, StringType())
    df_cleaned = df.withColumn(inputcol, clean_text_udf(df[inputcol]))
    return df_cleaned

def train_test_val_split(df, train_prob = 0.7, test_prob=0.2, val_prob= 0.1):
    train_df, test_df, validation_df = df.randomSplit([train_prob, test_prob, val_prob])
    return train_df, test_df, validation_df
    

def preprocess_pipeline(labelcol, inputcol,finfeaturecol):
    # tokenize the comments into words
    tokenizer = Tokenizer(inputCol=inputcol, outputCol="token")
    
    # remove stop words
    remover = StopWordsRemover(inputCol="token", outputCol="filtered_token")
    
    # vecotorize the words
    vectorizer = CountVectorizer(inputCol="filtered_token", outputCol="features")
    idf = IDF(inputCol="features", outputCol="tfidf_features")
    
    # assemble all features into 1 column
    assembler = VectorAssembler(inputCols=["features","tfidf_features"], outputCol=finfeaturecol)

    # Create the preprocessing piplines for the tweets
    pipeline = Pipeline().setStages([
        tokenizer,
        remover,
        vectorizer,
        idf,
        assembler
    ])
    return pipeline

def model_training_pipeline(featurecol, labelcol):
    svm = LinearSVC(labelCol=labelcol, featuresCol=featurecol)
    pipeline = Pipeline(stages=[svm])
    return svm,pipeline

In [12]:
def find_best_svm_hyperparameters(df,featurecol,labelCol):
    
    svm,pipeline = model_training_pipeline(featurecol, labelcol)
    # Set up the parameter grid
    paramGrid = ParamGridBuilder() \
        .addGrid(svm.maxIter, [10, 100]) \
        .addGrid(svm.regParam, [0.01, 0.1, 1.0]) \
        .build()

    # Evaluator
    evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol=labelCol)
    df.select(col(featurecol)).show()

    # Cross Validator
    crossval = CrossValidator(estimator=svm,
                              estimatorParamMaps=paramGrid,
                              evaluator=evaluator,
                              numFolds=3)

    # Run cross-validation and choose the best model
    print("start fitting")
    cvModel = crossval.fit(df)
    print(cvModel)

    return cvModel.bestModel

In [13]:
def list_all_files(bucket_name, folder_name):
    bucket = stoarge_client.bucket(bucket_name)
    file_lst = [('r/'+ blob.name.split("/")[2].split("_")[0]) for blob in bucket.list_blobs(prefix = folder_name)]
    file_lst = list(set(file_lst))
  
    return file_lst

In [23]:
bucket_name = 'msca-bdp-student-gcs'
folder_name = 'Group2_Final_Project/modelr/'
trained_tags = list_all_files(bucket_name, folder_name)

print(trained_tags)


['r/lifehacks', 'r/Showerthoughts', 'r/Damnthatsinteresting', 'r/SkincareAddiction', 'r/Documentaries', 'r/IWantToLearn', 'r/socialskills', 'r/YouShouldKnow', 'r/tifu', 'r/Games', 'r/AskHistorians', 'r/DIY', 'r/scifi', 'r/gadgets', 'r/IAmA', 'r/space', 'r/personalfinance', 'r/science', 'r/', 'r/UpliftingNews', 'r/Fantasy', 'r/explainlikeimfive', 'r/femalefashionadvice', 'r/bodyweightfitness', 'r/podcasts', 'r/todayilearned', 'r/gardening']


In [None]:

# clean dataframe
def get_model_specs(df,labelcol,inputcol,finfeaturecol,model_train):
    # model_dict
    model_lst = {}
    
    # model_spec
    model_spec = {}
    
    # encode the reddit_data_df into subreddits one-hot encoding df
    tags_lst, encoded_df = encode_by_tags(df,labelcol)
    
    # train the models by different tags
    # 更改这个标记
    
    for tag in tags_lst[30:40]:
        if tag in trained_tags:
            continue
    
        f1_evaluator = MulticlassClassificationEvaluator(labelCol=tag, predictionCol="prediction", metricName="f1")
        accuracy_evaluator = MulticlassClassificationEvaluator(labelCol=tag, predictionCol="prediction", metricName="accuracy")

        # use clean_df to clean up special characters and spaces
        df_cleaned = clean_df(encoded_df,inputcol)

        # train, test, val random split dataset
        train_df, test_df, val_df = train_test_val_split(df_cleaned)
        df_cleaned.select(tag)

        #Preprocess the dataframes
        preprocess = preprocess_pipeline(tag,inputcol,finfeaturecol)
        preprocess = preprocess.fit(train_df)
        train_processed = preprocess.transform(train_df).select(finfeaturecol, tag)
        test_processed = preprocess.transform(test_df).select(finfeaturecol, tag)
        validation_processed = preprocess.transform(val_df).select(finfeaturecol, tag)

        # find the best model by training through different hyperparameters of SVM
        train_processed.groupBy(tag).count().show()
        
#         print("start predicting best model")
#         best_model = find_best_svm_hyperparameters(train_processed,finfeaturecol,tag)
#         print("done predicting best model")

        svm, pipeline = model_train(finfeaturecol,tag)
         # Predict the model
        model = pipeline.fit(train_processed)
        predictions = model.transform(test_processed)
        val_predictions = model.transform(validation_processed)

        # get the acc and f1-score
        f1_score = f1_evaluator.evaluate(predictions)
        accuracy = accuracy_evaluator.evaluate(predictions)
        val_accuracy = accuracy_evaluator.evaluate(val_predictions)
        
        #save model and model spec
        model_lst[tag] = model
        model_spec[tag] = {"f1_score": f1_score, "accuracy": accuracy, "val_accuracy": val_accuracy}
        
        # 添加到 model 文件夹
        model.write().overwrite().save("gs://msca-bdp-student-gcs/Group2_Final_Project/model"+ tag +"_model")
    
    return model_lst, model_spec

# Transform the data
labelcol = "subreddit_name_prefixed"
inputcol = "body"
finfeaturecol = "final_features"
model_lst, model_spec = get_model_specs(reddit_data_df,labelcol,inputcol,finfeaturecol, model_training_pipeline)



                                                                                

r/YouShouldKnow
r/Fantasy
r/lifehacks
r/podcasts
r/UpliftingNews
r/AskHistorians
r/SkincareAddiction
r/IAmA
r/scifi
r/programming
r/Documentaries
r/todayilearned
r/gardening
r/IWantToLearn
r/science
r/explainlikeimfive
r/bodyweightfitness
r/bestof
r/Foodforthought
r/history
r/femalefashionadvice
r/Damnthatsinteresting
r/DIY
r/Showerthoughts
r/tifu
r/socialskills
r/Games
r/space
r/personalfinance
r/gadgets
r/LifeProTips
r/buildapc
r/boardgames
r/malefashionadvice
r/WritingPrompts
r/changemyview
r/philosophy
r/gaming
r/travel
r/technology
r/books
r/suggestmeabook
r/ifyoulikeblank
r/Fitness
r/GetMotivated
r/mildlyinteresting
r/EatCheapAndHealthy
r/sports
r/relationship_advice
r/askscience


23/11/26 03:53:55 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.9 MiB
23/11/26 03:56:03 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.9 MiB
23/11/26 03:56:09 WARN org.apache.spark.sql.catalyst.util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

+-------------+--------+
|r/LifeProTips|   count|
+-------------+--------+
|            1| 1835558|
|            0|64461618|
+-------------+--------+



23/11/26 03:57:46 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 04:00:17 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 04:00:21 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 04:02:51 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 04:02:52 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
23/11/26 04:02:52 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
23/11/26 04:02:55 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 04:03:13 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 04:03:16 WARN org.apache.spark.scheduler.DAGSc

+----------+--------+
|r/buildapc|   count|
+----------+--------+
|         1| 2126162|
|         0|64170174|
+----------+--------+



23/11/26 05:20:16 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:22:26 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:22:30 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:24:40 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:24:44 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:24:58 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:25:02 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:25:20 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 05:25:24 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task 

+------------+--------+
|r/boardgames|   count|
+------------+--------+
|           1|  720138|
|           0|65565597|
+------------+--------+



23/11/26 06:49:51 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:52:16 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:52:20 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:54:41 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:54:44 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:55:13 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:55:17 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:55:33 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 37.7 MiB
23/11/26 06:55:36 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task 

In [None]:
print(model_spec)

In [None]:

# convert the model into dataframe for better visualization
model_data = [(tag, specs["f1_score"], specs["accuracy"],specs["val_accuracy"]) for tag, specs in model_spec.items()]

schema = StructType([
    StructField("model", StringType(), True),
    StructField("f1_score", FloatType(), True),
    StructField("accuracy", FloatType(), True),
    StructField("val_accuracy", FloatType(), True)
])

# Create DataFrame
model_df = spark.createDataFrame(reddit_data_df, schema)

model_df.show()

In [None]:
# save all the models
# for key in model_lst:
#     model_lst[key].write().overwrite().save("msca-bdp-student-gcs/Group2_Final_Project/model"+ key +"_model")