In [None]:
from conf import *
import random
import pyspark
import time
import math
import numpy as np
import csv
import pandas as pd
from subprocess import check_output
from sys import exit
import subprocess

from datetime import date, datetime, timedelta

from pyspark.sql import Row, DataFrame, SparkSession
from pyspark import SparkConf, SparkContext
from pyspark import StorageLevel
from operator import add
from sklearn.metrics.pairwise import haversine_distances
from interval_tree import Interval, Node, IntervalTree
from sklearn.neighbors import BallTree

from pathlib import Path
from event import event
from node2 import node2

class node:
    type = ''
    refined_type = ''
    zip = ''
    childs = []
    parents = []
    visited = False
    def __init__(self, type, refined_type, zip, childs, parents):
        self.type         = type
        self.refined_type = refined_type
        self.zip          = zip
        self.childs       = (list(childs) if len(childs)!= 0 else [])
        self.parents      = (list(parents) if len(parents)!= 0 else [])
        self.visited      = False


def event_transformer(line: str) -> event:
    r = line.replace('\r', '').replace('\n', '').split(',')
    childs = set()
    parents = set()
    if len(r[16])>0:
        childs = set(r[16].split(';'))
    if len(r[17])>0:
        parents = set(r[17].split(';'))
            
    if r[1] == 'W':
        e = event(eventId=r[0], type='W', refinedType=r[2], startTime=r[3], endTime=r[4], locationLat=0, locationLng=0,  
                          distance=0, airportCode=r[8], number=0, street='NA', side='NA', city='NA', county='NA', state='NA', 
                          zipCode='NA', childs=childs, parents=parents)
    else:        
        e = event(eventId=r[0], type='T', refinedType=r[2], startTime=r[3], endTime=r[4], locationLat=float(r[5]), 
                          locationLng=float(r[6]), distance=float(r[7]), airportCode=r[15], number=(0 if r[9]=='N/A' or r[9]=='' else int(r[9])), 
                          street=r[10], side=r[11], city=r[12], county=r[13], state=r[14], zipCode=r[15], childs=childs, parents=parents)
    return e

def str_transformer(e: event) -> str:
    child, parents = list(e.childs), list(e.parents)
    child_str = ';'.join(child) if len(child) > 0 else ''
    parent_str = ';'.join(parents) if len(parents) > 0 else ''
    if e.type == 'W':
        res = f'{e.eventId},{e.type},{e.refinedType},{e.startTime},{e.endTime},N/A,N/A,N/A,{e.airportCode},N/A,N/A,N/A,N/A,N/A,N/A,N/A,{child_str},{parent_str}'

    else:
        res = f'{e.eventId},{e.type},{e.refinedType},{e.startTime},{e.endTime},{e.locationLat},{e.locationLng},{e.distance},{e.airportCode},{e.number},{e.street},{e.side},{e.city},{e.county},{e.state},{e.zipCode},{child_str},{parent_str}'

    return res


sc.stop()
# Create new config
conf = pyspark.SparkConf().setAll([("spark.driver.maxResultSize", '16g'), ('spark.executor.memoryOverhead', '16g'), ('spark.executor.memory', '16g')])

SparkSession.builder.config(conf=conf)
spark = SparkSession.builder.appName('test_05_11_1').getOrCreate()
sc = spark.sparkContext
sc.addPyFile('file:/event.py')
sc.addPyFile('file:/node2.py')

storage_level = pyspark.StorageLevel.MEMORY_AND_DISK
sc.setLogLevel("OFF")

#print(spark.sparkContext.getConf().getAll())

if SUBSET:
    f_name = f'_{CITY}_TTR-{trTimeThresh}'
    res = input(f'You are using a SUBSET ({CITY}) with trTimeThresh = {trTimeThresh}. Are you sure? [y/n]:')
else:
    res = input(f'You are using ENTIRE DATASET with trTimeThresh = {trTimeThresh}. Are you sure? [y/n]:')
    f_name = f'_TTR-{trTimeThresh}'

if(res != 'y'):
    exit()

start = time.time()    

traffic_path = f'file:/datasets/TrafficEvents{f_name}.csv'
weather_path = f'file:/datasets/WeatherEvents{f_name}.csv'


