In [1]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import udf, pandas_udf, col, PandasUDFType, lit, round, array_contains, from_unixtime
from pyspark.sql.functions import col, radians, sin, cos, sqrt, atan2, array, collect_list, struct, row_number, expr
from pyspark.sql.functions import monotonically_increasing_id, row_number, col
from pyspark.sql.types import DoubleType, StructType, StructField, IntegerType
from pyspark.sql.functions import when, broadcast, split, col, concat_ws,  min, max, to_date, unix_timestamp
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import col, expr, udf
from pyspark.sql.types import StructType, StructField, DoubleType, TimestampType
from pyspark.sql.functions import concat_ws, collect_list, array_sort

# Regular imports
from IPython.display import display, HTML
import os, time
import subprocess
import os,shutil
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
import h3pandas
import h3
import math

# Custom functions
from datetime import datetime, date
import dateutil.relativedelta
import calendar

def generate_months(start_date, end_date):
    """Generate a list of dates corresponding to the first day of each month between two dates.

    Args:
    start_date (datetime.date): The starting date.
    end_date (datetime.date): The ending date.

    Returns:
    list: A list of date objects for the first day of each month within the specified range.
    """
    current = start_date
    months = []
    while current <= end_date:
        months.append(current)
        # Increment month
        month = current.month
        year = current.year
        if month == 12:
            current = date(year + 1, 1, 1)
        else:
            current = date(year, month + 1, 1)
    return months

def get_start_end_of_month(date):
    """
    Return the Unix timestamp for the first and last second of the given month and year.

    Args:
        date (datetime): A datetime object representing any date within the desired month.

    Returns:
        tuple: A tuple containing the Unix timestamp of the first second and last second of the month.
    """
    year = date.year
    month = date.month
    
    # Calculate first and last second of the month
    first_second = datetime(year, month, 1, 0, 0, 0)
    last_day = calendar.monthrange(year, month)[1]
    last_second = datetime(year, month, last_day, 23, 59, 59)
    
    return first_second.timestamp(), last_second.timestamp()

# Settings
project = "project_opdi"
resolution = 7

start_month = date(2022, 1, 1)
end_month = date(2025, 1, 1)

# Getting today's date
today = datetime.today().strftime('%d %B %Y')

# Spark Session Initialization
spark = SparkSession.builder \
    .appName("OPDI Flight Table") \
    .config("spark.rpc.message.maxSize", 512) \
    .config("spark.hadoop.fs.azure.ext.cab.required.group", "eur-app-opdi") \
    .config("spark.kerberos.access.hadoopFileSystems", "abfs://storage-fs@cdpdllive.dfs.core.windows.net/data/project/opdi.db/unmanaged") \
    .config("spark.executor.extraClassPath", "/opt/spark/optional-lib/iceberg-spark-runtime-3.3_2.12-1.3.1.1.20.7216.0-70.jar") \
    .config("spark.driver.extraClassPath", "/opt/spark/optional-lib/iceberg-spark-runtime-3.3_2.12-1.3.1.1.20.7216.0-70.jar") \
    .config("spark.sql.catalog.spark_catalog.type", "hive") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") \
    .config("spark.sql.iceberg.handle-timestamp-without-timezone", "true") \
    .config("spark.sql.catalog.spark_catalog.warehouse", "abfs://storage-fs@cdpdllive.dfs.core.windows.net/data/project/opdi.db/unmanaged") \
    .config("spark.driver.cores", "1") \
    .config("spark.driver.memory", "8G") \
    .config("spark.executor.memory", "8G") \
    .config("spark.executor.memoryOverhead", "3G") \
    .config("spark.executor.cores", "2") \
    .config("spark.executor.instances", "3") \
    .config("spark.dynamicAllocation.maxExecutors", "20") \
    .config("spark.network.timeout", "800s") \
    .config("spark.executor.heartbeatInterval", "400s") \
    .config("spark.driver.maxResultSize", "6g") \
    .config("spark.shuffle.compress", "true") \
    .config("spark.shuffle.spill.compress", "true") \
    .enableHiveSupport() \
    .getOrCreate()

