In [58]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import udf, pandas_udf, col, PandasUDFType, lit, round, array_contains, from_unixtime
from pyspark.sql.functions import col, radians, sin, cos, sqrt, atan2, array, collect_list, struct, row_number, expr
from pyspark.sql.functions import monotonically_increasing_id, row_number, col
from pyspark.sql.types import DoubleType, StructType, StructField, IntegerType
from pyspark.sql.functions import when, broadcast, split, col, concat_ws,  min, max, to_date, unix_timestamp
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Regular imports
from IPython.display import display, HTML
import os, time
import subprocess
import os,shutil
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
import h3pandas
import h3

# Custom functions
from datetime import datetime, date
import dateutil.relativedelta
import calendar

def generate_months(start_date, end_date):
    """Generate a list of dates corresponding to the first day of each month between two dates.

    Args:
    start_date (datetime.date): The starting date.
    end_date (datetime.date): The ending date.

    Returns:
    list: A list of date objects for the first day of each month within the specified range.
    """
    current = start_date
    months = []
    while current <= end_date:
        months.append(current)
        # Increment month
        month = current.month
        year = current.year
        if month == 12:
            current = date(year + 1, 1, 1)
        else:
            current = date(year, month + 1, 1)
    return months

def get_start_end_of_month(date):
    """
    Return the Unix timestamp for the first and last second of the given month and year.

    Args:
        date (datetime): A datetime object representing any date within the desired month.

    Returns:
        tuple: A tuple containing the Unix timestamp of the first second and last second of the month.
    """
    year = date.year
    month = date.month
    
    # Calculate first and last second of the month
    first_second = datetime(year, month, 1, 0, 0, 0)
    last_day = calendar.monthrange(year, month)[1]
    last_second = datetime(year, month, last_day, 23, 59, 59)
    
    return first_second.timestamp(), last_second.timestamp()

# Settings
project = "project_opdi"
resolution = 7

start_month = date(2022, 1, 1)
end_month = date(2025, 1, 1)

# Getting today's date
today = datetime.today().strftime('%d %B %Y')

# Spark Session Initialization
spark = SparkSession.builder \
    .appName("OPDI Flight Table") \
    .config("spark.rpc.message.maxSize", 512) \
    .config("spark.hadoop.fs.azure.ext.cab.required.group", "eur-app-opdi") \
    .config("spark.kerberos.access.hadoopFileSystems", "abfs://storage-fs@cdpdllive.dfs.core.windows.net/data/project/opdi.db/unmanaged") \
    .config("spark.executor.extraClassPath", "/opt/spark/optional-lib/iceberg-spark-runtime-3.3_2.12-1.3.1.1.20.7216.0-70.jar") \
    .config("spark.driver.extraClassPath", "/opt/spark/optional-lib/iceberg-spark-runtime-3.3_2.12-1.3.1.1.20.7216.0-70.jar") \
    .config("spark.sql.catalog.spark_catalog.type", "hive") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") \
    .config("spark.sql.iceberg.handle-timestamp-without-timezone", "true") \
    .config("spark.sql.catalog.spark_catalog.warehouse", "abfs://storage-fs@cdpdllive.dfs.core.windows.net/data/project/opdi.db/unmanaged") \
    .config("spark.driver.cores", "1") \
    .config("spark.driver.memory", "8G") \
    .config("spark.executor.memory", "8G") \
    .config("spark.executor.memoryOverhead", "3G") \
    .config("spark.executor.cores", "2") \
    .config("spark.executor.instances", "3") \
    .config("spark.dynamicAllocation.maxExecutors", "20") \
    .config("spark.network.timeout", "800s") \
    .config("spark.executor.heartbeatInterval", "400s") \
    .config("spark.driver.maxResultSize", "6g") \
    .config("spark.shuffle.compress", "true") \
    .config("spark.shuffle.spill.compress", "true") \
    .enableHiveSupport() \
    .getOrCreate()

## Query testing sample function

from pyspark.sql.functions import col, lit, from_unixtime, to_timestamp

