## Setup

In [None]:
import torch

In [None]:
# Install torch geometric -- for pyg
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m95.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.data import Data

  import torch_geometric.typing
  import torch_geometric.typing


In [None]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta


## Data processing

In [None]:
# Install if not already installed
# pip install kagglehub[pandas-datasets]

import kagglehub
from kagglehub import KaggleDatasetAdapter

file_path = "flights.csv"  # for example
df = kagglehub.load_dataset(KaggleDatasetAdapter.PANDAS, "mahoora00135/flights", file_path)

print("First 5 records:\n", df.head())


  df = kagglehub.load_dataset(KaggleDatasetAdapter.PANDAS, "mahoora00135/flights", file_path)


Downloading from https://www.kaggle.com/api/v1/datasets/download/mahoora00135/flights?dataset_version_number=1&file_name=flights.csv...


100%|██████████| 10.3M/10.3M [00:00<00:00, 89.9MB/s]

Extracting zip of flights.csv...





First 5 records:
    id  year  month  day  dep_time  sched_dep_time  dep_delay  arr_time  \
0   0  2013      1    1     517.0             515        2.0     830.0   
1   1  2013      1    1     533.0             529        4.0     850.0   
2   2  2013      1    1     542.0             540        2.0     923.0   
3   3  2013      1    1     544.0             545       -1.0    1004.0   
4   4  2013      1    1     554.0             600       -6.0     812.0   

   sched_arr_time  arr_delay  ... flight  tailnum origin dest air_time  \
0             819       11.0  ...   1545   N14228    EWR  IAH    227.0   
1             830       20.0  ...   1714   N24211    LGA  IAH    227.0   
2             850       33.0  ...   1141   N619AA    JFK  MIA    160.0   
3            1022      -18.0  ...    725   N804JB    JFK  BQN    183.0   
4             837      -25.0  ...    461   N668DN    LGA  ATL    116.0   

   distance  hour  minute            time_hour                    name  
0      1400     5  

In [None]:
len(df)

336776

In [None]:
import pandas as pd

# Load airports.csv without overwriting flights 'df'
airports_df = kagglehub.dataset_load(
        KaggleDatasetAdapter.PANDAS,
        "aravindram11/list-of-us-airports",
        "airports.csv")

required = {"IATA", "LATITUDE", "LONGITUDE"}
missing = required - set(airports_df.columns)
if missing:
    raise KeyError(f"airports.csv missing columns: {missing}. Got: {list(airports_df.columns)}")

air = airports_df[["IATA", "LATITUDE", "LONGITUDE"]].copy()
air["IATA"] = air["IATA"].astype(str).str.upper()
air["LATITUDE"]  = pd.to_numeric(air["LATITUDE"], errors="coerce")
air["LONGITUDE"] = pd.to_numeric(air["LONGITUDE"], errors="coerce")

# IATA codes used in flights df
used_codes = pd.Index(df["origin"].astype(str).str.upper()).union(df["dest"].astype(str).str.upper())

# Filter to only those airports and tidy
airports_in_flights = (
    air[air["IATA"].isin(used_codes)]
      .dropna(subset=["LATITUDE", "LONGITUDE"])
      .drop_duplicates(subset=["IATA"])
      .sort_values("IATA")
      .reset_index(drop=True)
)

print(airports_in_flights.head())
print(f"Total airports with coords: {len(airports_in_flights)}")

Downloading from https://www.kaggle.com/api/v1/datasets/download/aravindram11/list-of-us-airports?dataset_version_number=1&file_name=airports.csv...


100%|██████████| 22.2k/22.2k [00:00<00:00, 21.9MB/s]


  IATA   LATITUDE   LONGITUDE
0  ABQ  35.040222 -106.609194
1  ACK  41.253052  -70.060181
2  ALB  42.748119  -73.802979
3  ANC  61.174320 -149.996186
4  ATL  33.640444  -84.426944
Total airports with coords: 107


In [None]:
import ee
import geemap
import time

# Initialize Earth Engine
PROJECT = "adept-amp-477117-a1"
try:
    ee.Initialize(project=PROJECT)
except Exception:
    geemap.ee_initialize(project=PROJECT)

# Airports DataFrame -> FeatureCollection
def airports_df_to_fc(airports_df: pd.DataFrame) -> ee.FeatureCollection:
    feats = []
    for _, r in airports_df.iterrows():
        pt = ee.Geometry.Point([float(r["LONGITUDE"]), float(r["LATITUDE"])])
        feats.append(ee.Feature(pt, {
            "IATA": str(r["IATA"]).upper(),
            "LATITUDE": float(r["LATITUDE"]),
            "LONGITUDE": float(r["LONGITUDE"]),
        }))
    return ee.FeatureCollection(feats)

