In [115]:
import os
from datetime import date
import graphviz
import pandas as pd
import json
from collections import defaultdict
from copy import copy
import re
import configparser
from pathlib import Path

config = configparser.ConfigParser()
config.read('config.ini')


class RddAsNode:
    
    def __init__(self, name, is_cached, number_of_usage, number_of_computations):
        self.name = name
        self.is_cached = is_cached
        self.number_of_usage = number_of_usage
        self.number_of_computations = number_of_computations

        
class Rdd:
    
    def __init__(self, id, name, parents_lst, stage_id, job_id, is_cached):
        self.id = id
        self.name = name
        self.parents_lst = parents_lst
        self.stage_id = stage_id
        self.job_id = job_id
        self.is_cached = is_cached

        
class Transformation:
    
    def __init__(self, from_rdd, to_rdd, is_narrow):
        self.from_rdd = from_rdd
        self.to_rdd = to_rdd
        self.is_narrow = is_narrow

    def __eq__(self, other):
        if (isinstance(other, Transformation)):
            return self.from_rdd == other.from_rdd and self.to_rdd == other.to_rdd
        return False

    def __hash__(self):
        return hash(self.from_rdd) ^ hash(self.to_rdd)

    def __lt__(self, other):
        if self.from_rdd == other.from_rdd:
            self.to_rdd < other.to_rdd
        return self.from_rdd < other.from_rdd

    
class CachingPlanItem:
    
    def __init__(self, stage_id, job_id, rdd_id, is_cache_item):
        self.stage_id = stage_id
        self.job_id = job_id
        self.rdd_id = rdd_id
        self.is_cache_item = is_cache_item

    def __lt__(self, other):
        if self.job_id == other.job_id:
            if self.stage_id == other.stage_id:
                if self.is_cache_item == other.is_cache_item:
                    return self.rdd_id
                return self.is_cache_item
            return self.stage_id < other.stage_id
        return self.job_id < other.job_id


class Utility():
    def get_absolute_path(path):
        if not os.path.isabs(path):
            return str(Path().absolute()) + '/' + path
        return path

    def intersection(lst1, lst2):
        lst3 = [value for value in lst1 if value in lst2]
        return lst3    

    
class FactHub():
    
    app_name = ""
    job_info_dect = {}
    stage_info_dect = {}
    stage_job_dect = {}
    stage_name_dect = {}
    submitted_stage_last_rdd_dect = {}
    job_last_rdd_dect = {}
    submitted_stages = set()
    rdds_lst = []

    def flush():
        FactHub.app_name = ""
        FactHub.job_info_dect = {}
        FactHub.stage_info_dect = {}
        FactHub.stage_job_dect = {}
        FactHub.stage_name_dect = {}
        FactHub.submitted_stage_last_rdd_dect = {}
        FactHub.job_last_rdd_dect = {}
        FactHub.submitted_stages.clear()
        FactHub.rdds_lst = []

        
class AnalysisHub():
    
    transformations_set = set()
    rdd_num_of_computations = defaultdict(int)
    rdd_num_of_usage = defaultdict(int)
    anomalies_dict = {}
    stage_computed_rdds = {}
    stage_used_rdds = {}
    computed_rdds = set()
    rdd_usage_lifetime_dict = {}
    caching_plan_lst = []
    memory_footprint_lst = []
    cached_rdds_set = set()
    non_cached_rdds_set = set()


    def flush():
        AnalysisHub.transformations_set.clear()
        AnalysisHub.rdd_num_of_computations = defaultdict(int)
        AnalysisHub.rdd_num_of_usage = defaultdict(int)
        AnalysisHub.anomalies_dict = {}
        AnalysisHub.stage_computed_rdds = {}
        AnalysisHub.stage_used_rdds = {}
        AnalysisHub.computed_rdds.clear()
        AnalysisHub.rdd_usage_lifetime_dict = {}

        