## Query testing sample function

from pyspark.sql.functions import col, lit, from_unixtime, to_timestamp

def get_data_within_timeframe(spark, table_name, month, time_col='event_time'):
    """
    Retrieves records from a specified Spark table within the given timeframe.

    Args:
        spark (SparkSession): The SparkSession object.
        table_name (str): The name of the Spark table to query.
        month (str): The start date of a month in the format 'YYYY-MM-DD'.
        time_col (str): The column name containing timestamp data (default: 'event_time').

    Returns:
        pyspark.sql.dataframe.DataFrame: A DataFrame containing the records within the specified timeframe.
    """
    # Convert the start and end of the month to Unix timestamps
    start_date, end_date = get_start_end_of_month(month)

    # Convert Unix timestamps to Spark timestamp format
    start_date_ts = to_timestamp(lit(start_date))
    end_date_ts = to_timestamp(lit(end_date))

    # Load the table
    df = spark.table(table_name)

    # Filter records based on the timestamp column
    filtered_df = df.filter((col(time_col) >= start_date_ts) & (col(time_col) < end_date_ts))

    return filtered_df

Setting spark.hadoop.yarn.resourcemanager.principal to quinten.goens


In [5]:
%%time
# Pull trajectories
traj_sdf = get_data_within_timeframe(spark, table_name = 'project_opdi.osn_tracks', month= datetime(2025,1,1), time_col='event_time')
#traj_sdf = traj_sdf.filter(col('track_id') == '000059f343f5301bc6aea34c66b51f445e511b319e6a0bc69961c37cadceeaab_0_2025_1')

# Pull FIR
fir_sdf = spark.table('project_opdi.opdi_h3_airspace_ref')
offset = 3
fir_sdf = fir_sdf.filter(
    (col('airspace_type') == 'FIR') &\
    (col('airac_cfmu') == 524) &\
    (col('h3_res_7_lat') >= 26.74617 - offset) &\
    (col('h3_res_7_lat') <= 70.25976 + offset) &\
    (col('h3_res_7_lon') >= -25.86653 - offset) &\
    (col('h3_res_7_lon') <= 49.65699 + offset))
fir_sdf = fir_sdf.dropDuplicates() 
fir_sdf = fir_sdf.select('code','h3_res_7', 'min_fl','max_fl')

# Combine both
#.withColumn('FL',col('baro_altitude_c') * 3.28084/100)\
traj_sdf = traj_sdf\
    .withColumn('FL',round(col('baro_altitude_c') * 3.28084/100))\
    .select('track_id', 'event_time', 'lat', 'lon','baro_altitude_c','h3_res_7','FL')\
    .join(fir_sdf, on='h3_res_7', how='left')\
    .filter((col('FL') >= col('min_fl')) & (col('FL') <= col('max_fl')))

traj_sdf = traj_sdf.cache()

CPU times: user 4.16 ms, sys: 15.8 ms, total: 20 ms
Wall time: 167 ms


In [6]:
traj_sdf = traj_sdf.groupBy("h3_res_7", "track_id", 'event_time', 'lat', 'lon','baro_altitude_c', 'FL').agg(
    concat_ws(",", array_sort(collect_list("code"))).alias("code"),
    concat_ws(",", array_sort(collect_list("min_fl"))).alias("min_fl"),
    concat_ws(",", array_sort(collect_list("max_fl"))).alias("max_fl")
)

# Define window specification for each track, ordered by event_time
window_spec_minmax = Window.partitionBy("track_id")
window_spec = Window.partitionBy("track_id").orderBy("event_time")