# Build 6‑hour ERA5 collection for an arbitrary [start,end)
def make_era5_6h_collection(start: ee.Date, end: ee.Date) -> ee.ImageCollection:
    """
    Creates an ImageCollection where each image represents one 6‑hour slot within [start, end):
      - minimum_2m_air_temperature (K)     [min(temperature_2m) over 6h]
      - maximum_2m_air_temperature (K)     [max(temperature_2m) over 6h]
      - total_precipitation       (m)      [sum of hourly totals over 6h]
      - u_component_of_wind_10m   (m/s)    [mean over 6h]
      - v_component_of_wind_10m   (m/s)    [mean over 6h]
      - instantaneous_10m_wind_gust (m/s)  [max over 6h]
      - mxtpr_kg_m2_s             (kg m^-2 s^-1) [max hourly total / 3600]
    Properties per image:
      - slot_start, slot_end (UTC, 'YYYY-MM-dd HH:mm:ss')
    """
    ic_hourly = (
        ee.ImageCollection('ECMWF/ERA5/HOURLY')
        .filterDate(start, end)
        .select([
            'temperature_2m',
            'total_precipitation',
            'u_component_of_wind_10m',
            'v_component_of_wind_10m',
            'instantaneous_10m_wind_gust'
        ])
    )

    nslots = end.difference(start, 'hour').divide(6).toInt()
    slots = ee.List.sequence(0, nslots.subtract(1))

    def build_slot(i):
        i = ee.Number(i)
        s = start.advance(i.multiply(6), 'hour')
        e = s.advance(6, 'hour')
        slot = ic_hourly.filterDate(s, e)

        tmin = slot.select('temperature_2m').reduce(ee.Reducer.min()) \
                   .rename('minimum_2m_air_temperature')
        tmax = slot.select('temperature_2m').reduce(ee.Reducer.max()) \
                   .rename('maximum_2m_air_temperature')
        tp   = slot.select('total_precipitation').reduce(ee.Reducer.sum()) \
                   .rename('total_precipitation')
        u10  = slot.select('u_component_of_wind_10m').reduce(ee.Reducer.mean()) \
                   .rename('u_component_of_wind_10m')
        v10  = slot.select('v_component_of_wind_10m').reduce(ee.Reducer.mean()) \
                   .rename('v_component_of_wind_10m')
        gust = slot.select('instantaneous_10m_wind_gust').reduce(ee.Reducer.max()) \
                   .rename('instantaneous_10m_wind_gust')

        # Equivalent to "Maximum total precipitation rate since previous post-processing":
        # max hourly total precip within the 6h block (m/hr) → kg m^-2 s^-1
        mxtpr = slot.select('total_precipitation').reduce(ee.Reducer.max()) \
                     .divide(3600).rename('mxtpr_kg_m2_s')

        return (ee.Image.cat([tmin, tmax, tp, u10, v10, gust, mxtpr])
                .set({"slot_start": s.format("YYYY-MM-dd HH:mm:ss"),
                      "slot_end":   e.format("YYYY-MM-dd HH:mm:ss")}))

    return ee.ImageCollection(slots.map(build_slot))

def make_era5_6h_collection_2013() -> ee.ImageCollection:
    return make_era5_6h_collection(ee.Date('2013-01-01T00:00'), ee.Date('2014-01-01T00:00'))

# Sample each 6‑hour image at each airport (for a given collection)
def sample_era5_6h_at_airports_from_collection(ic_6h: ee.ImageCollection,
                                               airports_fc: ee.FeatureCollection,
                                               scale_m: int = 25000) -> ee.FeatureCollection:
    def sample_one_image(image):
        samples = image.sampleRegions(
            collection=airports_fc,
            properties=['IATA', 'LATITUDE', 'LONGITUDE'],
            scale=scale_m,
            geometries=False
        )
        def attach_slot(f):
            return f.set({
                'slot_start': image.get('slot_start'),
                'slot_end':   image.get('slot_end')
            })
        return samples.map(attach_slot)
    return ee.FeatureCollection(ic_6h.map(sample_one_image)).flatten()

def sample_era5_6h_at_airports(airports_fc: ee.FeatureCollection,
                               scale_m: int = 25000) -> ee.FeatureCollection:
    ic_6h = make_era5_6h_collection_2013()
    return sample_era5_6h_at_airports_from_collection(ic_6h, airports_fc, scale_m)

_SELECTORS = [
    'IATA', 'LATITUDE', 'LONGITUDE',
    'slot_start', 'slot_end',
    'minimum_2m_air_temperature',
    'maximum_2m_air_temperature',
    'total_precipitation',
    'u_component_of_wind_10m',
    'v_component_of_wind_10m',
    'instantaneous_10m_wind_gust',
    'mxtpr_kg_m2_s'
]

def _compute_features_paged(fc: ee.FeatureCollection, selectors=None, page_size=2000, max_retries=5) -> pd.DataFrame:
    if selectors is None:
        selectors = _SELECTORS
    rows, token, attempt = [], None, 0
    while True:
        req = {"expression": fc, "pageSize": page_size, "selectors": selectors}
        if token:
            req["pageToken"] = token
        try:
            resp = ee.data.computeFeatures(req)
            attempt = 0  # reset on success
        except Exception as e:
            if attempt >= max_retries:
                raise
            sleep_s = 0.5 * (2 ** attempt)
            time.sleep(sleep_s)
            attempt += 1
            continue

        feats = resp.get("features", [])
        rows.extend([f.get("properties", {}) for f in feats])
        token = resp.get("nextPageToken")
        if not token:
            break
    return pd.DataFrame(rows)