class Parser():    
    
    def prepare(raw_log_file):
        all_events_lst = pd.read_json(raw_log_file, lines=True)
        FactHub.app_name = all_events_lst[all_events_lst['Event'] == 'SparkListenerApplicationStart']['App Name'].tolist()[0]
        print(FactHub.app_name)
        Parser.prepare_from_stage_submitted_events(all_events_lst[all_events_lst['Event'] == 'SparkListenerStageSubmitted'])
        Parser.prepare_from_job_start_events(all_events_lst[all_events_lst['Event'] == 'SparkListenerJobStart'])
        
    def prepare_from_stage_submitted_events(stage_submitted_events):
        for index, submitted_stage in enumerate(stage_submitted_events['Stage Info'].tolist()):
            FactHub.submitted_stages.add(submitted_stage['Stage ID'])

    def prepare_from_job_start_events(job_start_events):
        job_ids_list = job_start_events['Job ID'].tolist()
        job_stage_info_list = job_start_events['Stage Infos'].tolist()
        for job_num, job_rec in enumerate(job_stage_info_list):
            job_id = int(job_ids_list[job_num])
            FactHub.job_info_dect[job_id] = job_rec
            id_of_last_rdd_in_job = -1
            for stage_num, stage_rec in enumerate(job_rec):
                stage_id = int(stage_rec['Stage ID'])
                FactHub.stage_job_dect[stage_id] = job_id
                FactHub.stage_info_dect[stage_id] = stage_rec
                FactHub.stage_name_dect[stage_id] = stage_rec['Stage Name']
                id_of_last_rdd_in_stage = -1
                for stage_rdd_num, stage_rdd_rec in enumerate(stage_rec['RDD Info']):
                    rdd_id = stage_rdd_rec['RDD ID']
                    is_cached = stage_rdd_rec['Storage Level']['Use Memory'] or stage_rdd_rec['Storage Level']['Use Disk']
                    FactHub.rdds_lst.append(Rdd(rdd_id, stage_rdd_rec['Name'] + '\n' + stage_rdd_rec['Callsite'], stage_rdd_rec['Parent IDs'], stage_id, job_id, is_cached))
                    if id_of_last_rdd_in_job < rdd_id:
                        id_of_last_rdd_in_job = rdd_id
                    if id_of_last_rdd_in_stage < rdd_id:
                        id_of_last_rdd_in_stage = rdd_id
                if stage_id in FactHub.submitted_stages:
                    FactHub.submitted_stage_last_rdd_dect[stage_id] = id_of_last_rdd_in_stage
            FactHub.job_last_rdd_dect[job_id] = (id_of_last_rdd_in_job, stage_rec['Stage Name'])

    