# Identify changes in 'code' over time within each track
traj_sdf = traj_sdf\
    .withColumn(
        "begin_track",
        # This columns tags the beginning of the track with True
        F.when(F.min(col('event_time')).over(window_spec_minmax) == col("event_time"), True).otherwise(False)
    ).withColumn(
        "end_track",
         # This columns tags the end of the track with True
        F.when(F.max(col('event_time')).over(window_spec_minmax) == col("event_time"), True).otherwise(False)
    ).withColumn(
        "code_change",
        # This column tags every row for which the track:
        #    1) begins 
        #    2) ends 
        #    3) is about to cross a FIR boundary with the next state vector
        #    4) has just crossed a FIR boundary from the prev state vector
        F.when((F.lag("code").over(window_spec) != F.col("code")) |\
               (F.lead("code").over(window_spec) != F.col("code")) |\
               col('begin_track') |\
               col('end_track')
               , True).otherwise(False))

traj_sdf = traj_sdf.filter(col('code_change') == True)

traj_sdf = traj_sdf\
    .withColumn('after_lat', F.lead('lat').over(window_spec))\
    .withColumn('after_lon', F.lead('lon').over(window_spec))\
    .withColumn('after_FL', F.lead('FL').over(window_spec))\
    .withColumn('after_event_time', F.lead('event_time').over(window_spec))\
    .withColumn('after_code', F.lead('code').over(window_spec))\
    .withColumn('after_min_fl', F.lead('min_fl').over(window_spec))\
    .withColumn('after_max_fl', F.lead('max_fl').over(window_spec))

traj_sdf = traj_sdf\
    .withColumnRenamed('lat', 'before_lat')\
    .withColumnRenamed('lon', 'before_lon')\
    .withColumnRenamed('FL', 'before_FL')\
    .withColumnRenamed('event_time', 'before_event_time')\
    .withColumnRenamed('code', 'before_code')\
    .withColumnRenamed('min_fl', 'before_min_fl')\
    .withColumnRenamed('max_fl', 'before_max_fl')

columns = ['lat','lon','FL','event_time', 'code', 'min_fl', 'max_fl']

for col_name in columns:
    traj_sdf = traj_sdf\
        .withColumn( # If it's the beginning of the track, we set the after values to the first point (so that mid point is original)
            'after_' + col_name,  
            F.when(col('begin_track'), col('before_'+ col_name)).otherwise(col('after_' + col_name)))\
        .withColumn(
            'after_' + col_name, 
            F.when(col('end_track'), col('before_'+ col_name)).otherwise(col('after_' + col_name)))\
    

traj_sdf = traj_sdf.filter((col('before_code') != col('after_code')) | col('begin_track') | col('end_track'))

all_columns = ['track_id'] +\
    ['before_' + x for x in columns] +\
    ['after_' + x for x in columns]


traj_sdf = traj_sdf.select(all_columns)

# Define the midpoint calculation function
def compute_midpoint(lat1, lon1, fl1, lat2, lon2, fl2):
    """
    Compute the geographic midpoint and altitude-adjusted midpoint 
    between two latitude-longitude-altitude points.
    """
    if None in (lat1, lon1, fl1, lat2, lon2, fl2):
        return None, None, None  # Handle missing data gracefully

    # Convert degrees to radians
    lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])

    # Convert to Cartesian coordinates
    x1, y1, z1 = math.cos(lat1) * math.cos(lon1), math.cos(lat1) * math.sin(lon1), math.sin(lat1)
    x2, y2, z2 = math.cos(lat2) * math.cos(lon2), math.cos(lat2) * math.sin(lon2), math.sin(lat2)

    # Compute midpoint in Cartesian coordinates
    x_m, y_m, z_m = (x1 + x2) / 2, (y1 + y2) / 2, (z1 + z2) / 2

    # Convert back to latitude and longitude
    lon_m = math.atan2(y_m, x_m)
    hyp = math.sqrt(x_m**2 + y_m**2)
    lat_m = math.atan2(z_m, hyp)

    # Compute average altitude (FL)
    fl_m = (fl1 + fl2) / 2

    # Convert radians to degrees
    return float(math.degrees(lat_m)), float(math.degrees(lon_m)), float(fl_m)

# Register UDF with return type as StructType(DoubleType, DoubleType, DoubleType)
midpoint_udf = udf(compute_midpoint, returnType="struct<mid_lat:double, mid_lon:double, mid_FL:double>")