'''This function extracts and return all the existing sequences, given a node as root node'''
def iterativePatternChainFinder(root, _e, node_dict): 
    
    finalSequences = []
    seq_current = [[root]]
    seq_next = []
    
    while len(seq_current)>0: 
                                                         
        for s in seq_current:   
            if s[-1] not in node_dict: continue
            e = node_dict[s[-1]]
            st = set()
            st.update(s)
            flag = True
            
            if len(e.childs) > 0:
                for c in e.childs:
                    if c not in node_dict: continue
                    _c = node_dict[c]
                    if _c.visited: continue
                    _c.visited = True
                    node_dict[c] = _c
                    
                    if c not in st:
                        flag = False
                        seq_next.append(s + [c])
                                            
            if flag or len(e.childs)==0:
                seq = []
                for _s in s: 
                    if _s not in node_dict: continue
                    seq.append(node_dict[_s].refined_type + '_' + _s)                    
                finalSequences.append(seq)
                                                                                                        
        seq_current = seq_next    
        seq_next = []    
        
    return finalSequences, node_dict


def findSequences(events):
    new_events = list(set().union(*events[0]))
    new_events = new_events + events[1]
    zip_to_sequences = {}
    
    node_dict = {}
    
    for e in new_events:
        node_dict[e.eventId] = node(e.type, e.refinedType, e.zipCode, e.childs, e.parents)
    
    for e in new_events:                
        
        if len(e.childs) == len(e.parents) == 0: #no child and parent
            continue
        if len(e.parents) > 0: #if some event has a parent, this means it is already processed or will be
            continue
        
        seq = []
        c_idx = []
        idx = 0
        for c in e.childs: #c is an event without parent
            if c not in node_dict: continue
            _c = node_dict[c] 
            if _c.visited: continue
            _c.visited = True
            node_dict[c] = _c
            
            seqSet, node_dict = iterativePatternChainFinder(c, e, node_dict)
            
            for s in seqSet:
                seq.append(([e.refinedType + '_' + e.eventId] + s))
                c_idx.append(idx)
            idx += 1
        
        
        idx = 0
        for s in seq:
            if e.type == 'T':
                z = e.zipCode
            else:
                e_child_list = list(e.childs)
                if e_child_list[c_idx[idx]] not in node_dict: continue
                z = node_dict[e_child_list[c_idx[idx]]].zip
                
            seqs = []
            if z in zip_to_sequences:
                seqs = zip_to_sequences[z]
            seqs.append(s)
            zip_to_sequences[z] = seqs
            
            idx += 1
    
    final_list = []
    for z in zip_to_sequences:
        value = zip_to_sequences[z]
        final_list.append((z,value))
        
    return final_list

def createBasicUnorderedRootedTreeStructures(data):
    zipCode = data[0]
    if len(data)>1:
        sequences = data[1]
    else:
        sequences = []
    nodes = {}
    roots = []
    
    for s in sequences:
        for i in range(len(s)):
            if s[i] in nodes: 
                n = nodes[s[i]]
                if i < len(s)-1: 
                    n.childs.add(s[i+1])                        
            else: 
                cSet = set()
                if i < len(s)-1: cSet.add(s[i+1])
                n = node2(s[i].split('_')[0], (-1 if (i-1)<0 else s[i-1]), cSet)
                if n.parent == -1: roots.append(s[i]) 

            nodes[s[i]] = n     
            
    return [[zipCode] + [nodes] + [roots]]


def createLabelToCode():
    labelToCode = {}
    
    labelToCode['Snow-Light'] = '1'
    labelToCode['Snow-Moderate'] = '1'
    labelToCode['Snow-Heavy'] = '1'
    
    labelToCode['Rain-Light'] = '2'
    labelToCode['Rain-Moderate'] = '2'
    labelToCode['Rain-Heavy'] = '2'
    
    labelToCode['Construction'] = '3'
    labelToCode['Construction-Other'] = '3'
    labelToCode['Construction-Short'] = '3'
    
    labelToCode['Congestion'] = '4'
    labelToCode['Congestion-Fast'] = '4'
    labelToCode['Congestion-Moderate'] = '4'
    labelToCode['Congestion-Slow'] = '4'
    
    labelToCode['Event-Short'] = '5'
    labelToCode['Event-Long'] = '5'
    labelToCode['Event'] = '5'
    
    labelToCode['Fog-Moderate'] = '6'
    labelToCode['Fog-Severe'] = '6'
    
    labelToCode['Lane-Blocked'] = '7'
    labelToCode['Cold-Severe'] = '8'
    labelToCode['Other'] = '9'
    labelToCode['Storm-Severe'] = '10'
    labelToCode['Broken-Vehicle'] = '11'
    labelToCode['Incident-Weather'] = '12'
    labelToCode['Precipitation-UNK'] = '13'
    labelToCode['Hail-Other'] = '14'
    labelToCode['Incident-Other'] = '15'
    labelToCode['Incident-Flow'] = '16'
    labelToCode['Flow-Incident'] = '16'
    labelToCode['Accident'] = '17'
    
    return labelToCode
    