def ee_fc_to_dataframe(fc: ee.FeatureCollection, selectors=None, page_size=2000) -> pd.DataFrame:
    if selectors is None:
        selectors = _SELECTORS

    # 1) Geemap helper
    if hasattr(geemap, "ee_to_df"):
        try:
            df = geemap.ee_to_df(fc)
            keep = [c for c in selectors if c in df.columns]
            return df[keep]
        except Exception:
            pass

    # 2) geemap.common.ee_to_pandas
    try:
        from geemap.common import ee_to_pandas
        return ee_to_pandas(fc, selectors=selectors)
    except Exception:
        pass

    # 3) geemap.ee_to_gdf fallback
    if hasattr(geemap, "ee_to_gdf"):
        gdf = geemap.ee_to_gdf(fc, selectors=selectors)
        return pd.DataFrame(gdf.drop(columns="geometry"))

    # 4) API
    return _compute_features_paged(fc, selectors=selectors, page_size=page_size)

def monthly_ranges_2013():
    months = list(range(1, 13))
    out = []
    for m in months:
        start = ee.Date.fromYMD(2013, m, 1)
        end = start.advance(1, "month")
        out.append((start, end))
    return out

def download_weather_2013_6h(airports_fc: ee.FeatureCollection,
                             scale_m: int = 25000,
                             page_size: int = 2000) -> pd.DataFrame:
    dfs = []
    for start, end in monthly_ranges_2013():
        ic_6h = make_era5_6h_collection(start, end)
        fc_samples = sample_era5_6h_at_airports_from_collection(ic_6h, airports_fc, scale_m)
        df = ee_fc_to_dataframe(fc_samples, selectors=_SELECTORS, page_size=page_size)

        # Convert timestamps and units
        df = df.rename(columns={'slot_start': 'era5_slot_utc',
                                'slot_end':   'era5_slot_end_utc'})
        df['era5_slot_utc']     = pd.to_datetime(df['era5_slot_utc'], utc=True)
        df['era5_slot_end_utc'] = pd.to_datetime(df['era5_slot_end_utc'], utc=True)
        df['tmin_c'] = df['minimum_2m_air_temperature'] - 273.15
        df['tmax_c'] = df['maximum_2m_air_temperature'] - 273.15
        df['tp_mm']  = df['total_precipitation'] * 1000.0  # m → mm
        df = df.rename(columns={'instantaneous_10m_wind_gust': 'gust_ms'})
        dfs.append(df)

    weather_6h_df = pd.concat(dfs, ignore_index=True)
    # Order columns
    cols_pretty = [
        'IATA', 'LATITUDE', 'LONGITUDE',
        'era5_slot_utc', 'era5_slot_end_utc',
        'tmin_c', 'tmax_c', 'tp_mm', 'gust_ms', 'mxtpr_kg_m2_s',
        'u_component_of_wind_10m', 'v_component_of_wind_10m',
        'minimum_2m_air_temperature', 'maximum_2m_air_temperature', 'total_precipitation'
    ]
    existing = [c for c in cols_pretty if c in weather_6h_df.columns]
    return weather_6h_df[existing]

def ensure_full_grid_2013_6h(weather_df: pd.DataFrame,
                             airports_df: pd.DataFrame) -> pd.DataFrame:
    """
    IATA × every 6‑hour UTC slot in 2013, merges won't
    drop airports with missing samples
    """
    slots = pd.date_range('2013-01-01 00:00:00+00:00',
                          '2013-12-31 18:00:00+00:00',
                          freq='6H', inclusive='both')
    idx = pd.MultiIndex.from_product(
        [airports_df['IATA'].astype(str).str.upper().unique(), slots],
        names=['IATA', 'era5_slot_utc']
    )
    w = (weather_df
         .set_index(['IATA','era5_slot_utc'])
         .reindex(idx)
         .reset_index())
    w['era5_slot_end_utc'] = w['era5_slot_utc'] + pd.Timedelta(hours=6)
    return w

airports_fc = airports_df_to_fc(airports_in_flights)

# one big pull
# fc_samples = sample_era5_6h_at_airports(airports_fc, scale_m=25000)
# weather_6h_df = ee_fc_to_dataframe(fc_samples)

# monthly chunked pull
weather_6h_df = download_weather_2013_6h(airports_fc, scale_m=25000, page_size=2000)

print(weather_6h_df.head())
print(len(weather_6h_df), "rows")


KeyboardInterrupt: 

In [None]:
weather_6h_df.columns

We need 4 graphs per day

Each day and time period, we want the nodes of the graph to represent airports

Each edge is a flight scheduled in the current time period (can be delayed to next)

In [None]:
def get_day(D, month, day):
  day_df = df[(df['month'] == month) & (df['day'] == day)]
  if day_df.empty:
    return -1
  return df[(df['month'] == month) & (df['day'] == day)]

In [None]:
jan_first_df = get_day(df, 1,1)

In [None]:
def get_time_periods(day_df):
  return [day_df[(day_df['sched_dep_time'] <= 600 )], day_df[(day_df['sched_dep_time'] > 600) & (day_df['sched_dep_time'] <= 1200)], day_df[(day_df['sched_dep_time'] > 1200) & (day_df['sched_dep_time'] <= 1800)], day_df[(day_df['sched_dep_time'] > 1800)]]

In [None]:
get_time_periods(jan_first_df)[3]

