In [127]:
from pathlib import Path

from neo4j import GraphDatabase
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

In [128]:
# log in to neo4j db
URI = "bolt://localhost:7687"
driver = GraphDatabase.driver(URI, auth=("neo4j", "password"))

In [129]:
def check_connection(driver):
    with driver.session() as session:
        result = session.run("CALL dbms.components()")
        for record in result:
            print(record)


check_connection(driver)

<Record name='Neo4j Kernel' versions=['5.26.0'] edition='community'>


In [130]:
def empty_db(driver):
    with driver.session() as session:
        count_query = "MATCH (n) RETURN count(n) as total"
        total_nodes = session.run(count_query).single()["total"]

        if total_nodes == 0:
            return

        progress = tqdm(total=total_nodes, desc="Deleting nodes", unit="nodes")

        # Delete in batches
        batch_size = 1_000
        while True:
            delete_query = """
            MATCH (n)
            WITH n LIMIT $batch_size
            DETACH DELETE n
            RETURN count(n) as deleted
            """

            result = session.run(delete_query, batch_size=batch_size)
            deleted_count = result.single()["deleted"]

            progress.update(deleted_count)

            if deleted_count == 0:
                break

        progress.close()


empty_db(driver)

Deleting nodes:   0%|          | 0/15702 [00:00<?, ?nodes/s]

In [131]:
DATASET_PATH = Path("dataset")
airlines_csv = DATASET_PATH / "airlines.csv"
countries_csv = DATASET_PATH / "countries.csv"
planes_csv = DATASET_PATH / "planes.csv"
routes_csv = DATASET_PATH / "routes.csv"
terminals_csv = DATASET_PATH / "terminals.csv"

In [132]:
airlines_df = pd.read_csv(airlines_csv)
countries_df = pd.read_csv(countries_csv)
planes_df = pd.read_csv(planes_csv)
routes_df = pd.read_csv(routes_csv)
terminals_df = pd.read_csv(terminals_csv)

In [133]:
# for each df replace '\N' with None
airlines_df = airlines_df.map(lambda x: np.nan if x == "\\N" else x)
countries_df = countries_df.map(lambda x: np.nan if x == "\\N" else x)
planes_df = planes_df.map(lambda x: np.nan if x == "\\N" else x)
routes_df = routes_df.map(lambda x: np.nan if x == "\\N" else x)
terminals_df = terminals_df.map(lambda x: np.nan if x == "\\N" else x)

## Countries

In [134]:
countries_df

Unnamed: 0,name,iso_code,dafif_code
0,"Bonaire, Saint Eustatius and Saba",BQ,
1,Aruba,AW,AA
2,Antigua and Barbuda,AG,AC
3,United Arab Emirates,AE,AE
4,Afghanistan,AF,AF
...,...,...,...
256,Samoa,WS,WS
257,Eswatini,SZ,WZ
258,Yemen,YE,YM
259,Zambia,ZM,ZA


In [135]:
# upload each country to the db
for _, row in tqdm(countries_df.iterrows(), total=len(countries_df)):
    # if any value is '\N' then replace it with None
    row = row.apply(lambda x: None if x == "\\N" else x)
    with driver.session() as session:
        session.run(
            "CREATE (c:Country {name: $name, iso_code: $iso_code, dafif_code: $dafif_code})",
            name=row["name"],
            iso_code=row["iso_code"],
            dafif_code=row["dafif_code"],
        )

  0%|          | 0/261 [00:00<?, ?it/s]

## Airlines

In [136]:
airlines_df

Unnamed: 0,airline_id,name,alias,iata,icao,callsign,country,active
0,-1,Unknown,,-,,,,Y
1,1,Private flight,,-,,,,Y
2,2,135 Airways,,,GNL,GENERAL,United States,N
3,3,1Time Airline,,1T,RNX,NEXTIME,South Africa,Y
4,4,2 Sqn No 1 Elementary Flying Training School,,,WYT,,United Kingdom,N
...,...,...,...,...,...,...,...,...
6157,21248,GX Airlines,,,CBG,SPRAY,China,Y
6158,21251,Lynx Aviation (L3/SSX),,,SSX,Shasta,United States,N
6159,21268,Jetgo Australia,,JG,,,Australia,Y
6160,21270,Air Carnival,,2S,,,India,Y