# Load DataFrame (assuming it's already loaded as `df`)
traj_sdf = traj_sdf.withColumn("midpoint", midpoint_udf(col("before_lat"), col("before_lon"), col("before_FL"), 
                                            col("after_lat"), col("after_lon"), col("after_FL")))

# Extract computed values
traj_sdf = traj_sdf.withColumn("mid_lat", col("midpoint.mid_lat")) \
       .withColumn("mid_lon", col("midpoint.mid_lon")) \
       .withColumn("mid_FL", col("midpoint.mid_FL")) \
       .drop("midpoint")

# Compute midpoint timestamp as average of before and after timestamps
traj_sdf = traj_sdf.withColumn(
    "mid_event_time",
    ((col("before_event_time").cast("long") + col("after_event_time").cast("long")) / 2).cast(TimestampType())
).withColumn(
    "mid_time_range", 
    col('after_event_time').cast('long') - col('before_event_time').cast('long'))

traj_sdf = traj_sdf.withColumn(
    "before_event_time", col("before_event_time").cast("string")
).withColumn(
    "after_event_time", col("after_event_time").cast("string")
).withColumn(
    "mid_event_time", col("mid_event_time").cast("string")
)

# Define the Haversine formula using PySpark functions
def haversine_distance(df, lat1, lon1, lat2, lon2, output_col):
    # Convert degrees to radians
    df = df.withColumn(lat1, F.radians(F.col(lat1))) \
           .withColumn(lon1, F.radians(F.col(lon1))) \
           .withColumn(lat2, F.radians(F.col(lat2))) \
           .withColumn(lon2, F.radians(F.col(lon2)))

    # Haversine formula components
    delta_lat = F.col(lat2) - F.col(lat1)
    delta_lon = F.col(lon2) - F.col(lon1)

    a = F.pow(F.sin(delta_lat / 2), 2) + \
        F.cos(F.col(lat1)) * F.cos(F.col(lat2)) * F.pow(F.sin(delta_lon / 2), 2)

    c = 2 * F.atan2(F.sqrt(a), F.sqrt(1 - a))

    # Earth's radius in kilometers
    R = 6371.0

    return df.withColumn(output_col, R * c)

traj_sdf = haversine_distance(traj_sdf, "before_lat", "before_lon", "after_lat", "after_lon", 'mid_distance_range')

traj_sdf = traj_sdf.cache()

In [7]:

# Define the window specification
window_spec = Window.partitionBy("track_id").orderBy("mid_event_time")

# Perform lag operations to shift values within each track_id group
traj_sdf = traj_sdf.withColumns({
    "AIRSPACE_ID": F.lag("after_code").over(window_spec),
    "entry_lon": F.lag("mid_lon").over(window_spec),
    "entry_lat": F.lag("mid_lat").over(window_spec),
    "entry_FL": F.lag("mid_FL").over(window_spec),
    "entry_time": F.lag("mid_event_time").over(window_spec),
    "entry_time_range": F.lag("mid_time_range").over(window_spec),
    "entry_distance_range": F.lag("mid_distance_range").over(window_spec) 
})

# Rename columns for exit values
traj_sdf = traj_sdf.withColumnRenamed("mid_lon", "exit_lon")\
    .withColumnRenamed("mid_lat","exit_lat")\
    .withColumnRenamed("mid_FL", "exit_FL")\
    .withColumnRenamed("mid_event_time", "exit_time")\
    .withColumnRenamed("mid_time_range", "exit_time_range")\
    .withColumnRenamed("mid_distance_range", "exit_distance_range")
    

# Select final columns and filter out rows where AIRSPACE_ID is null
traj_sdf = traj_sdf.select(
    "track_id", "AIRSPACE_ID", "entry_time", "entry_lon", "entry_lat", "entry_FL", "entry_time_range", "entry_distance_range",
    "exit_time", "exit_lon", "exit_lat", "exit_FL", "exit_time_range", "exit_distance_range"
).filter(col("AIRSPACE_ID").isNotNull())

traj_df = traj_sdf.toPandas()

                                                                                