Unnamed: 0,id,year,month,day,dep_time,sched_dep_time,dep_delay,arr_time,sched_arr_time,arr_delay,...,flight,tailnum,origin,dest,air_time,distance,hour,minute,time_hour,name
151,151,2013,1,1,848.0,1835,853.0,1001.0,1950,851.0,...,3944,N942MQ,JFK,BWI,41.0,184,18,35,2013-01-01 18:00:00,Envoy Air
636,636,2013,1,1,1802.0,1805,-3.0,1930.0,1944,-14.0,...,1006,N359NB,LGA,BUF,61.0,292,18,5,2013-01-01 18:00:00,Delta Air Lines Inc.
637,637,2013,1,1,1802.0,1801,1.0,2125.0,2137,-12.0,...,1165,N75429,EWR,LAX,340.0,2454,18,1,2013-01-01 18:00:00,United Air Lines Inc.
642,642,2013,1,1,1806.0,1810,-4.0,2002.0,1945,17.0,...,4484,N711MQ,LGA,BNA,152.0,764,18,10,2013-01-01 18:00:00,Envoy Air
644,644,2013,1,1,1808.0,1815,-7.0,2111.0,2130,-19.0,...,7,N553AS,EWR,SEA,336.0,2402,18,15,2013-01-01 18:00:00,Alaska Airlines Inc.
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
833,833,2013,1,1,2327.0,2250,37.0,32.0,2359,33.0,...,22,N639JB,JFK,SYR,45.0,209,22,50,2013-01-01 22:00:00,JetBlue Airways
835,835,2013,1,1,2353.0,2359,-6.0,425.0,445,-20.0,...,739,N591JB,JFK,PSE,195.0,1617,23,59,2013-01-01 23:00:00,JetBlue Airways
836,836,2013,1,1,2353.0,2359,-6.0,418.0,442,-24.0,...,707,N794JB,JFK,SJU,185.0,1598,23,59,2013-01-01 23:00:00,JetBlue Airways
837,837,2013,1,1,2356.0,2359,-3.0,425.0,437,-12.0,...,727,N588JB,JFK,BQN,186.0,1576,23,59,2013-01-01 23:00:00,JetBlue Airways


In [None]:
def get_day(D, month, day):
  day_df = df[(df['month'] == month) & (df['day'] == day)]
  # dont need this because it messes with get_all_periods
  # if day_df.empty:
  #   return -1
  return df[(df['month'] == month) & (df['day'] == day)]

def get_time_periods(day_df):
  return [
      day_df[(day_df['sched_dep_time'] <= 600 )],
      day_df[(day_df['sched_dep_time'] > 600) & (day_df['sched_dep_time'] <= 1200)],
      day_df[(day_df['sched_dep_time'] > 1200) & (day_df['sched_dep_time'] <= 1800)],
      day_df[(day_df['sched_dep_time'] > 1800)]]

def get_all_periods(D):
  total_periods = {}
  for month in range(1,13):
    for day in range(1,32):
      day_df = get_day(D, month, day)
      if day_df.empty:
        pass
      else:
        split_periods = get_time_periods(day_df)
        total_periods[(month, day)] = [split_periods]

  return total_periods

In [None]:
periods_map = get_all_periods(df)

In [None]:
def get_airport_tensor(period_df):
    airports = set(period_df['origin']).union(set(period_df['dest']))
    airport_to_idx = {airport: i for i, airport in enumerate(sorted(airports))}

    # Create a tensor of indices (numeric)
    airport_tensor = torch.tensor(list(airport_to_idx.values()), dtype=torch.long)

    return airport_tensor, airport_to_idx

In [None]:
airport_tensor, airport_to_idx = get_airport_tensor(df)

In [None]:
airports_in_flights

Unnamed: 0,IATA,LATITUDE,LONGITUDE
0,ABQ,35.040222,-106.609194
1,ACK,41.253052,-70.060181
2,ALB,42.748119,-73.802979
3,ANC,61.174320,-149.996186
4,ATL,33.640444,-84.426944
...,...,...,...
102,TPA,27.975472,-82.533250
103,TUL,36.198372,-95.888242
104,TVC,44.741445,-85.582235
105,TYS,35.812487,-83.992856


In [None]:
import hashlib

def convert_tail(tail_num):
  # Some hash for tail number
  return int(hashlib.md5(tail_num.encode()).hexdigest(), 16) % (10**8)

airlines = set(df['name'])
airlines_map = {}
for i, name in enumerate(airlines):
  airlines_map[name] = i

def clean_data(df):
  # df.isna().sum()
  # we have enough data we can just drop nans
  df = df.dropna()
  # df['dest'] = [airport_to_idx[loc] for loc in df['dest']]
  # df['origin'] = [airport_to_idx[loc] for loc in df['origin']]
  df['tailnum'] = [convert_tail(tail) for tail in df['tailnum']]
  df['name'] = [airlines_map[name] for name in df['name']]
  return df

In [None]:
airlines_map

{'United Air Lines Inc.': 0,
 'Hawaiian Airlines Inc.': 1,
 'AirTran Airways Corporation': 2,
 'Southwest Airlines Co.': 3,
 'Alaska Airlines Inc.': 4,
 'ExpressJet Airlines Inc.': 5,
 'Mesa Airlines Inc.': 6,
 'Virgin America': 7,
 'SkyWest Airlines Inc.': 8,
 'American Airlines Inc.': 9,
 'Endeavor Air Inc.': 10,
 'JetBlue Airways': 11,
 'US Airways Inc.': 12,
 'Envoy Air': 13,
 'Delta Air Lines Inc.': 14,
 'Frontier Airlines Inc.': 15}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
convert_tail('N123')

67621894