class Analyzer():

    def is_narrow_transformation(rdd_id, parent_id):
        rdd_stages_set = set()
        parent_stages_set = set()
        for rdd in FactHub.rdds_lst:
            if rdd.id == rdd_id:
                rdd_stages_set.add(rdd.stage_id)
            elif rdd.id == parent_id:
                parent_stages_set.add(rdd.stage_id)
        return len(Utility.intersection(rdd_stages_set, parent_stages_set)) != 0

    def prepare_transformations_lst():
        for rdd in FactHub.rdds_lst:
            for parent_id in rdd.parents_lst:
                AnalysisHub.transformations_set.add(Transformation(rdd.id, parent_id, Analyzer.is_narrow_transformation(rdd.id, parent_id)))

    def add_rdd_and_its_parents_if_it_is_computed_in_stage(rdd_id, stage_id):#recursive
        if rdd_id not in AnalysisHub.stage_used_rdds[stage_id]:
            AnalysisHub.rdd_num_of_usage[rdd_id] += 1
            AnalysisHub.stage_used_rdds[stage_id].add(rdd_id)            
        for rdd in FactHub.rdds_lst:
            if rdd.id == rdd_id: 
                if rdd.is_cached:
                    if rdd_id not in AnalysisHub.rdd_usage_lifetime_dict:
                        AnalysisHub.rdd_usage_lifetime_dict[rdd.id] = (rdd.stage_id, rdd.job_id, rdd.stage_id, rdd.job_id)
                    if AnalysisHub.rdd_usage_lifetime_dict[rdd_id][0] > stage_id:
                        AnalysisHub.rdd_usage_lifetime_dict[rdd.id] = (rdd.stage_id, rdd.job_id, AnalysisHub.rdd_usage_lifetime_dict[rdd_id][2], AnalysisHub.rdd_usage_lifetime_dict[rdd_id][3])
                    if AnalysisHub.rdd_usage_lifetime_dict[rdd_id][2] < stage_id:
                        AnalysisHub.rdd_usage_lifetime_dict[rdd.id] = (AnalysisHub.rdd_usage_lifetime_dict[rdd_id][0], AnalysisHub.rdd_usage_lifetime_dict[rdd_id][1], rdd.stage_id, rdd.job_id)
            if rdd.id == rdd_id: 
                if rdd.stage_id == stage_id:
                    if rdd.is_cached:
                        if rdd_id in AnalysisHub.computed_rdds: #already cached
                            return
                        AnalysisHub.computed_rdds.add(rdd_id) #cached for the first time
                        AnalysisHub.stage_computed_rdds[stage_id].add(rdd_id)
                    else:
                        if rdd_id in AnalysisHub.computed_rdds: #handeling unpersistance
                            AnalysisHub.computed_rdds.remove(rdd_id)
                        AnalysisHub.stage_computed_rdds[stage_id].add(rdd_id)
                    for parent_id in rdd.parents_lst:
                        if Analyzer.is_narrow_transformation(rdd.id, parent_id):
                            Analyzer.add_rdd_and_its_parents_if_it_is_computed_in_stage(parent_id, stage_id)

    def calc_num_of_computations_of_rdds():
        AnalysisHub.rdd_usage_lifetime_dict = {}
        for stage_id in sorted(FactHub.submitted_stage_last_rdd_dect):
            id_of_last_rdd_in_stage = FactHub.submitted_stage_last_rdd_dect[stage_id]
            AnalysisHub.stage_computed_rdds[stage_id] = set()
            AnalysisHub.stage_used_rdds[stage_id] = set()
            Analyzer.add_rdd_and_its_parents_if_it_is_computed_in_stage(id_of_last_rdd_in_stage, stage_id)            
        for stage_id in AnalysisHub.stage_computed_rdds:
            for rdd_id in AnalysisHub.stage_computed_rdds[stage_id]:
                AnalysisHub.rdd_num_of_computations[rdd_id] += 1

    def prepare_anomalies_dict():
        for rdd in FactHub.rdds_lst:
            rdd.name, rdd.is_cached, AnalysisHub.rdd_num_of_usage[rdd.id], AnalysisHub.rdd_num_of_computations[rdd.id]
            if rdd.is_cached and AnalysisHub.rdd_num_of_usage[rdd.id] < int(config['Caching_Anomalies']['rdds_computation_tolerance_threshold']):
                AnalysisHub.anomalies_dict[rdd.id] = "unneeded cache"
            elif not rdd.is_cached and AnalysisHub.rdd_num_of_computations[rdd.id] >= int(config['Caching_Anomalies']['rdds_computation_tolerance_threshold']):
                AnalysisHub.anomalies_dict[rdd.id] = "recomputation"

    def prepare_caching_plan():
        AnalysisHub.caching_plan_lst = []
        for rdd_id, rdd_usage_lifetime in AnalysisHub.rdd_usage_lifetime_dict.items():
            if config['Caching_Anomalies']['include_caching_anomalies_in_caching_plan'] == "true" or rdd_id not in AnalysisHub.anomalies_dict:
                AnalysisHub.caching_plan_lst.append(CachingPlanItem(rdd_usage_lifetime[0], rdd_usage_lifetime[1], rdd_id, True))
                AnalysisHub.caching_plan_lst.append(CachingPlanItem(rdd_usage_lifetime[2], rdd_usage_lifetime[3], rdd_id, False))        
        AnalysisHub.memory_footprint_lst = []
        incremental_rdds_set = set()
        for caching_plan_item in sorted(AnalysisHub.caching_plan_lst):
            if caching_plan_item.is_cache_item:
                incremental_rdds_set.add(caching_plan_item.rdd_id)
            else:
                incremental_rdds_set.remove(caching_plan_item.rdd_id)
            AnalysisHub.memory_footprint_lst.append((caching_plan_item.job_id, caching_plan_item.stage_id, (incremental_rdds_set.copy())))
            
    def analyze_caching_anomalies():
        for rdd in FactHub.rdds_lst:
            if rdd.id in AnalysisHub.cached_rdds_set:
                rdd.is_cached = True
            if rdd.id in AnalysisHub.non_cached_rdds_set:
                rdd.is_cached = False
        Analyzer.calc_num_of_computations_of_rdds()
        Analyzer.prepare_anomalies_dict() 
        Analyzer.prepare_caching_plan() 