In [8]:
traj_df.to_parquet('crossings_opdi_rounded.parquet')

In [29]:
cols = [
    'track_id',
    'before_code',
    'after_code',
    'mid_lon',
    'mid_lat',
    'mid_FL',
    'mid_event_time'
]

traj_df = traj_df[cols]

traj_df['AIRSPACE_ID'] = traj_df.after_code.shift(1)
traj_df['entry_lon'] = traj_df.mid_lon.shift(1)
traj_df['entry_lat'] = traj_df.mid_lat.shift(1)
traj_df['entry_FL'] = traj_df.mid_FL.shift(1)
traj_df['entry_time'] = traj_df.mid_event_time.shift(1)

traj_df = traj_df.rename({
    'mid_lon':'exit_lon', 
    'mid_lat':'exit_lat', 
    'mid_FL':'exit_FL',
    'mid_FL':'exit_FL',
    'mid_event_time':'exit_time'}, axis=1)

traj_df = traj_df[['track_id', 'entry_time', 'entry_lon', 'entry_lat', 'entry_FL', 'exit_time','exit_lon','exit_lat', 'exit_FL','AIRSPACE_ID']]
traj_df = traj_df[~traj_df.AIRSPACE_ID.isna()]

In [13]:
traj_df.track_id.value_counts()

track_id
1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd325bccff2414f6ae8a6_4_2025_1     149
9986e2b7d038c99eb75facbb0c91325f752be48e5cf7a0f69726309362373d85_2_2025_1     142
9986e2b7d038c99eb75facbb0c91325f752be48e5cf7a0f69726309362373d85_9_2025_1     126
f03a46721d35a8ed6c8a394c923214a721da4eef746f44b35dab4236b4d24892_0_2025_1     119
1a1f3790a6d19ac8826fc03a1e532b25ba04f3d945578e6a7e41b8dea987b57b_0_2025_1     110
                                                                             ... 
012d40b81e0d94ac516680e04faf8c9a58d0779e928d6da497e6484174b50b12_98_2025_1      1
0132efb712207971964151cb4dbb5abafca55809c49002d52dd9b0db5907eb22_19_2025_1      1
01a43a73319c665ebd039a916378e38cde882c8db6cf4cf93e98319dd8cb99c2_48_2025_1      1
01a50ab8ab96f6aed78028653abf47de06829840ed1702939b563a533e884aef_27_2025_1      1
01b1adfa33899456f094bfe7e0c99d50f2ef8a8b6793ff7a80c64ea1322fb550_3_2025_1       1
Name: count, Length: 1304571, dtype: int64

In [14]:
traj_df[traj_df.track_id == '1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd325bccff2414f6ae8a6_4_2025_1']

Unnamed: 0,track_id,before_lat,before_lon,before_FL,before_event_time,before_code,before_min_fl,before_max_fl,after_lat,after_lon,after_FL,after_event_time,after_code,after_min_fl,after_max_fl,mid_lat,mid_lon,mid_FL,mid_event_time
2305113,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,,,,,,,,50.941136,0.930023,0.0,2025-01-12 17:02:30,EGTTFIR,0,245,,,,
2305114,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.003609,1.456451,6.0,2025-01-12 17:13:00,EGTTFIR,0,245,51.005425,1.461639,6.0,2025-01-12 17:13:05,LFFFFIR,0,195,51.004517,1.459045,6.0,
2305115,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.009754,1.470947,6.0,2025-01-12 17:13:25,LFFFFIR,0,195,51.013199,1.479263,6.0,2025-01-12 17:13:30,EGTTFIR,0,245,51.011477,1.475105,6.0,
2305116,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.175323,1.999215,2.0,2025-01-12 17:30:05,EGTTFIR,0,245,51.171066,1.999289,2.0,2025-01-12 17:30:10,EBBUFIR,0,195,51.173195,1.999252,2.0,
2305117,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.138890,1.998520,2.0,2025-01-12 17:31:40,EBBUFIR,0,195,51.121200,1.995773,2.0,2025-01-12 17:31:45,EGTTFIR,0,245,51.130045,1.997146,2.0,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2305257,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.011078,1.497926,2.0,2025-01-12 19:25:00,LFFFFIR,0,195,51.010173,1.493607,2.0,2025-01-12 19:25:05,EGTTFIR,0,245,51.010626,1.495766,2.0,
2305258,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.009201,1.489390,2.0,2025-01-12 19:25:10,EGTTFIR,0,245,51.009009,1.488495,2.0,2025-01-12 19:25:15,LFFFFIR,0,195,51.009105,1.488942,2.0,
2305259,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,51.001923,1.462295,2.0,2025-01-12 19:25:45,LFFFFIR,0,195,50.999588,1.459771,2.0,2025-01-12 19:25:50,EGTTFIR,0,245,51.000755,1.461033,2.0,
2305260,1a450c69faf934d7a4e369cb67eeb17bd0d4658e9dccd3...,50.902267,1.459029,2.0,2025-01-12 19:28:45,EGTTFIR,0,245,50.898331,1.459177,2.0,2025-01-12 19:28:50,LFFFFIR,0,195,50.900299,1.459103,2.0,