In [None]:
clean_data(df)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['tailnum'] = [convert_tail(tail) for tail in df['tailnum']]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['name'] = [airlines_map[name] for name in df['name']]


Unnamed: 0,id,year,month,day,dep_time,sched_dep_time,dep_delay,arr_time,sched_arr_time,arr_delay,...,flight,tailnum,origin,dest,air_time,distance,hour,minute,time_hour,name
0,0,2013,1,1,517.0,515,2.0,830.0,819,11.0,...,1545,73754730,EWR,IAH,227.0,1400,5,15,2013-01-01 05:00:00,0
1,1,2013,1,1,533.0,529,4.0,850.0,830,20.0,...,1714,84949463,LGA,IAH,227.0,1416,5,29,2013-01-01 05:00:00,0
2,2,2013,1,1,542.0,540,2.0,923.0,850,33.0,...,1141,4254391,JFK,MIA,160.0,1089,5,40,2013-01-01 05:00:00,9
3,3,2013,1,1,544.0,545,-1.0,1004.0,1022,-18.0,...,725,26303597,JFK,BQN,183.0,1576,5,45,2013-01-01 05:00:00,11
4,4,2013,1,1,554.0,600,-6.0,812.0,837,-25.0,...,461,61460058,LGA,ATL,116.0,762,6,0,2013-01-01 06:00:00,14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
336765,336765,2013,9,30,2240.0,2245,-5.0,2334.0,2351,-17.0,...,1816,44466225,JFK,SYR,41.0,209,22,45,2013-09-30 22:00:00,11
336766,336766,2013,9,30,2240.0,2250,-10.0,2347.0,7,-20.0,...,2002,64228233,JFK,BUF,52.0,301,22,50,2013-09-30 22:00:00,11
336767,336767,2013,9,30,2241.0,2246,-5.0,2345.0,1,-16.0,...,486,66259738,JFK,ROC,47.0,264,22,46,2013-09-30 22:00:00,11
336768,336768,2013,9,30,2307.0,2255,12.0,2359.0,2358,1.0,...,718,88605432,JFK,BOS,33.0,187,22,55,2013-09-30 22:00:00,11


## EDA to see how many graphs we'll get

In [None]:
# 1. Check periods_map structure
print(f"Total days in periods_map: {len(periods_map)}")
print(f"Expected graphs (days × 4): {len(periods_map) * 4}")
print(f"\nSample keys: {list(periods_map.keys())[:10]}")

# 2. Count flights per period
total_flights = 0
non_empty_periods = 0
period_flight_counts = []

for (month, day), day_data in periods_map.items():
    day_periods = day_data[0]
    for period_idx, period_df in enumerate(day_periods):
        num_flights = len(period_df)
        total_flights += num_flights
        if num_flights > 0:
            non_empty_periods += 1
        period_flight_counts.append(num_flights)

print(f"\nTotal flights across all periods: {total_flights}")
print(f"Non-empty periods: {non_empty_periods}")
print(f"Empty periods: {len(periods_map) * 4 - non_empty_periods}")

NameError: name 'periods_map' is not defined

In [None]:
# 3. Distribution of flights per period
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist([x for x in period_flight_counts if x > 0], bins=50)
plt.xlabel('Flights per period')
plt.ylabel('Frequency')
plt.title('Distribution of flights (non-empty periods)')

plt.subplot(1, 2, 2)
plt.hist(period_flight_counts, bins=50)
plt.xlabel('Flights per period (including 0s)')
plt.ylabel('Frequency')
plt.title('All periods')
plt.tight_layout()
plt.show()

print(f"\nFlights per period stats:")
print(f"  Mean: {np.mean(period_flight_counts):.1f}")
print(f"  Median: {np.median(period_flight_counts):.1f}")
print(f"  Max: {max(period_flight_counts)}")
print(f"  Min: {min(period_flight_counts)}")

# 4. Check original flights_df
print(f"\n--- Original flights_df ---")
print(f"Total flights in df: {len(df)}")
print(f"Flights with valid origin/dest: {df[['origin', 'dest']].notna().all(axis=1).sum()}")
print(f"Unique months: {sorted(df['month'].unique())}")
print(f"Unique days range: {df['day'].min()} to {df['day'].max()}")

# 5. Verify periods_map matches df
flights_in_periods = sum(len(period_df) for (_, _), day_data in periods_map.items()
                         for period_df in day_data[0])
print(f"\nFlights in periods_map: {flights_in_periods}")
print(f"Flights in original df: {len(df)}")
print(f"Match: {flights_in_periods == len(df)}")

# 6. Check for missing dates
from datetime import datetime, timedelta
all_dates_2013 = []
current = datetime(2013, 1, 1)
while current.year == 2013:
    all_dates_2013.append((current.month, current.day))
    current += timedelta(days=1)

missing_dates = set(all_dates_2013) - set(periods_map.keys())
print(f"\nTotal days in 2013: {len(all_dates_2013)}")
print(f"Days in periods_map: {len(periods_map)}")
print(f"Missing dates: {len(missing_dates)}")
if len(missing_dates) < 20:
    print(f"Missing: {sorted(missing_dates)}")