class SparkDataflowVisualizer():

    def init():
        AnalysisHub.cached_rdds_set.clear()
        AnalysisHub.non_cached_rdds_set.clear()
        FactHub.flush()
        AnalysisHub.flush()
    
    def parse(raw_log_file):
        Parser.prepare(raw_log_file)
        
    def analyze():
        AnalysisHub.flush()
        Analyzer.prepare_transformations_lst()
        Analyzer.analyze_caching_anomalies()

    def visualize_property_DAG():         
        dot = graphviz.Digraph(strict=True, comment='Spark-Application-Graph', format = config['Output']['selected_format'])
        dot.attr('node', shape=config['Drawing']['rdd_shape'], label='this is graph')
        dot.node_attr={'shape': 'plaintext'}
        dot.edge_attr.update(arrowhead='normal', arrowsize='1')
        dag_rdds_set = set()
        prev_action_name = ""
        iterations_count = int(config['Drawing']['max_iterations_count']) 
        for job_id, job in sorted(FactHub.job_last_rdd_dect.items()):
            action_name = job[1]
            draw_iteration_indicator = False
            if action_name == prev_action_name:
                if iterations_count == 0:
                    continue
                iterations_count-=1
            else:
                iterations_count = int(config['Drawing']['max_iterations_count']) 
            for rdd in FactHub.rdds_lst:
                if rdd.job_id == job_id and rdd.id not in dag_rdds_set:
                    dag_rdds_set.add(rdd.id)
                    node_label = "\n"
                    if config['Drawing']['show_action_id'] == "true":
                        node_label = "[" + str(rdd.id) + "] " 
                    if config['Drawing']['show_rdd_name'] == "true":
                        node_label = node_label + rdd.name[:int(config['Drawing']['rdd_name_max_number_of_chars'])]
                    if config['Caching_Anomalies']['show_number_of_rdd_usage'] == "true":
                        node_label = node_label + "\nused: " + str(AnalysisHub.rdd_num_of_usage[rdd.id])
                    if config['Caching_Anomalies']['show_number_of_rdd_computations'] == "true":
                        node_label = node_label + "\ncomputed: " + str(AnalysisHub.rdd_num_of_computations[rdd.id])
                    if  config['Caching_Anomalies']['highlight_unneeded_cached_rdds'] == "true" and AnalysisHub.anomalies_dict.get(rdd.id, "") == "unneeded cache":
                        dot.node(str(rdd.id), penwidth = '3', fillcolor = config['Drawing']['cached_rdd_bg_color'], color = 'red', shape = config['Drawing']['anomaly_shape'], style = 'filled', label = node_label)
                    elif config['Caching_Anomalies']['highlight_recomputed_rdds'] == "true" and AnalysisHub.anomalies_dict.get(rdd.id, "") == "recomputation":
                        dot.node(str(rdd.id), penwidth = '3', fillcolor = 'white', color = 'red', shape = config['Drawing']['anomaly_shape'], style = 'filled', label = node_label)
                    else:
                        dot.node(str(rdd.id), fillcolor = config['Drawing']['cached_rdd_bg_color'] if rdd.is_cached else 'white', style = 'filled', label = node_label)
            action_lable = "" 
            if config['Drawing']['show_action_id'] == "true":
                action_lable = "[" + str(job_id) + "]"
            if config['Drawing']['show_action_name'] == "true":
                action_lable = action_lable + action_name[:int(config['Drawing']['action_name_max_number_of_chars'])]
            
            if draw_iteration_indicator == True:    
                draw_iteration_indicator = False
                continue
            dot.node("Action_" + str(job_id), shape=config['Drawing']['action_shape'] if iterations_count != 0 else config['Drawing']['iterative_action_shape'], fillcolor = config['Drawing']['action_bg_collor'] if iterations_count != 0 else config['Drawing']['iterative_action_collor'], style = 'filled', label = action_lable)
            dot.edge(str(job[0]), "Action_" + str(job_id), color = 'black', arrowhead = 'none', style = 'dashed')
            prev_action_name = action_name
        for transformation in sorted(AnalysisHub.transformations_set):
            if transformation.to_rdd in dag_rdds_set and transformation.from_rdd in dag_rdds_set:
                dot.edge(str(transformation.to_rdd), str(transformation.from_rdd), color = config['Drawing']['narrow_transformation_color'] if transformation.is_narrow else config['Drawing']['wide_transformation_color'])        
        caching_plan_label = "\nRecommended Schedule:\n"
        for caching_plan_item in sorted(AnalysisHub.caching_plan_lst):
            if caching_plan_item.is_cache_item:
                caching_plan_label += "\nCache "
            else:
                caching_plan_label += "\nUnpersist "
            caching_plan_label += "RDD[" + str(caching_plan_item.rdd_id) + "] " + ("at" if caching_plan_item.is_cache_item else "after") + " stage(" + str(caching_plan_item.stage_id) + ") in job(" + str(caching_plan_item.job_id) + ")\n"
        caching_plan_label += "\n"
        if len(AnalysisHub.caching_plan_lst) > 0 and config['Caching_Anomalies']['show_caching_plan'] == "true":
            dot.node("caching_plan", shape = 'note', fillcolor = 'lightgray', style = 'filled', label = caching_plan_label)
        memory_footprint_label = "\nMemory Footprint:\n"
        for memory_footprint_item in AnalysisHub.memory_footprint_lst:
            memory_footprint_label += "\n"
            if len(memory_footprint_item[2]) == 0:
                memory_footprint_label += "Free"
            else:
                memory_footprint_label += str(memory_footprint_item[2])
            memory_footprint_label += "\n"
        memory_footprint_label += "\n"
        if len(AnalysisHub.caching_plan_lst) > 0 and config['Caching_Anomalies']['show_memory_footprint'] == "true":
            dot.node("memory_footprint", shape = 'note', fillcolor = 'lightgray', style = 'filled', label = memory_footprint_label)
        dot.attr(labelloc="t")
        dot.attr(label=FactHub.app_name)
        dot.attr(fontsize='40')
        spark_dataflow_visualizer_output_path = Utility.get_absolute_path(config['Paths']['output_path'])
        output_file_name = re.sub('[^a-zA-Z0-9]+', '', FactHub.app_name)
        dot.render(spark_dataflow_visualizer_output_path + '/' + output_file_name, view=config['Output']['view_after_render'] == 'true')
        