def get_data_within_timeframe(spark, table_name, month, time_col='event_time'):
    """
    Retrieves records from a specified Spark table within the given timeframe.

    Args:
        spark (SparkSession): The SparkSession object.
        table_name (str): The name of the Spark table to query.
        month (str): The start date of a month in the format 'YYYY-MM-DD'.
        time_col (str): The column name containing timestamp data (default: 'event_time').

    Returns:
        pyspark.sql.dataframe.DataFrame: A DataFrame containing the records within the specified timeframe.
    """
    # Convert the start and end of the month to Unix timestamps
    start_date, end_date = get_start_end_of_month(month)

    # Convert Unix timestamps to Spark timestamp format
    start_date_ts = to_timestamp(lit(start_date))
    end_date_ts = to_timestamp(lit(end_date))

    # Load the table
    df = spark.table(table_name)

    # Filter records based on the timestamp column
    filtered_df = df.filter((col(time_col) >= start_date_ts) & (col(time_col) < end_date_ts))

    return filtered_df

In [86]:
%%time
# Pull trajectories
traj_sdf = get_data_within_timeframe(spark, table_name = 'project_opdi.osn_tracks', month= datetime(2025,1,1), time_col='event_time')
traj_sdf = traj_sdf.filter(col('track_id') == '0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049b9be02ad489bc6523d_0_2025_1')

# Pull FIR
fir_sdf = spark.table('project_opdi.opdi_h3_airspace_ref')
offset = 3
fir_sdf = fir_sdf.filter(
    (col('airspace_type') == 'FIR') &\
    (col('airac_cfmu') == 406) &\
    (col('h3_res_7_lat') >= 26.74617 - offset) &\
    (col('h3_res_7_lat') <= 70.25976 + offset) &\
    (col('h3_res_7_lon') >= -25.86653 - offset) &\
    (col('h3_res_7_lon') <= 49.65699 + offset))
fir_sdf = fir_sdf.dropDuplicates() 
fir_sdf = fir_sdf.select('code','h3_res_7', 'min_fl','max_fl')

# Combine both
traj_sdf = traj_sdf\
    .withColumn('FL',F.round(col('baro_altitude_c')/100))\
    .select('track_id', 'event_time', 'lat', 'lon','baro_altitude_c','h3_res_7','FL')\
    .join(fir_sdf, on='h3_res_7', how='left')\
    .filter((col('FL') >= col('min_fl')) & (col('FL') <= col('max_fl')))

traj_sdf = traj_sdf.cache()

CPU times: user 7.22 ms, sys: 11.9 ms, total: 19.1 ms
Wall time: 580 ms


In [95]:
# Define window specification for each track, ordered by event_time
window_spec_minmax = Window.partitionBy("track_id")
window_spec = Window.partitionBy("track_id").orderBy("event_time")

# Identify changes in 'code' over time within each track
traj_sdf = traj_sdf\
    .withColumn(
        "begin_track",
        # This columns tags the beginning of the track with True
        F.when(F.min(col('event_time')).over(window_spec_minmax) == col("event_time"), True).otherwise(False)
    ).withColumn(
        "end_track",
         # This columns tags the end of the track with True
        F.when(F.max(col('event_time')).over(window_spec_minmax) == col("event_time"), True).otherwise(False)
    ).withColumn(
        "code_change",
        # This column tags every row for which the track:
        #    1) begins 
        #    2) ends 
        #    3) is about to cross a FIR boundary with the next state vector
        #    4) has just crossed a FIR boundary from the prev state vector
        F.when((F.lag("code").over(window_spec) != F.col("code")) |\
               (F.lead("code").over(window_spec) != F.col("code")) |\
               col('begin_track') |\
               col('end_track')
               , True).otherwise(False))

traj_sdf = traj_sdf.filter(col('code_change') == True)

traj_sdf = traj_sdf.withColumn(
    'grouping_id', 
    F.sum(
        F.when(
            col('begin_track') |\
            col('end_track') |\
            (F.lead('code').over(window_spec) != F.col('code'))\
            , 1).otherwise(0))\
    .over(window_spec))

group_window = Window.partitionBy("track_id", "grouping_id").orderBy("event_time")
traj_sdf = traj_sdf.withColumn(
    "row_num", 
    F.row_number().over(group_window))