# 7. Sample a specific day
sample_day = (1, 1)
if sample_day in periods_map:
    day_periods = periods_map[sample_day][0]
    print(f"\n--- January 1st breakdown ---")
    for i, period_df in enumerate(day_periods):
        time_range = ['00:00-06:00', '06:00-12:00', '12:00-18:00', '18:00-24:00'][i]
        print(f"Period {i} ({time_range}): {len(period_df)} flights")
        if len(period_df) > 0:
            print(f"  Dep time range: {period_df['sched_dep_time'].min()} - {period_df['sched_dep_time'].max()}")

## Creating PyG graph


Using InMemoryDataset to create a final dataset that can be downloaded as a pt file

note: might need to remove the df['dest'] and df['origin'] conversion to idxs since i think internal indexing in the dataset might be better  

In [None]:
class SkyNetDataset(InMemoryDataset):
  """
  PyG dataset for flight delay prediction

  Each graph represents a 6-hour time block of the US flight network, where:
    - Nodes represent airports with geographic and weather features
    - Edges represent individual flights with scheduling and operational features
    - Labels are either delay minutes (regression) or binary delay indicators (classification)

    The dataset covers the year 2013 with 4 time blocks per day (00:00, 06:00, 12:00, 18:00 UTC)


  Args:
        periods_map (dict): Mapping of (month, day) tuples to lists of flight DataFrames for each 6-hour block
        airports_df (pd.DataFrame): Airport information with columns: IATA, LATITUDE, LONGITUDE
        weather_df (pd.DataFrame): Weather data with columns: IATA, era5_slot_utc, tmin_c, tmax_c,
                                    tp_mm, gust_ms, u_component_of_wind_10m, v_component_of_wind_10m, mxtpr_kg_m2_s
        airlines_map (dict): Mapping of airline names to integer indices
        root (str): Root directory for storing processed data
        task (str): Either "regression" (predict delay minutes) or "classification" (predict delay > 15 min)
  """
  def __init__(self, periods_map, airports_df, weather_df,airlines_map,
               root="/content/drive/Shareddrives/CS 224W Project/data/skynet_graphs", transform=None, pre_transform=None,
               verbose=True, task="regression"):


    self.periods_map = periods_map
    self.airports_df = airports_df
    self.weather_df = weather_df
    self.airlines_map = airlines_map
    self.verbose = verbose
    self.task = task


    super().__init__(root, transform, pre_transform)
    self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

  @property
  def raw_file_names(self):
      return []

  @property
  def processed_file_names(self):
      return ['flight_graphs.pt']


  def process(self):
    """
    Generate all graphs from the periods map and save to disk.

    Iterates through each day of 2013, creating one graph per 6-hour time block.
    Each graph captures the flight network structure and features for that period.

    Total possible blocks: 365 days × 4 blocks/day = 1,460 blocks
    Actual blocks may be fewer if some periods have no flights.

    The processed graphs are saved to self.processed_paths[0] for fast loading.
    """
    self._setup_mappings()

    data_list = []
    block_idx = 0
    period_to_hour = {0: 0, 1: 6, 2: 12, 3: 18}

    for month in range(1,13):
      for day in range(1, 32):
        if (month, day) not in self.periods_map:
          continue

        day_blocks = self.periods_map[(month, day)][0]
        for section_idx, section_df in enumerate(day_blocks):

          if section_df.empty:
            block_idx += 1
            continue

          start = period_to_hour[section_idx]
          section_time = pd.Timestamp(
              year=2013, month=month, day=day, hour=start, tz='UTC'
          )


          graph = self._create_graph_for_block(
              section_df, section_time, block_idx
          )

          if graph:
            data_list.append(graph)

          block_idx += 1

          if self.verbose and block_idx % 100 == 0:
            print(f"Processed block: {block_idx}/1460")

    if self.verbose:
      print(f"Created {len(data_list)} total graphs")

    # convention for inmemorydataset
    data, slices = self.collate(data_list)
    torch.save((data, slices), self.processed_paths[0])


  def _setup_mappings(self):
    """
    Setup airport and weather mappings for 0(1) lookup

    Creates:
      - airport_to_idx: Dict mapping airport IATA codes to node indices
      - num_nodes: Total number of airports (graph nodes)
      - static_node_features: Tensor of [latitude, longitude] for each airport (sorted by IATA)
      - weather_dict: Dict mapping (IATA, timestamp) to weather feature dict for fast lookup
    """
    # airports index mapping
    airports = sorted(self.airports_df['IATA'].unique())
    self.airport_to_idx = {airport: i for i, airport in enumerate(airports)}
    self.num_nodes = len(airports)

    # extract static geographic features for airports
    features = []
    for airport in sorted(self.airport_to_idx.keys()):
        row = self.airports_df[self.airports_df['IATA'] == airport].iloc[0]
        features.append([row['LATITUDE'], row['LONGITUDE']])
    self.static_node_features = torch.tensor(features, dtype=torch.float)

    # pre-index weather information for fast lookup
    self.weather_dict = {}
    for _, row in self.weather_df.iterrows():
        key = (row['IATA'], row['era5_slot_utc'])
        self.weather_dict[key] = {
            'tmin_c': row['tmin_c'],
            'tmax_c': row['tmax_c'],
            'tp_mm': row['tp_mm'],
            'gust_ms': row['gust_ms'],
            'u_wind': row['u_component_of_wind_10m'],
            'v_wind': row['v_component_of_wind_10m'],
            'mxtpr': row['mxtpr_kg_m2_s']
        }


  def _create_graph_for_block(self, block, start_time, block_idx):
    """
    Create a graph representation for a single 6-hour time block.

    Args:
        block (pd.DataFrame): DataFrame containing all flights in this time block
        start_time (pd.Timestamp): Start time of the 6-hour block (UTC)
        block_idx (int): Sequential index of this block

    Returns:
        Data: PyTorch Geometric Data object with:
            - x: Node features [num_nodes, 9] = [lat, lon, 7 weather features]
            - edge_index: Edge connections [2, num_edges]
            - edge_attr: Edge features [num_edges, 13]
            - y: Edge labels [num_edges] (delay minutes or binary delay indicator)
            - num_nodes: Total number of airports
            - time: Timestamp of this block
            - block_idx: Sequential block identifier
    """
    # get node info (static + temporal) and combine
    airport_weather = self._get_node_weather(start_time)
    node_features = torch.cat([self.static_node_features, airport_weather], dim=1)

    # edge features
    edge_idx, edge_attr, edge_lables = self._create_edges_for_block(block)

    graph = Data(
        x=node_features,
        edge_index=edge_idx,
        edge_attr=edge_attr,
        y=edge_lables,
        num_nodes=self.num_nodes,
        time=start_time,
        block_idx=block_idx
    )

    return graph

  def _get_node_weather(self, timestamp):
    """
    Retrieve weather features for all airports at a given timestamp.

    Returns:
        torch.Tensor: Weather features [num_nodes, 7] containing:
            - tmin_c: Minimum temperature (°C)
            - tmax_c: Maximum temperature (°C)
            - tp_mm: Total precipitation (mm)
            - gust_ms: Wind gust speed (m/s)
            - u_wind: U-component of wind at 10m
            - v_wind: V-component of wind at 10m
            - mxtpr: Maximum total precipitation rate (kg/m²/s)

        Missing weather data is filled with zeros.
    """
    weather_features = []
    for airport in sorted(self.airport_to_idx.keys()):
        key = (airport, timestamp)

        if key in self.weather_dict:
            w = self.weather_dict[key]
            features = [
                w['tmin_c'], w['tmax_c'], w['tp_mm'],
                w['gust_ms'], w['u_wind'], w['v_wind'], w['mxtpr']
            ]
        else:
            features = [0.0] * 7

        weather_features.append(features)

    return torch.tensor(weather_features, dtype=torch.float)

  def _create_edges_for_block(self, block):
    """
    Create edges (flights) for a time block with associated features and labels.

    Args:
        block (pd.DataFrame): DataFrame containing flight records for this time block

    Returns:
        tuple: (edge_index, edge_attr, edge_labels)
            - edge_index: Tensor [2, num_edges] with source and destination node indices
            - edge_attr: Tensor [num_edges, 13] with flight features
            - edge_labels: Tensor [num_edges] with delay labels
                * Regression task: actual delay in minutes
                * Classification task: 1.0 if delay > 15 min, else 0.0

    Flights with missing origin/destination or unknown airports are skipped.
    """
    edge_list = []
    edge_features = []
    edge_labels = []

    for _, flight in block.iterrows():
        if pd.isna(flight['origin']) or pd.isna(flight['dest']):
            continue

        origin = str(flight['origin']).upper()
        dest = str(flight['dest']).upper()

        if origin not in self.airport_to_idx or dest not in self.airport_to_idx:
            continue

        origin_idx = self.airport_to_idx[origin]
        dest_idx = self.airport_to_idx[dest]

        edge_list.append([origin_idx, dest_idx])
        edge_features.append(self._extract_edge_features(flight))

        # Label -- using 15 minutes as delay info
        # DEPARTURE DELAY
        if self.task == 'regression':
            label = flight['dep_delay'] if pd.notna(flight['dep_delay']) else 0.0
        else:
            label = 1.0 if (pd.notna(flight['dep_delay']) and flight['dep_delay'] > 15) else 0.0

        edge_labels.append(label)

    if len(edge_list) == 0:
        return (torch.empty((2, 0), dtype=torch.long),
                torch.empty((0, 0), dtype=torch.float),
                torch.empty((0,), dtype=torch.float))

    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_features, dtype=torch.float)
    edge_y = torch.tensor(edge_labels, dtype=torch.float)

    return edge_index, edge_attr, edge_y


  def _extract_edge_features(self, flight):
      """
      Extract a feature vector for a single flight (edge).

      Args:
          flight (pd.Series): Row from flight DataFrame containing flight information

      Returns:
          list: Feature vector with 13 elements:
              [0-1]   sched_dep_sin, sched_dep_cos: Scheduled departure time (cyclical encoding)
              [2-3]   sched_arr_sin, sched_arr_cos: Scheduled arrival time (cyclical encoding)
              [4-5]   dep_sin, dep_cos: Actual departure time (cyclical encoding)
              [6-7]   arr_sin, arr_cos: Actual arrival time (cyclical encoding)
              [8]     tailnum_encoded: Aircraft tail number (hashed and normalized to [0,1])
              [9]     airline_encoded: Airline identifier (normalized to [0,1])
              [10]    distance: Flight distance
              [11]    air_time: Scheduled flight duration
              [12]    dep_delay: Departure delay in minutes (note: also used as label)

      Time features use sin/cos encoding to capture cyclical nature of 24-hour time.
      Missing values are encoded as 0.0.
      """
      def time_to_cyclical(time_hhmm):
          if pd.isna(time_hhmm):
              return 0.0, 0.0
          hours = int(time_hhmm) // 100
          minutes = int(time_hhmm) % 100
          fraction = (hours + minutes/60) / 24
          angle = 2 * np.pi * fraction
          return np.sin(angle), np.cos(angle)

      sched_dep_sin, sched_dep_cos = time_to_cyclical(flight['sched_dep_time'])
      sched_arr_sin, sched_arr_cos = time_to_cyclical(flight['sched_arr_time'])
      dep_sin, dep_cos = time_to_cyclical(flight.get('dep_time'))
      arr_sin, arr_cos = time_to_cyclical(flight.get('arr_time'))

      # Tailnum encoding
      tailnum_encoded = 0.0
      if pd.notna(flight['tailnum']):
          tailnum_encoded = int(hashlib.md5(str(flight['tailnum']).encode()).hexdigest(), 16) % (10**8)
          tailnum_encoded /= 1e8

      # Airline encoding
      airline_encoded = 0.0
      if pd.notna(flight['name']) and flight['name'] in self.airlines_map:
          airline_encoded = self.airlines_map[flight['name']] / len(self.airlines_map)

      return [
          sched_dep_sin, sched_dep_cos,
          sched_arr_sin, sched_arr_cos,
          dep_sin, dep_cos,
          arr_sin, arr_cos,
          tailnum_encoded,
          airline_encoded,
          flight['distance'] if pd.notna(flight['distance']) else 0.0,
          flight['air_time'] if pd.notna(flight['air_time']) else 0.0,
          flight['dep_delay'] if pd.notna(flight['dep_delay']) else 0.0,
      ]


