<a href="https://colab.research.google.com/github/heetae185/Algorithms/blob/main/A_Star_algorithm_Pyspark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --quiet pyspark
!pip install --quiet graphframes
!apt install openjdk-8-jdk-headless -qq

In [None]:
!curl -L -o "/usr/local/lib/python3.7/dist-packages/pyspark/jars/graphframes0.8.2-spark3.2-s_2.12.jar" http://dl.bintray.com/sparkpackages/maven/graphframes/graphframes/0.8.2-spark3.2-s_2.12/graphframes-0.8.2-spark3.2-s_2.12.jar

In [None]:
import os
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-8-openjdk-amd64'

import pyspark
from pyspark.sql import *
import pyspark.sql.functions as F
from pyspark import SparkContext, SparkConf

from pyspark.sql.types import *
from graphframes import *

In [None]:
# sc = pyspark.SparkContext()
spark = SparkSession.builder.master('local[*]').config('spark.jars.packages', 'graphframes:graphframes:0.8.2-spark3.2-s_2.12').getOrCreate()

In [None]:
!gdown 1wMZdMz3m4OIHNH-b3zgEaZrNV2GdNpre # USA-road-d.NY.gr
!gdown 1P3BiMuJvMOqdhVsRI9E1EfKsd7cLrGBs # USA-road-d.NY.co

In [None]:
def text_preprocess(filename, firstletter):
  text_col = ''
  with open(filename, 'r') as f:   
    while True:
      line = f.readline()
      if not line : break
      if line[0] == firstletter:
        text = line.split(' ')
        text_line = text[1] + ',' + text[2] + ',' + text[3]
        text_col += text_line
  f.close()
  filename_string = filename.split('.')
  filename = filename_string[0] + '_processed.' + filename_string[1] + '.' + filename_string[2]
  with open(filename, 'w') as f:
    f.write(text_col)
  f.close()

In [None]:
text_preprocess('USA-road-d.NY.co', 'v')
text_preprocess('USA-road-d.NY.gr', 'a')

In [None]:
import pandas as pd
co = pd.read_csv('USA-road-d_processed.NY.co')
co.head(5)

In [None]:
# 그래프 프레임 선언
def create_transport_graph():
  node_fields = [
      StructField("id", IntegerType(), True),
      StructField("longitude", LongType(), True),
      StructField("latitude", LongType(), True)
  ]
  nodes = spark.read.csv("USA-road-d_processed.NY.co", header=False, sep=',', schema=StructType(node_fields))

  rel_fields = [
      StructField("src", IntegerType(), True),
      StructField("dst", IntegerType(), True),
      StructField("distance", IntegerType(), True)
  ]
  rels = spark.read.csv("USA-road-d_processed.NY.gr", header=False, sep=',', schema=StructType(rel_fields))

  return GraphFrame(nodes, rels)

In [None]:
g = create_transport_graph()

In [None]:
g.vertices.show()

In [None]:
g.edges.show()

In [None]:
o_table = g.vertices.alias('o_table')
i_table = g.vertices.alias('i_table')
new_vertices = g.vertices.alias('vertices')

In [None]:
new_vertices.show()

In [None]:
import math
# 두 노드 간 거리 구하는 함수
def get_distance(src_id, dst_id):
  src_lon = g.vertices.filter(g.vertices.id == src_id).first().longitude
  src_lat = g.vertices.filter(g.vertices.id == src_id).first().latitude
  dst_lon = g.vertices.filter(g.vertices.id == dst_id).first().longitude
  dst_lat = g.vertices.filter(g.vertices.id == dst_id).first().latitude
  return math.sqrt((src_lon - dst_lon)**2 + (src_lat - dst_lat)**2)

In [None]:
# 자신의 node와 연결되어 있는 node 구하기
def directed_list(node_id, dst_id):
  edge = g.edges.filter(g.edges.src == node_id)
  direct_list = new_vertices.join(edge, new_vertices.id == edge.src, 'inner').select('dst', 'longitude', 'latitude', 'distance', 'id').withColumnRenamed('id', 'parentNode').withColumnRenamed('dst', 'id').withColumnRenamed('distance', 'gscore')
  expected_remain_distance = [(direct.id, get_distance(direct.id, dst_id)) for direct in direct_list.collect()]
  schema = StructType([
      StructField('id', IntegerType(), True),
      StructField('hscore', FloatType(), True),
  ])
  rdd = spark.sparkContext.parallelize(expected_remain_distance)
  hscore = spark.createDataFrame(rdd,schema)
  direct_list = direct_list.join(hscore, on='id', how='inner').select('id', 'gscore', 'hscore', 'parentNode')
  direct_list = direct_list.withColumn('fscore', direct_list.hscore + direct_list.gscore).select('id', 'fscore', 'gscore', 'hscore', 'parentNode')
  return direct_list

In [None]:
def path_finder(src_id, dst_id):

  schema = StructType([
      StructField('id', IntegerType(), True),
      StructField('fscore', FloatType(), True),
      StructField('gscore', FloatType(), True),
      StructField('hscore', FloatType(), True),
      StructField('parentNode', IntegerType(), True)
  ])

  close_list = g.vertices.filter(g.vertices.id == src_id)\
  .withColumn('fscore', F.lit(float('inf'))).withColumn('gscore', F.lit(float('inf'))).withColumn('hscore', F.lit(float('inf'))).withColumn('parentNode', F.lit(0))\
  .select('id', 'fscore', 'gscore', 'hscore', 'parentNode')

  open_list = directed_list(src_id, dst_id).sort('fscore')
  opened_id = open_list.first().id

  def update_open_list(close_list, closed_id, open_list):
    candidate_list = directed_list(closed_id, dst_id)
    close_list_id = [c.id for c in close_list.collect()]
    union_list = open_list.union(candidate_list)
    union_list.sort('fscore').coalesce(1).dropDuplicates(['id'])
    open_list = union_list.filter(~union_list.id.isin(close_list_id)).sort('fscore')
    opened_id = open_list.first().id
    return open_list, opened_id, close_list

  def update_close_list(open_list, opened_id, close_list):
    new_close = [(open_list.first().id, open_list.first().fscore, float(open_list.first().gscore), open_list.first().hscore, open_list.first().parentNode)]
    open_top = spark.createDataFrame(new_close, schema)
    close_list = close_list.union(open_top)
    closed_id = new_close[0][0]
    open_list = open_list.filter(open_list.id != closed_id)
    return close_list, closed_id, open_list

  while True:
    close_list, closed_id, open_list = update_close_list(open_list, opened_id, close_list)
    open_list, opened_id, close_list = update_open_list(close_list, closed_id, open_list)
    print('-------------------')
    print(close_list.collect())
    print('-------------------')
    print(open_list.collect())
    if dst_id in [close.id for close in close_list.collect()]:
      route = [dst_id]
      key_id = dst_id
      while key_id != src_id:
        temp_id = close_list.filter(close_list.id == key_id).first().parentNode
        key_id = temp_id
        route.append(key_id)
      route = list(reversed(route))
      break

  return route

In [None]:
def path_find_bot():
  print('---------------------------------------------------------')
  start = input('출발 노드 : ')
  end = input('도착 노드 : ')
  print('경로 : ', path_finder(start, end))
  print('---------------------------------------------------------')

In [None]:
path_find_bot()