In [1]:
import numpy as np
import pandas as pd

In [2]:
DATA_PATH = 'romania_cities.xlsx'

def load_matrix(path=None):
    path = path or DATA_PATH
    adj_matrix = pd.read_excel(path, sheet_name='matrix')
    adj_matrix = adj_matrix.set_index('source')
    np.fill_diagonal(adj_matrix.values, 0)
    return adj_matrix

def load_heuristic(path=None):
    path = path or DATA_PATH
    heuristic = pd.read_excel(path, sheet_name='heuristic')
    heuristic = heuristic.set_index('source').line_distance
    return heuristic

In [3]:
adj_matrix = load_matrix()
adj_matrix

Unnamed: 0_level_0,arad,bucharest,craiova,dobreta,eforie,fagaras,giurgiu,hirsova,iasi,lugoj,mehadia,neamt,oradea,pitesti,rimnicu,sibiu,timisoara,urziceni,vaslui,zerind
source,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
arad,0.0,,,,,,,,,,,,,,,140.0,118.0,,,75.0
bucharest,,0.0,,,,,90.0,,,,,,,101.0,,,,85.0,,
craiova,,,0.0,120.0,,,,,,,,,,138.0,146.0,,,,,
dobreta,,,120.0,0.0,,,,,,,75.0,,,,,,,,,
eforie,,,,,0.0,,,86.0,,,,,,,,,,,,
fagaras,,211.0,,,,0.0,,,,,,,,,,99.0,,,,
giurgiu,,90.0,,,,,0.0,,,,,,,,,,,,,
hirsova,,,,,86.0,,,0.0,,,,,,,,,,98.0,,
iasi,,,,,,,,,0.0,,,87.0,,,,,,,92.0,
lugoj,,,,,,,,,,0.0,70.0,,,,,,111.0,,,


In [4]:
heuristic = load_heuristic()
heuristic

source
arad         366
bucharest      0
craiova      160
dobreta      242
eforie       161
fagaras      178
giurgiu       77
hirsova      151
iasi         226
lugoj        244
mehadia      241
neamt        234
oradea       380
pitesti       98
rimnicu      193
sibiu        253
timisoara    329
urziceni      80
vaslui       199
zerind       374
Name: line_distance, dtype: int64

In [5]:
def neighbor(src, adj_matrix):
    idx = np.where(adj_matrix.loc[src].notnull())
    res = adj_matrix.columns[idx]
    res = res[res != src].tolist()
    return res

print(neighbor('arad', adj_matrix))
print(neighbor('bucharest', adj_matrix))

['sibiu', 'timisoara', 'zerind']
['giurgiu', 'pitesti', 'urziceni']


In [6]:
def a_star(src, dst, adj_matrix, heuristic):
    closed = set()
    opened = pd.Series(np.nan, index=heuristic.index, name='opened')
    parent = pd.Series(np.nan, index=heuristic.index, name='parent')
    f = pd.Series(np.nan, index=heuristic.index, name='f')
    g = pd.Series(np.nan, index=heuristic.index, name='g')

    parent[src] = src
    opened[src] = heuristic[src]
    f[src] = heuristic[src]
    g[src] = 0
    found = False

    # while opened is not empty
    while not opened.isnull().all():
        q = opened.idxmin()
        opened[q] = np.nan # remove q from opened

        neighbors = neighbor(q, adj_matrix)
        for v in neighbors:
            if v == dst:
                found = True
                break

            if v in closed:
                continue

            new_g = g[q] + adj_matrix.loc[v, q]
            new_f = new_g + heuristic[v]

            if np.isnan(opened[v]) or new_f < opened[v]:
                # if v not in opened or in opened but with higher f than the new f
                print(f'Updated {q} -> {v}, new_f: {new_f}, new_g: {new_g}')
                g[v] = new_g
                f[v] = new_f
                opened[v] = new_f
                parent[v] = q
            
        closed.add(q)

    return f, g, parent

In [7]:
src = 'timisoara'
dst = 'bucharest'

f, g, parent = a_star(src=src, dst=dst, adj_matrix=adj_matrix, heuristic=heuristic)

Updated timisoara -> arad, new_f: 484.0, new_g: 118.0
Updated timisoara -> lugoj, new_f: 355.0, new_g: 111.0
Updated lugoj -> mehadia, new_f: 422.0, new_g: 181.0
Updated mehadia -> dobreta, new_f: 498.0, new_g: 256.0
Updated arad -> sibiu, new_f: 511.0, new_g: 258.0
Updated arad -> zerind, new_f: 567.0, new_g: 193.0
Updated dobreta -> craiova, new_f: 536.0, new_g: 376.0
Updated sibiu -> fagaras, new_f: 535.0, new_g: 357.0
Updated sibiu -> oradea, new_f: 789.0, new_g: 409.0
Updated sibiu -> rimnicu, new_f: 531.0, new_g: 338.0
Updated rimnicu -> pitesti, new_f: 533.0, new_g: 435.0
Updated zerind -> oradea, new_f: 644.0, new_g: 264.0


In [18]:
cur = 'fagaras'
dist = 0
while cur != src:
    print(cur, end=' -> ')
    dist += adj_matrix.loc[cur, parent[cur]]
    cur = parent[cur]
print(cur, dist)

fagaras -> sibiu -> arad -> timisoara 357.0