In [137]:
airlines_df.columns

Index(['airline_id', 'name', 'alias', 'iata', 'icao', 'callsign', 'country',
       'active'],
      dtype='object')

In [138]:
with driver.session() as session:
    for _, row in tqdm(airlines_df.iterrows(), total=len(airlines_df)):
            session.run(
                "CREATE (a:Airline {id: $airline_id, name: $name, alias: $alias, iata: $iata, icao: $icao, callsign: $callsign, active: $active})",
                airline_id=row["airline_id"],
                name=row["name"],
                alias=row["alias"],
                iata=row["iata"],
                icao=row["icao"],
                callsign=row["callsign"],
                # country=row["country"],
                active=row["active"],
            )
            # math the airline with its country
            session.run(
                """
                MATCH (a:Airline {iata: $iata}),
                      (c:Country {iso_code: $country})
                CREATE (a)-[:BASED_IN]->(c)
                """,
                iata=row["iata"],
                country=row["country"],
            )

  0%|          | 0/6162 [00:00<?, ?it/s]

## Planes

In [139]:
planes_df

Unnamed: 0,name,iata,icao
0,Aerospatiale (Nord) 262,ND2,N262
1,Aerospatiale (Sud Aviation) Se.210 Caravelle,CRV,S210
2,Aerospatiale SN.601 Corvette,NDC,S601
3,Aerospatiale/Alenia ATR 42-300,AT4,AT43
4,Aerospatiale/Alenia ATR 42-500,AT5,AT45
...,...,...,...
241,Tupolev Tu-144,,T144
242,Tupolev Tu-154,TU5,T154
243,Tupolev Tu-204,T20,T204
244,Yakovlev Yak-40,YK4,YK40


In [140]:
planes_df.columns

Index(['name', 'iata', 'icao'], dtype='object')

In [141]:
with driver.session() as session:
    for _, row in tqdm(planes_df.iterrows(), total=len(planes_df)):
            session.run(
                "CREATE (p:Plane {name: $name, iata: $iata, icao: $icao})",
                name=row["name"],
                iata=row["iata"],
                icao=row["icao"],
            )

  0%|          | 0/246 [00:00<?, ?it/s]

## Terminals

In [142]:
terminals_df

Unnamed: 0,airport_id,name,city,country,iata,icao,latitude,longitude,altitude,timezone,dst,tz,type,source
0,1,Goroka Airport,Goroka,Papua New Guinea,GKA,AYGA,-6.081690,145.391998,5282,10,U,Pacific/Port_Moresby,airport,OurAirports
1,2,Madang Airport,Madang,Papua New Guinea,MAG,AYMD,-5.207080,145.789001,20,10,U,Pacific/Port_Moresby,airport,OurAirports
2,3,Mount Hagen Kagamuga Airport,Mount Hagen,Papua New Guinea,HGU,AYMH,-5.826790,144.296005,5388,10,U,Pacific/Port_Moresby,airport,OurAirports
3,4,Nadzab Airport,Nadzab,Papua New Guinea,LAE,AYNZ,-6.569803,146.725977,239,10,U,Pacific/Port_Moresby,airport,OurAirports
4,5,Port Moresby Jacksons International Airport,Port Moresby,Papua New Guinea,POM,AYPY,-9.443380,147.220001,146,10,U,Pacific/Port_Moresby,airport,OurAirports
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12663,14107,Ulan-Ude East Airport,Ulan Ude,Russia,,XIUW,51.849998,107.737999,1670,,,,airport,OurAirports
12664,14108,Krechevitsy Air Base,Novgorod,Russia,,ULLK,58.625000,31.385000,85,,,,airport,OurAirports
12665,14109,Desierto de Atacama Airport,Copiapo,Chile,CPO,SCAT,-27.261200,-70.779198,670,,,,airport,OurAirports
12666,14110,Melitopol Air Base,Melitopol,Ukraine,,UKDM,46.880001,35.305000,0,,,,airport,OurAirports


In [143]:
terminals_df.columns

Index(['airport_id', 'name', 'city', 'country', 'iata', 'icao', 'latitude',
       'longitude', 'altitude', 'timezone', 'dst', 'tz', 'type', 'source'],
      dtype='object')