In [None]:
spark.sql('SELECT * FROM project_opdi.opdi_flight_list WHERE year(dof) == 2025;')

In [7]:
traj_df.track_id.value_counts()

track_id
38d0deaa5832fa8e5a181726450165d972807ef673404ff95ef2cb232e04b631_6_2025_1      29
3b2918e367250d918f894a4bab2fcb57c50f2f13b64adfa3ef956541d92bfd92_0_2025_1      28
1177b98d81f34ee3224e3b6a4765a394a6cfacbc676dc139ea6924ebc759acbf_0_2025_1      25
3d246b2cde8bf536a721d94618baa364b8da6d661dccb52a15039ca7a3a4392e_0_2025_1      25
25562f47ee4dd92018f19ca0e7fb4a29f5d52a12b221802de90be7f9b1a26aa3_0_2025_1      24
                                                                               ..
0b989d13c8c46b5900cf968d6e4c2ae1bb0b0c41593cb9bef0817257f90e2a5e_25_2025_1      1
0bb817090dd642660a0b7e64c120d768931943bac52c8b8046abf868be6a23d0_32_2025_1      1
0bef62801da3d094805b08b31636bab7531d1118480c9ad6d779df17f09a4276_104_2025_1     1
51368e48ccfab6b6fbe41b1847528dd38e04279aa4e9a8353a753959d27bc6a6_0_2025_1       1
51408f7731e2dc1dd674dd0ace20101a090f2937db8e8c1cf2bb67026b7da403_0_2025_1       1
Name: count, Length: 2458, dtype: int64

In [15]:
traj_df.to_parquet('crossings_example.parquet')

In [None]:
import pandas as pd
import numpy as np
from math import radians, sin, cos, sqrt, atan2

def haversine(lat1, lon1, lat2, lon2):
    """
    Calculate the great-circle distance between two points 
    on the Earth using the Haversine formula.
    
    Parameters:
    lat1, lon1 : float - Latitude and longitude of the first point in degrees
    lat2, lon2 : float - Latitude and longitude of the second point in degrees
    
    Returns:
    float - Distance in kilometers
    """
    R = 6371.0  # Earth's radius in km
    
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))
    
    return R * c

# Ensure proper handling of NaN values before calculation
traj_df["haversine_distance_km"] = traj_df.apply(lambda row: np.nan if any(pd.isnull([row["before_lat"], row["before_lon"], row["after_lat"], row["after_lon"]])) 
                                       else haversine(row["before_lat"], row["before_lon"], row["after_lat"], row["after_lon"]), axis=1)


In [None]:
traj_df

In [None]:
traj_df['distance'] = 

In [None]:
traj_sdf = traj_sdf.withColumn(
    'grouping_id_lead', 
    F.sum(
        F.when(
            col('begin_track') |\
            col('end_track') |\
            (F.lead('code').over(window_spec) != F.col('code'))\
            , 1).otherwise(0))\
    .over(window_spec)).withColumn(
    'grouping_id_lag', # This one is added because you might have subsequent points that are crossings. We will treat these seperately. We can later drop duplicates! 
    F.sum(
        F.when(F.lag('code').over(window_spec) != F.col('code'), 1).otherwise(0))\
    .over(window_spec))

