In [25]:
import time
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, BooleanType, DateType, StructType, StructField
from pyspark.sql.functions import explode, col, split, array, array_min, concat, least, collect_set, size, sum, min


BUCKET_INPUT_PATH = "gs://iasd-input-data"
NB_WORKER_NODES = 1  # -------------- to be changed at each run


spark_session = SparkSession \
    .builder \
    .appName("PySpark App by Olivier & Jean-Loulou") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()
spark_context = spark_session.sparkContext


def load_rdd(path):
    return spark_context.textFile(path)


def load_df(path):
    return spark_session.read.format("csv").option("header","false")\
                .load(path)


def preprocess_rdd(rdd):
    return rdd.filter(lambda x: "#" not in x) \
                .map(lambda x: x.split("\t")) \
                .map(lambda x: (int(x[0]), int(x[1])))


def preprocess_df(df):
    col_name = df.columns[0]
    return df.filter(f"{col_name} NOT LIKE '#%'")\
                .withColumn('k', split(df[col_name], '\t').getItem(0)) \
                .withColumn('v', split(df[col_name], '\t').getItem(1)) \
                .drop(col_name)\
                .withColumn("k",col("k").cast(IntegerType())) \
                .withColumn("v",col("v").cast(IntegerType()))


def iterate_map_rdd(rdd):
    return rdd.union(rdd.map(lambda x : (x[1], x[0])))


def iterate_map_df(df):
    return df.union(df.select(col("v").alias("k"), col("k").alias("v")))


# countnb_new_pair function to know if additional CCF iteration is needed
def count_nb_new_pair(x):
  global nb_new_pair
  k, values = x
  min, value_list = k, []
  for v in values:
    if v < min:
       min = v
    value_list.append(v)
  if min < k:
    yield((k, min))
    for v in value_list:
      if min != v:
        nb_new_pair += 1
        yield((v, min))
        

def iterate_reduce_rdd(rdd):
    return rdd.groupByKey().flatMap(lambda x: count_nb_new_pair(x)).sortByKey()


def iterate_reduce_df(df):
    global nb_new_pair

    df = df.groupBy(col("k")).agg(collect_set("v").alias("v"))\
                                            .withColumn("min", least(col("k"), array_min("v")))\
                                            .filter((col("k")!=col('min')))

    nb_new_pair += df.withColumn("count", size("v")-1).select(sum("count")).collect()[0][0]

    return df.select(col("min").alias("a_min"), concat(array(col("k")), col("v")).alias("valueList"))\
                                                    .withColumn("valueList", explode("valueList"))\
                                                    .filter((col('a_min')!=col('valueList')))\
                                                    .select(col('a_min').alias("k"), col('valueList').alias("v"))


def compute_rdd(rdd):
    nb_iteration = 0
    while True:
        nb_iteration += 1
        start_pair = nb_new_pair.value

        rdd = iterate_map_rdd(rdd)
        rdd = iterate_reduce_rdd(rdd)
        rdd = rdd.distinct()

        print(f"Number of new pairs for iteration #{nb_iteration}:\t{nb_new_pair.value}")
        if start_pair == nb_new_pair.value:
            print("\nNo new pair, end of computation")
            break
    return rdd


def compute_cc_df(df):
    nb_iteration = 0
    while True:
        nb_iteration += 1
        nb_pairs_start = nb_new_pair.value

        df = iterate_map_df(df)
        df = iterate_reduce_df(df)
        df = df.distinct()
        
        print(f"Number of new pairs for iteration #{nb_iteration}:\t{nb_new_pair.value}")
        if nb_pairs_start == nb_new_pair.value:
            print("\nNo new pair, end of computation")
            break
    return df


def workflow_rdd(path):
    rdd_raw = load_rdd(path)
    rdd = preprocess_rdd(rdd_raw)
    start_time = time.time()
    rdd = compute_rdd(rdd)
    print(f"Nb of connected components in the graph: {rdd.map(lambda x : x[1]).distinct().count()}")
    print(f"Duration in seconds: {time.time() - start_time}")


def workflow_df(path):
    df_raw = load_df(path)
    df = preprocess_df(df_raw)
    start_time = time.time()
    df = compute_cc_df(df)
    print(f"Nb of connected components in the graph: {df.select('k').distinct().count()}")
    print(f"Duration in seconds: {time.time() - start_time}")   
    

def main():
    
    dataset_paths = {
    "test_example": f"{BUCKET_INPUT_PATH}/test.txt"#,
#     "notre_dame": f"{BUCKET_INPUT_PATH}/web-NotreDame.txt",
#     "berk_stan": f"{BUCKET_INPUT_PATH}/web-BerkStan.txt",
#     "stanford": f"{BUCKET_INPUT_PATH}/web-Stanford.txt",
#     "google": f"{BUCKET_INPUT_PATH}/web-Google.txt"
    }
    computation_methods = {
        "rdd": workflow_rdd,
        "df": workflow_df
    }
    
    for dataset in dataset_paths.keys():
        for method in computation_methods.keys():
            print("\n"* 3 + "_" * 10 + 
                  f" nb of clusters' nodes: {NB_WORKER_NODES} - dataset: {dataset} - method: {method} "
                  + "_" * 10)
            nb_new_pair = sc.accumulator(0)
            computation_methods[method](dataset_paths[dataset])



if __name__ == "__main__":
    main()




__________ nb of clusters' nodes: 1 - dataset: test_example - method: rdd __________
Number of new pairs for iteration #1:	93
Number of new pairs for iteration #2:	115
Number of new pairs for iteration #3:	132


                                                                                

Number of new pairs for iteration #4:	136


                                                                                

Number of new pairs for iteration #5:	136

No new pair, end of computation


                                                                                

Nb of connected components in the graph: 2
Duration in seconds: 12.615476608276367



__________ nb of clusters' nodes: 1 - dataset: test_example - method: df __________
Number of new pairs for iteration #1:	140
Number of new pairs for iteration #2:	149
Number of new pairs for iteration #3:	153
Number of new pairs for iteration #4:	153

No new pair, end of computation
Nb of connected components in the graph: 2
Duration in seconds: 6.003725290298462