def convertToTreePreOrderedDfsEncoding(data):
    zipCode = data[0]
    roots = data[2]
    nodes = data[1]
    encodings = []

    for r in roots:
        list = [r]
        enc = labelToCode[nodes[r].label]
        nodes[r].color = 'b'  # as our graph is a tree, we don't need three colors; thus, just use white (w) and black (b)
        while len(list)>0:
            n = list[len(list)-1]
            flag = False                
            for c in nodes[n].childs:
                if nodes[c].color == 'w':
                    flag = True
                    nodes[c].color = 'b'
                    list.append(c)

                    #if nodes[c].label not in labelToCode: labelToCode[nodes[c].label] = str(len(labelToCode) + 1)
                    enc += ' ' + labelToCode[nodes[c].label]
                    break
            if flag: continue
            list.pop(len(list)-1)
            if n not in roots:  enc += ' -1'

        enc = str(len(encodings)+1) + ' ' + str(len(encodings)+1) + ' ' + str(len(enc.split(' '))) + ' ' + enc 
        encodings.append(enc)
    
    return [(zipCode, encodings)]


def format_airport_row(r: str) -> str:
    parts = r.split(',')
    return (str(int(parts[0])),parts[1])

print("\nLoading Traffic and Weather events...")
input_rdd = sc.textFile(traffic_path).map(event_transformer)
input_rdd2 = sc.textFile(weather_path).map(event_transformer)
input_rdd3 = sc.textFile(airport_path).filter(lambda it: not it.startswith('Zip,')).map(format_airport_row)

print("%d Traffic events\n%d Weather events" % (input_rdd.count(),input_rdd2.count()))

traffic_rdd_groupByZip = input_rdd.map(lambda e: (e.zipCode, e)).groupByKey().mapValues(list)

traffic_rdd_join = input_rdd3.join(traffic_rdd_groupByZip) #key = zipCode, value = (airportCode, [TrafficEvents])
traffic_rdd_join_map = traffic_rdd_join.map(lambda it: (it[1][0],it[1][1])) #key = airportCode, value = [TrafficEvents]
traffic_rdd_join_map_gbk = traffic_rdd_join_map.groupByKey().mapValues(list) #key = airportCode, value = [[TrafficEvents]] 

weather_rdd_groupByAirCode = input_rdd2.map(lambda e: (e.airportCode, e)).groupByKey().mapValues(list)
traffic_weather_join = traffic_rdd_join_map_gbk.join(weather_rdd_groupByAirCode) #key = airportCode, value = ([[TrafficEvents]], [WeatherEvents])

zip_to_sequences = traffic_weather_join.values().flatMap(lambda events: findSequences(events))

labelToCode = createLabelToCode()

zipToNodesRoots = zip_to_sequences.flatMap(createBasicUnorderedRootedTreeStructures)
zipToEncodingRDD = zipToNodesRoots.flatMap(convertToTreePreOrderedDfsEncoding)

print(f'\n{CITY} zipCodes:\n')
print(zips_to_filter)

zipToEncodingCity = zipToEncodingRDD.filter(lambda it: it[0] in zips_to_filter)


### CODE TO WRITE ENCODING LIST ###
w = open(f'encoding_lists/encodingList_TTR-{trTimeThresh}_{CITY}', 'w')
for t in zipToEncodingCity.collect():
    for tree in t[1]:
        w.write(tree + '\n')
w.close()
print(f'\nencodingList_TTR-{trTimeThresh}_{CITY} has been created')

print("\nElapsed Time: %.2fs" % (time.time()-start))