In [None]:
traj_sdf

In [None]:
group_window = Window.partitionBy("track_id", "grouping_id_lead").orderBy("event_time")
traj_sdf = traj_sdf.withColumn(
    "row_num_lead", 
    F.row_number().over(group_window))

traj_sdf = traj_sdf.withColumn(
    "row_num_lead",
    F.when(col('begin_track'), 2).otherwise(col('row_num_lead')))

traj_sdf = traj_sdf.withColumn(
    "row_num_lead",
    F.when(col('end_track'), 1).otherwise(col('row_num_lead')))

traj_sdf = traj_sdf.withColumn(
    "point_type_lead",
    F.when(F.col("row_num") == 1, "exit").otherwise("entry")
)

traj_sdf_lead = traj_sdf.orderBy("event_time").groupBy("track_id", "grouping_id_lead").pivot("point_type_lead").agg(
    F.first("lat").alias("lat"),
    F.first("lon").alias("lon"),
    F.first("event_time").alias("event_time"),
    F.first("FL").alias("FL"),
    F.first("code").alias("code"),
    F.first("min_fl").alias("min_fl"),
    F.first("max_fl").alias("max_fl")
)

In [None]:
group_window = Window.partitionBy("track_id", "grouping_id_lag").orderBy("event_time")
traj_sdf = traj_sdf.withColumn(
    "row_num_lag", 
    F.row_number().over(group_window))

traj_sdf = traj_sdf.withColumn(
    "row_num_lag",
    F.when(col('begin_track'), 2).otherwise(col('row_num_lag')))

traj_sdf = traj_sdf.withColumn(
    "row_num_lag",
    F.when(col('end_track'), 1).otherwise(col('row_num_lag')))

traj_sdf = traj_sdf.withColumn(
    "row_num_lag",
    F.when(F.col("row_num") == 1, "entry").otherwise("exit")
)

traj_sdf_lag = traj_sdf.orderBy("event_time").groupBy("track_id", "grouping_id_lag").pivot("point_type_lag").agg(
    F.first("lat").alias("lat"),
    F.first("lon").alias("lon"),
    F.first("event_time").alias("event_time"),
    F.first("FL").alias("FL"),
    F.first("code").alias("code"),
    F.first("min_fl").alias("min_fl"),
    F.first("max_fl").alias("max_fl")
)


In [None]:
traj_sdf_lag = traj_sdf_lag.withColumn(
    "entry_event_time", col("entry_event_time").cast("string")
).withColumn(
    "exit_event_time", col("exit_event_time").cast("string")
)
df_lag = traj_sdf_lag.limit(10000).toPandas()

In [None]:
traj_sdf_lead = traj_sdf_lead.withColumn(
    "entry_event_time", col("entry_event_time").cast("string")
).withColumn(
    "exit_event_time", col("exit_event_time").cast("string")
)
df_lead = traj_sdf_lead.limit(10000).toPandas()

In [None]:
df_lead[df_lead['entry_code'] != df_lead['exit_code']]

In [None]:
df = pd.concat([df_lag, df_lead], axis=0)
df = df[[x for x in df.columns if x not in ['grouping_id_lag', 'grouping_id_lead']]]

In [None]:
df.drop_duplicates()

In [None]:
traj_sdf = traj_sdf.withColumn(
    "entry_event_time", col("entry_event_time").cast("string")
)

traj_sdf = traj_sdf.withColumn(
    "exit_event_time", col("exit_event_time").cast("string")
)
pd.set_option('display.max_rows', 2338)
tmp = traj_sdf.toPandas()
tmp

In [None]:
tmp.to_parquet('crossing_example.parquet')

In [None]:
import pandas as pd
import plotly.graph_objects as go

# Sample dataframe
df = tmp # Replace with your actual data source

