In [1]:
from pyspark.sql.functions import array, col, collect_list, struct, when

class Decorator:
    '''Responsible for modifying the dataframe such as adding
        additional columns or combining rows together'''
    def calc_max_depth(self, df):
        columns = [c for c in df.columns if 'id' in c.lower()]
        return df.withColumn("levels", \
            sum(when(col(column).isNotNull(), 1) \
            .otherwise(0) for column in columns)) \
            .agg({"levels": "max"}).collect()[0][0]

    def aggregate_child_node(self, df):
        return df.groupBy(df.pid)\
            .agg(collect_list(df.concat).alias('children'))\
            .withColumnRenamed('pid', 'tempID')
    
    def concat_columns(self, df):
        return df.withColumn('concat', (struct( \
            col('label'), col('ID'),\
            col('link'), col('children'))))

    def remove_duplicates(self, df):
        return df.distinct()
    
    def fill_null_children(self, df):
        return df.withColumn( \
            "children", \
            when(df.children.isNull(), array()) \
            .otherwise(df.children))

    def append_empty_children(self, df):
        return df.withColumn("children", array())

In [2]:
from pyspark.sql import SparkSession

class Initializer:
    '''Function of class is to initialize spark session'''
    def initialize(self):
        return SparkSession.builder\
            .master("local")\
            .config("spark.sql.autoBroadcastJoinThreshold", -1)\
            .config("spark.executor.memory", "500mb")\
            .appName("morisson")\
            .getOrCreate()

In [3]:
class Getter:
    '''Responsible for getting information from dataframe 
        like column names and partial dataframe'''
    def get_column_names(self, df, depth):
        pattern = str(depth) + str(' ')
        columns = [c for c in df.columns if pattern in c]
        for c in columns:
            if 'name' in c.lower():
                name = c
            if 'url' in c.lower():
                url = c
            if 'id' in c.lower():
                id = c
        return name, id, url
    
    def get_first_node(self, df, name, id, url):
        return df.select(name, id, url)\
            .filter(df[id].isNotNull())\
            .withColumnRenamed(name, 'label')\
            .withColumnRenamed(id, 'ID')\
            .withColumnRenamed(url, 'link')
    
    def get_node(self, df, parent, name, id, url):
        return df.select(parent, name, id, url).withColumnRenamed(parent, 'pid')\
            .filter(df[id].isNotNull())\
            .withColumnRenamed(name, 'label')\
            .withColumnRenamed(id, 'ID')\
            .withColumnRenamed(url, 'link')

In [4]:
import os
import pandas
from pyspark.sql.types import ArrayType, StringType, StructField, StructType
from pyspark.sql.functions import col
import json

# from decorator import Decorator
# from initializer import Initializer
# from getter import Getter

dec = Decorator()
init = Initializer()
getter = Getter()

def handle(path, spark):
    df = spark.read.csv(path, header="true")
    max_depth = dec.calc_max_depth(df)
    depth_node_dict = create_depth_node_dict(df, max_depth)
    if max_depth == 1:
        node = depth_node_dict[1]
        filled = dec.append_empty_children(node)
        result = dec.concat_columns(filled)
    else:
        result = combine_nodes(depth_node_dict, max_depth)
    return create_result(result, spark)

def create_depth_node_dict(df, max_depth):
    depth_node_dict = {}
    for i in range(max_depth, 0, -1):
        columns = getter.get_column_names(df, i)
        if i == 1:
            temp = getter.get_first_node( \
                                df, columns[0], \
                                columns[1], columns[2])
            depth_node_dict[i] = dec.remove_duplicates(temp)
        else:
            parent_id = 'Level ' + str(i-1) + str(' - ID')
            temp = getter.get_node(\
                                df, parent_id, columns[0], \
                                columns[1], columns[2])
            depth_node_dict[i] = dec.remove_duplicates(temp)
    return depth_node_dict

def join_parent_child(left, right):
    return left.join(\
        right, left.ID == right.tempID, 'left')

def combine_nodes(depth_node_dict, max_depth):
    agg = {}

    for i in range(max_depth, 0, -1):
        if max_depth == i:
            node = depth_node_dict[i]
            filled = dec.append_empty_children(node)
            formatted = dec.concat_columns(filled).orderBy(col("concat"))
            agg[i] = dec.aggregate_child_node(formatted)
        elif i == 1:
            joined = join_parent_child(depth_node_dict[i], agg[i+1])
            filled = dec.fill_null_children(joined)
            agg[i] = filled
        else:
            joined = join_parent_child(depth_node_dict[i], agg[i+1])
            filled = dec.fill_null_children(joined)
            formatted = dec.concat_columns(filled).orderBy(col("concat"))
            agg[i] = dec.aggregate_child_node(formatted)
    return agg[i]

def create_result(df, spark):
    if df:
        result = df.select('label', 'ID', 'link', 'children')
    else:
        schema = StructType([ \
            StructField('label', StringType()), \
            StructField('ID', StringType()), \
            StructField('link', StringType()), \
            StructField('children', ArrayType(StructType()))])
        result = spark.createDataFrame(spark.sparkContext.emptyRDD(), schema)
    return result

def write_output(result):
    current = os.getcwd()
    output_path = os.path.join(current, "webout")
    # print(output)
    result.write.json(path=output_path, ignoreNullFields=False, mode='overwrite')

In [5]:
spark = init.initialize()

21/12/30 19:03:59 WARN Utils: Your hostname, LTPH008120 resolves to a loopback address: 127.0.1.1; using 172.17.226.203 instead (on interface eth0)
21/12/30 19:03:59 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/12/30 19:04:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [None]:
def trigger(path, spark):
    result = handle(path, spark)
    return result
    
path = "/home/dindo/t2/csv_to_json/files/original_data.csv"
temp = trigger(path, spark)

                                                                                

In [None]:
temp2 = temp.collect()

new = []
for i in temp2:
    new.append(i.asDict(True))

In [None]:
print(type(new))
new2 = json.dumps(new)

<class 'list'>


In [None]:
from flask import Flask

app = Flask(__name__)

@app.route('/')
def index():
    return 'test'

if __name__ == '__main__':
    app.run(debug=True)

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: on


OSError: [Errno 98] Address already in use

In [None]:
print(new2)

[{"label": "THE BEST", "ID": "178974", "link": "https://groceries.morrisons.com/browse/178974", "children": [{"label": "BAKERY & CAKES", "ID": "178971", "link": "https://groceries.morrisons.com/browse/178974/178971", "children": [{"label": "BREAD & BREAD ROLLS", "ID": "179023", "link": "https://groceries.morrisons.com/browse/178974/178971/179023", "children": []}, {"label": "CAKES, PIES & TARTS", "ID": "179024", "link": "https://groceries.morrisons.com/browse/178974/178971/179024", "children": []}, {"label": "CROISSANTS & BREAKFAST BAKERY", "ID": "179025", "link": "https://groceries.morrisons.com/browse/178974/178971/179025", "children": []}, {"label": "DESSERTS & PUDDINGS", "ID": "179026", "link": "https://groceries.morrisons.com/browse/178974/178971/179026", "children": []}, {"label": "FRUITED BREAD, SCONES & HOT CROSS BUNS", "ID": "179027", "link": "https://groceries.morrisons.com/browse/178974/178971/179027", "children": []}]}, {"label": "DRINKS", "ID": "178973", "link": "https://g

21/12/30 17:13:07 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 14903460 ms exceeds timeout 120000 ms
21/12/30 17:13:07 WARN SparkContext: Killing executors is not supported by current scheduler.