In [144]:
with driver.session() as session:
    for _, row in tqdm(terminals_df.iterrows(), total=len(terminals_df)):
        session.run(
            "CREATE (t:Terminal {airport_id: $airport_id, name: $name, iata: $iata, icao: $icao, latitude: $latitude, longitude: $longitude, altitude: $altitude, type: $type, source: $source})",
            airport_id=row["airport_id"],
            name=row["name"],
            # city=row["city"],
            # country=row["country"],
            iata=row["iata"],
            icao=row["icao"],
            latitude=row["latitude"],
            longitude=row["longitude"],
            altitude=row["altitude"],
            timezone=row["timezone"],
            # dst=row["dst"],
            # tz=row["tz"],
            type=row["type"],
            source=row["source"],
        )
        # create timezone node
        session.run("CREATE (tz:Timezone {name: $name})", name=row["timezone"])
        # create city node
        session.run("CREATE (c:City {name: $city})", city=row["city"])
        # match city with timezone
        session.run(
            """
            MATCH (c:City {name: $city}),
                  (tz:Timezone {name: $timezone})
            CREATE (c)-[:USES]->(tz)
            """,
            city=row["city"],
            timezone=row["timezone"],
        )
        # match terminal with timezone
        session.run(
            """
            MATCH (t:Terminal {iata: $iata}),
                  (tz:Timezone {name: $timezone})
            CREATE (t)-[:USES]->(tz)
            """,
            iata=row["iata"],
            timezone=row["timezone"],
        )
        # match terminal with city
        session.run(
            """
            MATCH (t:Terminal {iata: $iata}),
                  (c:City {name: $city})
            CREATE (t)-[:LOCATED_IN]->(c)
            """,
            iata=row["iata"],
            city=row["city"],
        )
        # match terminal with country
        session.run(
            """
            MATCH (t:Terminal {iata: $iata}),
                  (c:Country {iso_code: $country})
            CREATE (t)-[:LOCATED_IN]->(c)
            """,
            iata=row["iata"],
            country=row["country"],
        )
        # match city with country
        session.run(
            """
            MATCH (c:City {name: $city}),
                  (c:Country {iso_code: $country})
            CREATE (c)-[:LOCATED_IN]->(c)
            """,
            city=row["city"],
            country=row["country"],
        )

  0%|          | 0/12668 [00:00<?, ?it/s]

## Routes

In [145]:
routes_df

Unnamed: 0,airline,airline_id,source_airport,source_airport_id,destination_airport,destination_airport_id,codeshare,stops,equipment
0,2B,410,AER,2965,KZN,2990,,0,CR2
1,2B,410,ASF,2966,KZN,2990,,0,CR2
2,2B,410,ASF,2966,MRV,2962,,0,CR2
3,2B,410,CEK,2968,KZN,2990,,0,CR2
4,2B,410,CEK,2968,OVB,4078,,0,CR2
...,...,...,...,...,...,...,...,...,...
67658,ZL,4178,WYA,6334,ADL,3341,,0,SF3
67659,ZM,19016,DME,4029,FRU,2912,,0,734
67660,ZM,19016,FRU,2912,DME,4029,,0,734
67661,ZM,19016,FRU,2912,OSS,2913,,0,734


In [146]:
routes_df.columns

Index(['airline', 'airline_id', 'source_airport', 'source_airport_id',
       'destination_airport', 'destination_airport_id', 'codeshare', 'stops',
       'equipment'],
      dtype='object')

In [147]:
with driver.session() as session:
    for _, row in tqdm(routes_df.iterrows(), total=len(routes_df)):
        session.run(
            """
            MATCH (source:Terminal {iata: $source_airport}),
                  (destination:Terminal {iata: $destination_airport}),
                  (airline:Airline {iata: $airline})
            CREATE (source)-[:FLIES_TO {iata: $airline, stops: $stops, equipment: $equipment}]->(destination)
            """,
            source_airport=row["source_airport"],
            destination_airport=row["destination_airport"],
            airline=row["airline"],
            stops=row["stops"],
            equipment=row["equipment"],
        )
        # match equipment to airline
        session.run(
            """
            MATCH (p:Plane {iata: $equipment}),
                  (a:Airline {iata: $airline})
            CREATE (a)-[:OPERATES]->(p)
            """,
            equipment=row["equipment"],
            airline=row["airline"],
        )

  0%|          | 0/67663 [00:00<?, ?it/s]