traj_sdf = traj_sdf.withColumn(
    "row_num",
    F.when(col('begin_track'), 2).otherwise(col('row_num')))

traj_sdf = traj_sdf.withColumn(
    "row_num",
    F.when(col('end_track'), 1).otherwise(col('row_num')))

traj_sdf = traj_sdf.withColumn(
    "point_type",
    F.when(F.col("row_num") == 1, "entry").otherwise("exit")
)


In [96]:
#traj_sdf = traj_sdf.filter(col('code_change') == True)

In [97]:
traj_sdf = traj_sdf.withColumn(
    "event_time", col("event_time").cast("string")
)
pd.set_option('display.max_rows', 2338)
tmp = traj_sdf.toPandas()
tmp

                                                                                

Unnamed: 0,h3_res_7,track_id,event_time,lat,lon,baro_altitude_c,FL,code,min_fl,max_fl,begin_track,end_track,code_change,test,grouping_id,row_num,point_type
0,871ecd741ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:25,41.670025,28.573587,3817.62,38.0,LTBBFIR,0,999,True,False,True,1,1,2,exit
1,871ecdc8bffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:07:55,41.966995,27.783924,6987.54,70.0,LTBBFIR,0,999,False,False,True,1,2,1,entry
2,871ecdc88ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:08:00,41.972672,27.773375,7025.64,70.0,LBSRFIR,0,999,False,False,True,0,2,2,exit
3,871eee4f6ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:30:25,43.686482,24.947139,10972.8,110.0,LBSRFIR,0,999,False,False,True,1,3,1,entry
4,871eee4a9ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:30:30,43.69249,24.935452,10972.8,110.0,LRBBFIR,0,999,False,False,True,0,3,2,exit
5,871e19c50ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 13:04:15,46.147659,20.326449,10972.8,110.0,LRBBFIR,0,999,False,False,True,1,4,1,entry
6,871e19c54ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 13:04:20,46.152878,20.316065,10972.8,110.0,LHCCFIR,0,999,False,False,True,0,4,2,exit
7,871e02590ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 13:26:30,47.531067,16.717209,10972.8,110.0,LHCCFIR,0,999,False,False,True,1,5,1,entry
8,871e02594ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 13:26:35,47.536285,16.70224,10972.8,110.0,LOVVFIR,0,999,False,False,True,0,5,2,exit
9,871e32313ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 13:43:05,48.536992,13.750185,10972.8,110.0,LOVVFIR,0,999,False,False,True,1,6,1,entry


In [11]:
# Define a new window by track and code_group
group_window = Window.partitionBy("track_id", "code_group").orderBy("event_time")
max_window = Window.partitionBy("track_id", "code_group")

# Identify first and last row in each code group
traj_sdf = traj_sdf.withColumn("row_num", F.row_number().over(group_window))
traj_sdf = traj_sdf.withColumn("max_row", F.max("row_num").over(max_window))

In [13]:
# Create a transition group that includes the first row of a new code_group and the last row of the previous group
traj_sdf = traj_sdf.withColumn(
    "transition_code_group",
    F.when(F.col("row_num") == 1, F.col("prev_code_group")).otherwise(F.col("code_group"))
)


In [14]:
traj_sdf = traj_sdf.orderBy("event_time").groupBy("track_id", "transition_code_group").pivot("point_type").agg(
    F.first("lat").alias("lat"),
    F.first("lon").alias("lon"),
    F.first("event_time").alias("event_time"),
    F.first("FL").alias("FL"),
    F.first("code").alias("code"),
    F.first("min_fl").alias("min_fl"),
    F.first("max_fl").alias("max_fl")
)

                                                                                