In [None]:
import pandas as pd

# Assuming the file is in a shared drive under 'CS 224W Project/data/'
weather_6h_df_path = '/content/drive/Shareddrives/CS 224W Project/data/weather_6h_df.pkl'

try:
    weather_6h_df = pd.read_pickle(weather_6h_df_path)
    print(f"Successfully loaded weather_6h_df from {weather_6h_df_path}")
    print(weather_6h_df.head())
    print(f"{len(weather_6h_df)} rows loaded.")
except FileNotFoundError:
    print(f"Error: The file {weather_6h_df_path} was not found. Please ensure the path is correct and the file exists.")
except Exception as e:
    print(f"An error occurred while loading the file: {e}")

Successfully loaded weather_6h_df from /content/drive/Shareddrives/CS 224W Project/data/weather_6h_df.pkl
  IATA   LATITUDE   LONGITUDE             era5_slot_utc  \
0  ABQ  35.040222 -106.609194 2013-01-01 00:00:00+00:00   
1  ACK  41.253052  -70.060181 2013-01-01 00:00:00+00:00   
2  ALB  42.748119  -73.802979 2013-01-01 00:00:00+00:00   
3  ANC  61.174320 -149.996186 2013-01-01 00:00:00+00:00   
4  ATL  33.640444  -84.426944 2013-01-01 00:00:00+00:00   

          era5_slot_end_utc    tmin_c     tmax_c     tp_mm    gust_ms  \