# Useful functions for the demonstration 

def load_file(file_name):
    spark_dataflow_visualizer_input_path = Utility.get_absolute_path(config['Paths']['input_path'])
    log_file_path = spark_dataflow_visualizer_input_path + '/' + file_name
    SparkDataflowVisualizer.init()
    SparkDataflowVisualizer.parse(log_file_path)

def draw_DAG():
    SparkDataflowVisualizer.analyze()
    SparkDataflowVisualizer.visualize_property_DAG()
    
def cache(rdd_id):
    AnalysisHub.cached_rdds_set.add(rdd_id)
    AnalysisHub.non_cached_rdds_set.discard(rdd_id)
    draw_DAG()
    
def dont_cache(rdd_id):
    AnalysisHub.non_cached_rdds_set.add(rdd_id)
    AnalysisHub.cached_rdds_set.discard(rdd_id)
    draw_DAG()

In [None]:
load_file('application_1641567765635_0122')
draw_DAG()

In [None]:
load_file('application_1641567765635_0023')
draw_DAG()

In [None]:
load_file('application_1635092038229_0122')
config.read('config.ini')
config['Caching_Anomalies']['highlight_recomputed_rdds'] = 'true'
config['Caching_Anomalies']['highlight_unneeded_cached_rdds'] = 'true'
draw_DAG()

In [119]:
cache(31)

In [120]:
dont_cache(29)

In [121]:
dont_cache(6)

In [None]:
config.read('config.ini')

In [None]:
load_file('application_1641567765635_0161')
draw_DAG()

In [124]:
cache(35)
cache(53)

In [None]:
load_file('local-1641586266617')
draw_DAG()

In [None]:
load_file('application_1635092038229_0130')
draw_DAG()

In [None]:
load_file('application_1635092038229_0126')
draw_DAG()

In [None]:
load_file('application_1635092038229_0144')
draw_DAG()

In [None]:
load_file('application_1635092038229_0140')
draw_DAG()

In [130]:
config['Drawing']['max_iterations_count'] = '5'
draw_DAG()

In [131]:
config['Caching_Anomalies']['rdds_computation_tolerance_threshold'] = '4'
draw_DAG()

In [132]:
dont_cache(217)

In [133]:
cache(6)

In [134]:
dont_cache(2)

In [135]:
cache(1)

In [None]:
config.read('config.ini')

In [None]:
load_file('application_1635092038229_0124')
draw_DAG()

In [138]:
config['Drawing']['max_iterations_count'] = '5'
draw_DAG()

In [139]:
config['Caching_Anomalies']['rdds_computation_tolerance_threshold'] = '3'
draw_DAG()

In [140]:
config['Caching_Anomalies']['rdds_computation_tolerance_threshold'] = '4'
draw_DAG()

In [None]:
config.read('config.ini')
config['Output']['view_after_render'] = 'false'

file_list = os.listdir(config['Paths']['input_path'])
for os_file_name in file_list:
    load_file(os_file_name)
    draw_DAG()