Unnamed: 0,h3_res_7,track_id,event_time,lat,lon,baro_altitude_c,FL,code,min_fl,max_fl,code_change,code_group,prev_code_group,row_num,max_row,transition_code_group
0,871ecd741ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:25,41.670025,28.573587,3817.62,38.0,LTBBFIR,0,999,0,0,,1,79,
1,871ecd741ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:30,41.674255,28.562518,3870.96,39.0,LTBBFIR,0,999,0,0,0.0,2,79,0.0
2,871ecd745ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:35,41.677613,28.553531,3909.06,39.0,LTBBFIR,0,999,0,0,0.0,3,79,0.0
3,871ecd745ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:40,41.680573,28.545602,3947.16,39.0,LTBBFIR,0,999,0,0,0.0,4,79,0.0
4,871ecd744ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:45,41.68455,28.535071,4000.5,40.0,LTBBFIR,0,999,0,0,0.0,5,79,0.0
5,871ecd744ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:50,41.687669,28.526767,4053.84,41.0,LTBBFIR,0,999,0,0,0.0,6,79,0.0
6,871ecd744ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:01:55,41.690277,28.519759,4099.56,41.0,LTBBFIR,0,999,0,0,0.0,7,79,0.0
7,871ecd744ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:02:00,41.694717,28.507899,4152.9,42.0,LTBBFIR,0,999,0,0,0.0,8,79,0.0
8,871ecd762ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:02:05,41.698143,28.498791,4198.62,42.0,LTBBFIR,0,999,0,0,0.0,9,79,0.0
9,871ecd771ffffff,0005461f97ff2ed81bb9eb7e529401ce81a773cd21a049...,2025-01-13 12:02:10,41.700284,28.493042,4251.96,43.0,LTBBFIR,0,999,0,0,0.0,10,79,0.0


In [4]:








# Keep only the first and last row per transition_code_group
filtered_traj_sdf = traj_sdf.filter(
    (F.col("row_num") == 1) | (F.col("row_num") == F.col("max_row"))
)

# Add a column to indicate first or last point
filtered_traj_sdf = filtered_traj_sdf.withColumn(
    "point_type",
    F.when(F.col("row_num") == 1, "entry").otherwise("exit")
)


filtered_traj_sdf.cache()

DataFrame[h3_res_7: string, track_id: string, event_time: timestamp, lat: double, lon: double, baro_altitude_c: double, FL: double, code: string, min_fl: int, max_fl: int, code_change: int, code_group: bigint, prev_code_group: bigint, row_num: int, max_row: int, transition_code_group: bigint, point_type: string]

In [5]:
# Pivot the dataset to have separate columns for first and last points
pivoted_traj_sdf = filtered_traj_sdf.orderBy("event_time").groupBy("track_id", "transition_code_group").pivot("point_type").agg(
    F.first("lat").alias("lat"),
    F.first("lon").alias("lon"),
    F.first("event_time").alias("event_time"),
    F.first("FL").alias("FL"),
    F.first("code").alias("code"),
    F.first("min_fl").alias("min_fl"),
    F.first("max_fl").alias("max_fl")
)

# Show result


[Stage 55:>                                                         (0 + 1) / 1]

+--------------------+---------------------+------------------+------------------+-------------------+--------+----------+------------+------------+------------------+------------------+-------------------+-------+---------+-----------+-----------+
|            track_id|transition_code_group|         first_lat|         first_lon|   first_event_time|first_FL|first_code|first_min_fl|first_max_fl|          last_lat|          last_lon|    last_event_time|last_FL|last_code|last_min_fl|last_max_fl|
+--------------------+---------------------+------------------+------------------+-------------------+--------+----------+------------+------------+------------------+------------------+-------------------+-------+---------+-----------+-----------+
|000059f343f5301bc...|                    4|              null|              null|               null|    null|      null|        null|        null| 43.63142239845405|1.3702338082449776|2025-01-03 22:37:00|    1.0|  LFBBFIR|          0|        195|
|000

                                                                                

In [8]:
pivoted_traj_sdf = pivoted_traj_sdf.withColumn(
    "last_event_time", col("last_event_time").cast("string")
).withColumn(
    "first_event_time", col("first_event_time").cast("string")
)
df = pivoted_traj_sdf.limit(10000).toPandas()

                                                                                

In [25]:
df.track_id

NameError: name 'df' is not defined

In [None]:
## To-do:
# - Refresh opdi_h3_airspace_ref once Enrico fixes the airac 481 file for FIRs (currently 481 == 406)
# - 

In [24]:
fir_sdf.filter((col('airspace_type') == 'FIR')).select(col('airac_cfmu')).dropDuplicates().show()

+----------+
|airac_cfmu|
+----------+
|       481|
|       406|
+----------+