0 2013-01-01 06:00:00+00:00 -7.489691  -0.561499  0.045593   7.207325   
1 2013-01-01 06:00:00+00:00  4.496820   5.455896  0.286515  17.311811   
2 2013-01-01 06:00:00+00:00 -1.222784  -0.104865  0.032761   6.050241   
3 2013-01-01 06:00:00+00:00 -0.984564   0.565393  0.121892   9.476895   
4 2013-01-01 06:00:00+00:00  9.144800  11.294366  0.004413   4.613820   

   mxtpr_kg_m2_s  u_component_of_wind_10m  v_component_of_wind_10m  \
0   4.258214e-09              

In [None]:
dataset = SkyNetDataset(
    periods_map, airports_in_flights, weather_6h_df, airlines_map
)

Processing...


Processed block: 100/1460
Processed block: 200/1460
Processed block: 300/1460
Processed block: 400/1460
Processed block: 500/1460
Processed block: 600/1460
Processed block: 700/1460
Processed block: 800/1460
Processed block: 900/1460
Processed block: 1000/1460
Processed block: 1100/1460
Processed block: 1200/1460
Processed block: 1300/1460
Processed block: 1400/1460
Created 1460 total graphs


Done!


In [None]:
file_path = '/content/drive/Shareddrives/CS 224W Project/data/skynet_graphs/processed/flight_graphs.pt'

# Load the contents of the .pt file
loaded_data = torch.load(file_path, weights_only=False)


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/Shareddrives/CS 224W Project/data/skynet_graphs/processed/flight_graphs.pt'

In [None]:
loaded_data

(Data(x=[156220, 9], edge_index=[2, 336776], edge_attr=[336776, 13], y=[336776], time=[1460], num_nodes=156220, block_idx=[1460]),
 {'x': tensor([     0,    107,    214,  ..., 156006, 156113, 156220]),
  'edge_index': tensor([     0,     23,    312,  ..., 336345, 336642, 336776]),
  'edge_attr': tensor([     0,     23,    312,  ..., 336345, 336642, 336776]),
  'y': tensor([     0,     23,    312,  ..., 336345, 336642, 336776]),
  'time': tensor([   0,    1,    2,  ..., 1458, 1459, 1460]),
  'block_idx': tensor([   0,    1,    2,  ..., 1458, 1459, 1460])})