# Convert timestamps to datetime
df["entry_event_time"] = pd.to_datetime(df["entry_event_time"], errors="coerce")
df["exit_event_time"] = pd.to_datetime(df["exit_event_time"], errors="coerce")

# Sort by entry time to maintain correct order
df = df.sort_values(by="entry_event_time")

# Create figure
fig = go.Figure()

# Add lines connecting entry and exit points
for _, row in df.iterrows():
    fig.add_trace(go.Scattermapbox(
        mode="lines",
        lat=[row["entry_lat"], row["exit_lat"]],
        lon=[row["entry_lon"], row["exit_lon"]],
        line=dict(width=2, color="black"),
        name=f"Group {row['grouping_id']}: {row['entry_code']} → {row['exit_code']}",
        hovertext=f"FL {row['entry_FL']} → FL {row['exit_FL']}",
        hoverinfo="text"
    ))

# Add entry points (blue)
fig.add_trace(go.Scattermapbox(
    mode="markers",
    lat=df["entry_lat"],
    lon=df["entry_lon"],
    marker=dict(size=8, color="blue"),
    name="Entry Points",
    hovertext=df["entry_code"],
    hoverinfo="text"
))

# Add exit points (red)
fig.add_trace(go.Scattermapbox(
    mode="markers",
    lat=df["exit_lat"],
    lon=df["exit_lon"],
    marker=dict(size=8, color="red"),
    name="Exit Points",
    hovertext=df["exit_code"],
    hoverinfo="text"
))

# Update layout
fig.update_layout(
    mapbox=dict(style="open-street-map"),
    title="Entry and Exit Points with Flight Paths",
    margin=dict(l=2, r=2, t=40, b=2)
)

# Show plot

fig.write_html('crossing_example.html')

In [None]:
# Define a new window by track and code_group
group_window = Window.partitionBy("track_id", "code_group").orderBy("event_time")
max_window = Window.partitionBy("track_id", "code_group")

# Identify first and last row in each code group
traj_sdf = traj_sdf.withColumn("row_num", F.row_number().over(group_window))
traj_sdf = traj_sdf.withColumn("max_row", F.max("row_num").over(max_window))

In [None]:
# Create a transition group that includes the first row of a new code_group and the last row of the previous group
traj_sdf = traj_sdf.withColumn(
    "transition_code_group",
    F.when(F.col("row_num") == 1, F.col("prev_code_group")).otherwise(F.col("code_group"))
)


In [None]:








# Keep only the first and last row per transition_code_group
filtered_traj_sdf = traj_sdf.filter(
    (F.col("row_num") == 1) | (F.col("row_num") == F.col("max_row"))
)

# Add a column to indicate first or last point
filtered_traj_sdf = filtered_traj_sdf.withColumn(
    "point_type",
    F.when(F.col("row_num") == 1, "entry").otherwise("exit")
)


filtered_traj_sdf.cache()

In [None]:
# Pivot the dataset to have separate columns for first and last points
pivoted_traj_sdf = filtered_traj_sdf.orderBy("event_time").groupBy("track_id", "transition_code_group").pivot("point_type").agg(
    F.first("lat").alias("lat"),
    F.first("lon").alias("lon"),
    F.first("event_time").alias("event_time"),
    F.first("FL").alias("FL"),
    F.first("code").alias("code"),
    F.first("min_fl").alias("min_fl"),
    F.first("max_fl").alias("max_fl")
)

# Show result


In [None]:
pivoted_traj_sdf = pivoted_traj_sdf.withColumn(
    "last_event_time", col("last_event_time").cast("string")
).withColumn(
    "first_event_time", col("first_event_time").cast("string")
)
df = pivoted_traj_sdf.limit(10000).toPandas()

In [None]:
df.track_id

In [None]:
## To-do:
# - Refresh opdi_h3_airspace_ref once Enrico fixes the airac 481 file for FIRs (currently 481 == 406)
# - 

In [None]:
fir_sdf.filter((col('airspace_type') == 'FIR')).select(col('airac_cfmu')).dropDuplicates().show()