# ORTEGA Concurrent/Delayed Interaction

## Set up

In [1]:
# Set up secrets auth
from configparser import ConfigParser
parser = ConfigParser()
_ = parser.read('../../notebook.cfg')
gae_user = parser.get('Engine', 'user')
gae_pass = parser.get('Engine', 'auth')

In [2]:
# Imports
import geoanalytics
import geoanalytics.sql.functions as ST
import pyspark.sql.types as PyType
from geoanalytics.tools import SpatiotemporalJoin

from pyspark.sql.types import *
from pyspark.sql.functions import udf
from pyspark.sql import SparkSession, DataFrame, functions as F
from pyspark.sql.functions import col, count, when
import math
from pyspark.sql.window import Window
import time

In [3]:
# Sign into GeoAnalytics Enginemail
geoanalytics.auth(username=gae_user, password=gae_pass)
print(f"GeoAnalytics v.{geoanalytics.version()} awesome is enabled")

GeoAnalytics v.1.2.0.1291 awesome is enabled


## Read pre-existing parquet file with PPAs
* Note the PPAs are ```geom_PPA``` not the original ```aggr_std_ellipse```

In [4]:
# pull PPA parquet already created
data_dir = r"..\Result_parquet"
data_dir_csv = r"..\Result_csv"
FILE_NAME = "Vulture_context.parquet"

df = spark.read.format("parquet").load(f"{data_dir}\{FILE_NAME}")

In [5]:
# Global setting
longitude_field = "location-long"
latitude_field = "location-lat"
id_field = "individual-local-identifier"
num_field = "event-id"
sampling_rate_threshold = 1

join_longitude_field = "join_" + longitude_field
join_latitude_field = "join_" + latitude_field
join_id_field = "join_" + id_field
join_num_field = "join_" + num_field

In [6]:
window = Window.partitionBy(id_field).orderBy("timestamp")
df = df.withColumn("prev_angle", F.lag("angle").over(window))\
    .withColumn("prev_new_speed_mps", F.coalesce(df["prev_new_speed_mps"], df["new_speed_mps"]))
df = df.withColumn("prev_angle", F.coalesce(df["prev_angle"], df["angle"]))

In [7]:
print(f"Feature count: {df.count()}")
df.printSchema()

Feature count: 1861815
root
 |-- event-id: string (nullable = true)
 |-- individual-local-identifier: string (nullable = true)
 |-- location-lat: string (nullable = true)
 |-- location-long: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- geometry: point (nullable = true)
 |-- speed: string (nullable = true)
 |-- NDVI: double (nullable = true)
 |-- Temp: double (nullable = true)
 |-- tailwind: double (nullable = true)
 |-- crosswind: double (nullable = true)
 |-- prev_NDVI: double (nullable = true)
 |-- prev_Temp: double (nullable = true)
 |-- prev_tailwind: double (nullable = true)
 |-- prev_crosswind: double (nullable = true)
 |-- prev_lat: string (nullable = true)
 |-- prev_long: string (nullable = true)
 |-- prev_timestamp: timestamp (nullable = true)
 |-- prev_speed: string (nullable = true)
 |-- distance: double (nullable = true)
 |-- delta_time: long (nullable = true)
 |-- new_speed_mps: double (nullable = true)
 |-- prev_new_speed_mps: double (nullable

### Make list of all pairs for interaction model
* 3486 total pairs

In [9]:
names_df = df.select(id_field).distinct()
pairs = names_df.alias("id1").crossJoin((names_df.alias("id2")))\
        .filter(F.col("id1."+id_field) < F.col("id2."+id_field)).collect()
print(f"Number of birds: {names_df.count()}")
print(f"Number of pairs: {len(pairs)}")

Number of birds: 84
Number of pairs: 3486


In [10]:
# for testing - streamline to one bird of interest
BIRD_NAME = "Earl"

one_bird_pairs = []
for pair in pairs:
    if pair[0] == BIRD_NAME or pair[1] == BIRD_NAME:
        one_bird_pairs.append(pair)

print(f"Number of one_bird_pairs: {len(one_bird_pairs)}")

Number of one_bird_pairs: 83


## Concurrent Interaction detection
To identify a concurrent interaction (individuals move synchronously in space and time), we will filter time/spatial overlap

In [11]:
df= df.st.set_geometry_field("geom_PPA")
df = df.st.set_time_fields("prev_timestamp", "timestamp")

In [12]:
df.st.get_geometry_field()

'geom_PPA'

In [13]:
df.st.get_time_fields()

['prev_timestamp', 'timestamp']

### Extract_attributes function

In [14]:
def extract_attributes(df, columns):
    for col_name in columns:
        # Calculating the mean for each individual
        df = df.withColumn(f"p1_attrs_{col_name}", (col(f"prev_{col_name}") + col(f"{col_name}")) / 2) \
               .withColumn(f"p2_attrs_{col_name}", (col(f"join_prev_{col_name}") + col(f"join_{col_name}")) / 2)
    
        # Calculating the mean between individuals
        df = df.withColumn(f"attrs_mean_{col_name}", (col(f"p1_attrs_{col_name}") + col(f"p2_attrs_{col_name}")) / 2)
        
        # Calculating the difference between individuals
        df = df.withColumn(f"attrs_diff_{col_name}", col(f"p1_attrs_{col_name}") - col(f"p2_attrs_{col_name}"))

    return df

### Check continuous concurrent interaction segments and compute duration

In [15]:
def identify_continuous_segments(df, id_field, join_id_field, sampling_rate_threshold):
    # Define window specification for ordering data by timestamp
    windowSpec = Window.partitionBy(id_field, join_id_field).orderBy("prev_timestamp")

    # Add columns for the next row's times to compare with the current one for continuity and overlap
    df = df.withColumn("next_p1_start", F.lead("prev_timestamp").over(windowSpec))
    df = df.withColumn("next_p2_start", F.lead("join_prev_timestamp").over(windowSpec))
    df = df.withColumn("next_p1_end", F.lead("timestamp").over(windowSpec))
    df = df.withColumn("next_p2_end", F.lead("join_timestamp").over(windowSpec))

    # Determine if there is an actual overlap between consecutive segments
    df = df.withColumn("actual_overlap",
                       (F.col("timestamp") > F.col("prev_timestamp")) &
                       (F.col("join_timestamp") > F.col("join_prev_timestamp")) &
                       (F.col("timestamp") <= F.col("next_p1_start")) &
                       (F.col("join_timestamp") <= F.col("next_p2_start")))

    # Determine if a new segment should start
    df = df.withColumn("new_segment",
                       F.when(~F.col("actual_overlap") | 
                              (F.col("timestamp") < F.col("next_p1_end")) | 
                              (F.col("join_timestamp") < F.col("next_p2_end")),
                              1).otherwise(0))

    # Cumulatively sum the new_segment flags to identify continuous segments
    df = df.withColumn("segment_id", F.sum("new_segment").over(windowSpec))

    # Calculate the difference between start times of two individuals (in hours)
    df = df.withColumn("difference",
                       F.greatest(F.abs((F.unix_timestamp("prev_timestamp") - F.unix_timestamp("join_prev_timestamp")) / 3600),
                                  F.abs((F.unix_timestamp("timestamp") - F.unix_timestamp("join_timestamp")) / 3600)))

    # Filter out rows where difference is greater than the sampling rate threshold
    df = df.filter(F.col("difference") <= sampling_rate_threshold)

    # Group by the continuous segment identifier and calculate min and max times for each individual
    segment_df = df.groupBy("segment_id", id_field, join_id_field).agg(
        F.min("prev_timestamp").alias("p1_start"),
        F.max("timestamp").alias("p1_end"),
        F.min("join_prev_timestamp").alias("p2_start"),
        F.max("join_timestamp").alias("p2_end"),
        F.max("difference").alias("max_difference")  # Maximum difference for each segment
    )

    # Define window specification for ordering data by p1_start
    windowSpecByStart = Window.orderBy("p1_start")

    # Create lag columns to check for overlaps
    segment_df = segment_df.withColumn("prev_p1_end", F.lag("p1_end").over(windowSpecByStart))
    segment_df = segment_df.withColumn("prev_p2_end", F.lag("p2_end").over(windowSpecByStart))

    # Define the overlap condition
    overlap_condition = (
        (F.col("p1_start") <= F.col("prev_p1_end")) |
        (F.col("p2_start") <= F.col("prev_p2_end"))
    )

    # Apply the condition to determine the segment continuation
    segment_df = segment_df.withColumn("new_segment", (~overlap_condition | F.isnull(F.col("prev_p1_end")) | F.isnull(F.col("prev_p2_end"))).cast("int"))
    segment_df = segment_df.withColumn("continuous_segment_id", F.sum("new_segment").over(windowSpecByStart))

    # Aggregate continuous segments
    continuous_segments_df = segment_df.groupBy("continuous_segment_id",id_field, join_id_field).agg(
        F.min("p1_start").alias("p1_start"),
        F.max("p1_end").alias("p1_end"),
        F.min("p2_start").alias("p2_start"),
        F.max("p2_end").alias("p2_end"),
        F.max("max_difference").alias("max_difference")  # Maximum difference for each continuous segment
    )

    continuous_segments_df = continuous_segments_df\
                                .withColumnRenamed(id_field, "p1").withColumnRenamed(join_id_field, "p2")\
                                .orderBy(col("p1").asc(), col("p2").asc(), col("continuous_segment_id").asc())

    return continuous_segments_df

In [16]:
def compute_duration(continuous_segments_df):
    continuous_segments_df = continuous_segments_df.withColumn("duration",
                                       (F.greatest("p1_end", "p2_end").cast("long") - F.least("p1_start", "p2_start").cast("long")) / 60)
    return continuous_segments_df

### Concurrent interaction functions

In [17]:
# Define attribute list to annotate
attribute_list = ['new_speed_mps', 'angle', 'NDVI', 'Temp', 'tailwind', 'crosswind']
# Columns to write out
select_columns = [num_field, id_field, "prev_timestamp", "timestamp", "prev_long", "prev_lat", longitude_field, latitude_field, 
                  join_num_field, join_id_field, "join_prev_timestamp", "join_timestamp", "join_prev_long", "join_prev_lat", join_longitude_field, join_latitude_field]

# Add mean and difference columns for each attribute 
for col_name in attribute_list:
    select_columns.extend([
        F.col(f'p1_attrs_{col_name}'),
        F.col(f'p2_attrs_{col_name}'),
        F.col(f'attrs_mean_{col_name}'),
        F.col(f'attrs_diff_{col_name}')
    ])

In [18]:
# Calculate min and max timestamps for each ID
time_bounds_df = df.groupBy(id_field) \
                   .agg(F.min("timestamp").alias("min_timestamp"),
                        F.max("timestamp").alias("max_timestamp"))

# Collect the results as a list of rows
time_bounds_list = time_bounds_df.collect()

# Convert list of rows to a dictionary
time_bounds = {row['individual-local-identifier']: {'min_timestamp': row['min_timestamp'], 
                                                   'max_timestamp': row['max_timestamp']} 
               for row in time_bounds_list}

# Broadcast the time bounds dictionary
broadcast_bounds = spark.sparkContext.broadcast(time_bounds)

In [19]:
# Calculate spatial bounds for each ID
spatial_bounds_df = df.groupBy(id_field)\
                         .agg(F.min(longitude_field).alias("min_long"),
                              F.max(longitude_field).alias("max_long"),
                              F.min(latitude_field).alias("min_lat"),
                              F.max(latitude_field).alias("max_lat"))

# Collect the results as a list of rows
spatial_bounds_list = spatial_bounds_df.collect()

# Convert list of rows to a dictionary
spatial_bounds = {row['individual-local-identifier']: {'min_long': row['min_long'], 
                                                       'max_long': row['max_long'],
                                                       'min_lat': row['min_lat'], 
                                                       'max_lat': row['max_lat']}
                  for row in spatial_bounds_list}

# Broadcast the spatial bounds dictionary
broadcast_spatial_bounds = spark.sparkContext.broadcast(spatial_bounds)

# Function to check spatial overlap
def has_spatial_overlap(bounds1, bounds2):
    return not (bounds1['max_long'] < bounds2['min_long'] or 
                bounds2['max_long'] < bounds1['min_long'] or 
                bounds1['max_lat'] < bounds2['min_lat'] or 
                bounds2['max_lat'] < bounds1['min_lat'])


In [20]:
def run_spatiotemporal_join(pairs, id_field, join_id_field, sampling_rate_threshold, df, output_path, broadcast_bounds, broadcast_spatial_bounds, attribute_list, select_columns, file_name):
    """
    Runs a spatiotemporal join for each pair in the provided list and includes the population name in the output file name.
    
    :param pairs: List of ID pairs to process.
    :param df: The DataFrame containing the data.
    :param output_base_path: Base path for the output files.
    :param broadcast_bounds: Broadcasted temporal bounds.
    :param broadcast_spatial_bounds: Broadcasted spatial bounds.
    :param attribute_list: List of attributes for extraction.
    :param select_columns: Columns to select for the output.
    :param file_name: Name of the population to include in the file name.
    """

    spark.catalog.clearCache()

    filename_intersect = f"{output_path}/concurrent_intersect_PPAs_{file_name}"
    filename_duration = f"{output_path}/concurrent_events_{file_name}"

    pair_count = 0 # used to overwrite parquet for first pair; we append for the subsequent pairs
    pairs_processed = 0 # total pairs checked
    pairs_joined = 0 # total pairs joined (does not count where there is no temporal overlap)
    print("Writing status every 1000 pairs")
        
    for pair in pairs:
        start_time = time.time()  # Start timer
        
        pairs_processed += 1
        if (pairs_processed%1000 == 0):
            print(f"...{pairs_processed} finished")

        id1, id2 = pair
        bounds1 = broadcast_bounds.value.get(id1)
        bounds2 = broadcast_bounds.value.get(id2)
        spatial_bounds1 = broadcast_spatial_bounds.value.get(id1)
        spatial_bounds2 = broadcast_spatial_bounds.value.get(id2)
    
        # Check if both temporal and spatial bounds are available and overlap
        if bounds1 and bounds2 and spatial_bounds1 and spatial_bounds2:
            if not (bounds1['max_timestamp'] < bounds2['min_timestamp'] or 
                    bounds2['max_timestamp'] < bounds1['min_timestamp']) and \
               has_spatial_overlap(spatial_bounds1, spatial_bounds2):
                
                # print(f"Pair with overlap: {id1}, {id2}") #only print pairs with temporal and spatial overlaps
    
                df1 = df.filter(F.col(id_field) == id1)
                df2 = df.filter(F.col(id_field) == id2)
        
                # SpatiotemporalJoin
                join_result = SpatiotemporalJoin() \
                                    .setJoinOneToMany()\
                                    .setSpatialRelationship(spatial_relationship="Intersects") \
                                    .setTemporalRelationship("Intersects")\
                                    .run(target_dataframe=df1, join_dataframe=df2)
                # print("Finish finding intersecting PPAs")
        
                # Contextualize
                join_result = extract_attributes(join_result, attribute_list)

                pairs_joined += 1
                if (pairs_joined%100 == 0):
                    print(f"\t...{pairs_joined} joined")
        
                # Append each join_result to the Parquet file
                mode = "append" if pair_count > 0 else "overwrite"
                join_result.select(*select_columns).write.parquet(filename_intersect, mode=mode)
                pair_count += 1

                end_time = time.time()
                elapsed_time = end_time - start_time
                print(f"Processing time for pair {id1}-{id2}: {elapsed_time:.2f} seconds")

    # After processing all pairs, read the combined intersect PPAs from Parquet
    combined_intersect_ppas = spark.read.format("parquet").load(filename_intersect)

    # Sort the combined_intersect_ppas DataFrame
    combined_intersect_ppas = combined_intersect_ppas.orderBy(
            col(id_field).asc(), 
            col(join_id_field).asc(), 
            col("prev_timestamp").asc()
        )
    
    # Identify continuous segments and compute duration on the combined data
    continuous_segments = identify_continuous_segments(combined_intersect_ppas, id_field, join_id_field, sampling_rate_threshold)
    continuous_segments = compute_duration(continuous_segments)

    # Write out combined continuous segments to a single file
    continuous_segments.write.parquet(filename_duration, mode="overwrite")


### Test

In [27]:
%%time
"""
--conf spark.driver.memory=6g ^
--conf spark.executor.memory=10g
"""
run_spatiotemporal_join(one_bird_pairs, id_field, join_id_field, sampling_rate_threshold, df, data_dir, broadcast_bounds, broadcast_spatial_bounds, attribute_list, select_columns, "test_Earl_6_10")

Writing status every 100 pairs
Processing time for pair Earl-Hugh: 40.15 seconds
Processing time for pair Earl-MooMoo: 31.81 seconds
Processing time for pair Earl-Ethan: 37.23 seconds
Processing time for pair Earl-Irma: 34.05 seconds
Processing time for pair Earl-Leo: 40.31 seconds
Processing time for pair Earl-Julie: 51.97 seconds
Processing time for pair Earl-Gifford: 107.66 seconds
Processing time for pair Black Knight-Earl: 36.43 seconds
Processing time for pair David-Earl: 35.74 seconds
CPU times: total: 516 ms
Wall time: 6min 58s


In [22]:
# check the output parquet for events
df_ppa = spark.read.format("parquet").load(f"{data_dir}\concurrent_intersect_PPAs_test_Earl_10_40")
df_ppa.show(5)

+----------+---------------------------+-------------------+-------------------+---------+--------+-------------+------------+-------------+--------------------------------+-------------------+-------------------+--------------+-------------+------------------+-----------------+----------------------+----------------------+------------------------+------------------------+------------------+------------------+------------------+-------------------+-------------+-------------+-------------------+--------------------+------------------+-------------+---------------+--------------------+--------------------+-----------------+-------------------+-------------------+--------------------+------------------+--------------------+--------------------+
|  event-id|individual-local-identifier|     prev_timestamp|          timestamp|prev_long|prev_lat|location-long|location-lat|join_event-id|join_individual-local-identifier|join_prev_timestamp|     join_timestamp|join_prev_long|join_prev_lat|join_

In [23]:
# check the output parquet for events
df_result = spark.read.format("parquet").load(f"{data_dir}\concurrent_events_test_Earl_10_40")
df_result.show(5)

+---------------------+----+-----+-------------------+-------------------+-------------------+-------------------+------------------+-----------------+
|continuous_segment_id|  p1|   p2|           p1_start|             p1_end|           p2_start|             p2_end|    max_difference|         duration|
+---------------------+----+-----+-------------------+-------------------+-------------------+-------------------+------------------+-----------------+
|                    1|Earl|Ethan|2015-07-16 12:18:02|2015-07-16 14:05:29|2015-07-16 12:37:00|2015-07-16 13:38:00|            0.9975|           107.45|
|                    3|Earl|Ethan|2015-07-18 10:39:13|2015-07-18 11:39:22|2015-07-18 10:38:00|2015-07-18 11:39:00|0.9758333333333333|61.36666666666667|
|                    5|Earl|Ethan|2015-07-19 10:59:57|2015-07-19 19:31:22|2015-07-19 10:38:00|2015-07-19 18:48:00|0.9997222222222222|533.3666666666667|
|                    6|Earl|Ethan|2015-07-20 01:10:30|2015-07-20 07:32:05|2015-07-20 01:

In [24]:
# write as CSV?
df_result.coalesce(1).write.option("header", True).csv(f"{data_dir_csv}\concurrent_events_test_Earl.csv", mode="overwrite")

### All pairs

In [29]:
%%time
"""
--conf spark.driver.memory=6g ^
--conf spark.executor.memory=10g
"""
run_spatiotemporal_join(pairs, id_field, join_id_field, sampling_rate_threshold, df, data_dir, broadcast_bounds, broadcast_spatial_bounds, attribute_list, select_columns, "All_run3")

Writing status every 1000 pairs
Processing time for pair Cinco-Fenix: 32.19 seconds
Processing time for pair Cinco-Whitey: 17.97 seconds
Processing time for pair Cinco-Two-thirds: 36.66 seconds
Processing time for pair Cinco-Halfway: 32.04 seconds
Processing time for pair Cinco-Mandeb: 16.56 seconds
Processing time for pair Cinco-Young Luro: 19.20 seconds
Processing time for pair Steve-Tekoa: 39.37 seconds
Processing time for pair Steve-Tire Pile: 35.52 seconds
Processing time for pair Steve-Tintin: 36.37 seconds
Processing time for pair Steve-Superior: 33.90 seconds
Processing time for pair Steve-Versace: 49.99 seconds
Processing time for pair Steve-Thomas: 32.63 seconds
Processing time for pair Mary-Steamhouse 1: 19.03 seconds
Processing time for pair Mary-Steamhouse 2: 16.14 seconds
Processing time for pair Schaumboch-Sill: 30.74 seconds
Processing time for pair Prado-Steamhouse 1: 26.29 seconds
Processing time for pair Prado-Steamhouse 2: 24.66 seconds
Processing time for pair Prad

In [26]:
# check the output parquet for events
df_result = spark.read.format("parquet").load(f"{data_dir}\concurrent_intersect_PPAs_All")
df_result.show(10)

+-----------+---------------------------+-------------------+-------------------+----------+--------+-------------+------------+-------------+--------------------------------+-------------------+-------------------+--------------+-------------+------------------+-----------------+----------------------+----------------------+------------------------+------------------------+------------------+------------------+------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------------+------------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+--------------------+--------------------+
|   event-id|individual-local-identifier|     prev_timestamp|          timestamp| prev_long|prev_lat|location-long|location-lat|join_event-id|join_individual-local-identifier|join_prev_timestamp|     join_timestamp|join_

## Delayed Interaction detection
To identify a delayed interaction (individuals visit the same location asynchronously), we will:
- setTemporalRelationship("Near", near_duration = int, near_duration_unit = "Milliseconds, Seconds, Minutes, Hours, Days, Weeks, Months, Years")
- `Near`,`NearBefore` or `NearAfter`

### Function

In [21]:
def run_delay_spatiotemporal_join(pairs, id_field, join_id_field, df, output_path, broadcast_bounds, broadcast_spatial_bounds, attribute_list, select_columns, file_name, delay_days):
    """
    Runs a spatiotemporal join for each pair in the provided list and includes the population name in the output file name.
    
    :param pairs: List of ID pairs to process.
    :param df: The DataFrame containing the data.
    :param output_base_path: Base path for the output files.
    :param broadcast_bounds: Broadcasted temporal bounds.
    :param broadcast_spatial_bounds: Broadcasted spatial bounds.
    :param attribute_list: List of attributes for extraction.
    :param select_columns: Columns to select for the output.
    :param file_name: Name of the population to include in the file name.
    """

    spark.catalog.clearCache()
    
    filename_intersect = f"{output_path}/delay_{delay_days}d_intersect_PPAs_{file_name}"
    pair_count = 0 # used to overwrite parquet for first pair; we append for the subsequent pairs
    pairs_processed = 0 # total pairs checked
    pairs_joined = 0 # total pairs joined (does not count where there is no temporal overlap)
    print("Writing status every 100 pairs")
        
    for pair in pairs:
        start_time = time.time()  # Start timer
        pairs_processed += 1
        if (pairs_processed%100 == 0):
            print(f"...{pairs_processed} finished")

        id1, id2 = pair
        bounds1 = broadcast_bounds.value.get(id1)
        bounds2 = broadcast_bounds.value.get(id2)
        spatial_bounds1 = broadcast_spatial_bounds.value.get(id1)
        spatial_bounds2 = broadcast_spatial_bounds.value.get(id2)
    
        # Check if both temporal and spatial bounds are available and overlap
        if bounds1 and bounds2 and spatial_bounds1 and spatial_bounds2:
            if not (bounds1['max_timestamp'] < bounds2['min_timestamp'] or 
                    bounds2['max_timestamp'] < bounds1['min_timestamp']) and \
               has_spatial_overlap(spatial_bounds1, spatial_bounds2):
                
                # print(f"Pair with overlap: {id1}, {id2}") #only print pairs with temporal and spatial overlaps
    
                df1 = df.filter(F.col(id_field) == id1)
                df2 = df.filter(F.col(id_field) == id2)
        
                # SpatiotemporalJoin
                join_result = SpatiotemporalJoin() \
                        .setJoinOneToMany()\
                        .setSpatialRelationship(spatial_relationship="Intersects") \
                        .setTemporalRelationship("Near", near_duration = delay_days, near_duration_unit = "Days")\
                        .run(target_dataframe=df1, join_dataframe=df2)
                # print("Finish finding intersecting PPAs")
        
                # Contextualize
                join_result = extract_attributes(join_result, attribute_list)

                pairs_joined += 1
                if (pairs_joined%100 == 0):
                    print(f"\t...{pairs_joined} joined")
        
                # Append each join_result to the Parquet file
                mode = "append" if pair_count > 0 else "overwrite"
                join_result.select(*select_columns).write.parquet(filename_intersect, mode=mode)
                pair_count += 1

                end_time = time.time()
                elapsed_time = end_time - start_time
                print(f"Processing time for pair {id1}-{id2}: {elapsed_time:.2f} seconds")

In [22]:
def identify_continuous_segments_delay(df, id_field, join_id_field):
    # Define window specification for ordering data by timestamp
    windowSpec = Window.partitionBy(id_field, join_id_field).orderBy("prev_timestamp")

    # Add columns for the next row's times to compare with the current one for continuity and overlap
    df = df.withColumn("next_p1_start", F.lead("prev_timestamp").over(windowSpec))
    df = df.withColumn("next_p2_start", F.lead("join_prev_timestamp").over(windowSpec))
    df = df.withColumn("next_p1_end", F.lead("timestamp").over(windowSpec))
    df = df.withColumn("next_p2_end", F.lead("join_timestamp").over(windowSpec))

    # Determine if there is an actual overlap between consecutive segments
    df = df.withColumn("actual_overlap",
                       (F.col("timestamp") > F.col("prev_timestamp")) &
                       (F.col("join_timestamp") > F.col("join_prev_timestamp")) &
                       (F.col("timestamp") <= F.col("next_p1_start")) &
                       (F.col("join_timestamp") <= F.col("next_p2_start")))

    # Determine if a new segment should start
    df = df.withColumn("new_segment",
                       F.when(~F.col("actual_overlap") | 
                              (F.col("timestamp") < F.col("next_p1_end")) | 
                              (F.col("join_timestamp") < F.col("next_p2_end")),
                              1).otherwise(0))

    # Cumulatively sum the new_segment flags to identify continuous segments
    df = df.withColumn("segment_id", F.sum("new_segment").over(windowSpec))

    # # Calculate the difference between start times of two individuals (in minutes)
    # df = df.withColumn("difference",
    #                    F.greatest(F.abs((F.unix_timestamp("prev_timestamp") - F.unix_timestamp("join_prev_timestamp")) / 60),
    #                               F.abs((F.unix_timestamp("timestamp") - F.unix_timestamp("join_timestamp")) / 60)))

    # Group by the continuous segment identifier and calculate min and max times for each individual
    segment_df = df.groupBy("segment_id", id_field, join_id_field).agg(
        F.min("prev_timestamp").alias("p1_start"),
        F.max("timestamp").alias("p1_end"),
        F.min("join_prev_timestamp").alias("p2_start"),
        F.max("join_timestamp").alias("p2_end"),
        F.max("difference").alias("max_difference")  # Maximum difference for each segment
    )

    # Define window specification for ordering data by p1_start
    windowSpecByStart = Window.orderBy("p1_start")

    # Create lag columns to check for overlaps
    segment_df = segment_df.withColumn("prev_p1_end", F.lag("p1_end").over(windowSpecByStart))
    segment_df = segment_df.withColumn("prev_p2_end", F.lag("p2_end").over(windowSpecByStart))

    # Define the overlap condition
    overlap_condition = (
        (F.col("p1_start") <= F.col("prev_p1_end")) |
        (F.col("p2_start") <= F.col("prev_p2_end"))
    )

    # Apply the condition to determine the segment continuation
    segment_df = segment_df.withColumn("new_segment", (~overlap_condition | F.isnull(F.col("prev_p1_end")) | F.isnull(F.col("prev_p2_end"))).cast("int"))
    segment_df = segment_df.withColumn("continuous_segment_id", F.sum("new_segment").over(windowSpecByStart))

    # Aggregate continuous segments
    continuous_segments_df = segment_df.groupBy("continuous_segment_id",id_field, join_id_field).agg(
        F.min("p1_start").alias("p1_start"),
        F.max("p1_end").alias("p1_end"),
        F.min("p2_start").alias("p2_start"),
        F.max("p2_end").alias("p2_end"),
        F.max("max_difference").alias("max_difference")  # Maximum difference for each continuous segment
    )

    continuous_segments_df = continuous_segments_df\
                                .withColumnRenamed(id_field, "p1").withColumnRenamed(join_id_field, "p2")\
                                .orderBy(col("p1").asc(), col("p2").asc(), col("continuous_segment_id").asc())

    return continuous_segments_df

In [23]:
# Filter down to shorter time lag
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

def filter_delay_PPAs(df: DataFrame, id_field, join_id_field, output_dir: str, file_name: str):
    """
    Filters the DataFrame based on specified time intervals, identifies continuous segments,
    and writes each interval to a separate Parquet file.

    :param df: DataFrame containing the intersecting PPAs with a 'difference' column in hours.
    :param output_dir: Base path for the output files.
    :param file_name: Name of output file.
    """
    # Define time intervals in hours
    one_day = 24
    one_week = 7 * one_day
    two_weeks = 2 * one_week
    three_weeks = 3 * one_week
    four_weeks = 4 * one_week

    # Time interval filters
    intervals = {
        "delay_1d": (1, one_day),
        "delay_1w": (one_day, one_week),
        "delay_2w": (one_week, two_weeks),
        "delay_3w": (two_weeks, three_weeks),
        "delay_4w": (three_weeks, four_weeks)
    }

    # Add a time difference column (hour)
    df = df.withColumn("difference",
                       F.greatest(F.abs((F.unix_timestamp("prev_timestamp") - F.unix_timestamp("join_prev_timestamp")) / 3600),
                                  F.abs((F.unix_timestamp("timestamp") - F.unix_timestamp("join_timestamp")) / 3600)))

    for suffix, (lower_bound, upper_bound) in intervals.items():
        filtered_df = df.filter((F.col("difference") > lower_bound) & (F.col("difference") <= upper_bound))
        continuous_segments_df = identify_continuous_segments_delay(filtered_df)
        
        file_path_filtered = f"{output_dir}/{suffix}_intersect_PPAs_{file_name}"
        filtered_df.write.parquet(file_path_filtered, mode="overwrite")
        
        file_path_event = f"{output_dir}/{suffix}_events_{file_name}"
        continuous_segments_df.write.parquet(file_path_event, mode="overwrite")

In [24]:
def dataframe_info(spark_df):
    # Get total number of rows
    total_rows = spark_df.count()

    # Iterate over columns and count nulls
    null_counts = [count(when(col(c).isNull(), c)).alias(c) for c in spark_df.columns]
    null_info = spark_df.agg(*null_counts).collect()[0]

    # Print information
    print(f"Dataframe has {len(spark_df.columns)} columns and {total_rows} rows\n")
    print("Column\t\tNon-Null Count\tDataType")
    for c in spark_df.columns:
        non_null_count = total_rows - null_info[c]
        print(f"{c}\t\t{non_null_count}\t\t{spark_df.schema[c].dataType}")

In [25]:
# Define attribute list to annotate
attribute_list = ['new_speed_mps', 'angle', 'NDVI', 'Temp', 'tailwind', 'crosswind']
# Columns to write out
select_columns = [num_field, id_field, "prev_timestamp", "timestamp", "prev_long", "prev_lat", longitude_field, latitude_field, 
                  join_num_field, join_id_field, "join_prev_timestamp", "join_timestamp", "join_prev_long", "join_prev_lat", join_longitude_field, join_latitude_field]

# Add mean and difference columns for each attribute 
for col_name in attribute_list:
    select_columns.extend([
        F.col(f'p1_attrs_{col_name}'),
        F.col(f'p2_attrs_{col_name}'),
        F.col(f'attrs_mean_{col_name}'),
        F.col(f'attrs_diff_{col_name}')
    ])

### Test

In [26]:
df = df.st.set_geometry_field("geom_PPA")
df = df.st.set_time_fields("prev_timestamp", "timestamp")

In [26]:
%%time
"""
--conf spark.driver.memory=6g ^
--conf spark.executor.memory=10g
"""
run_delay_spatiotemporal_join(one_bird_pairs, id_field, join_id_field, df, data_dir, broadcast_bounds, broadcast_spatial_bounds, attribute_list, select_columns, "test_Earl_6_10", 28)

Writing status every 100 pairs
Processing time for pair Earl-Hugh: 116.36 seconds
Processing time for pair Earl-MooMoo: 30.20 seconds
Processing time for pair Earl-Ethan: 63.84 seconds
Processing time for pair Earl-Irma: 30.20 seconds
Processing time for pair Earl-Leo: 38.21 seconds
Processing time for pair Earl-Julie: 45.57 seconds
Processing time for pair Earl-Gifford: 507.86 seconds
Processing time for pair Black Knight-Earl: 30.60 seconds
Processing time for pair David-Earl: 33.49 seconds
CPU times: total: 609 ms
Wall time: 14min 56s


In [38]:
# check the output parquet for intersecting PPAs
delay4w_test_df = spark.read.format("parquet").load(f"{data_dir}\delay_28d_intersect_PPAs_test_Earl")
# Sort the combined_intersect_ppas DataFrame
delay4w_test_df = delay4w_test_df.orderBy(
        col(id_field).asc(), 
        col(join_id_field).asc(), 
        col("prev_timestamp").asc()
    )
delay4w_test_df.show(5)

+----------+---------------------------+-------------------+-------------------+---------+--------+-------------+------------+-------------+--------------------------------+-------------------+-------------------+--------------+-------------+------------------+-----------------+----------------------+----------------------+------------------------+------------------------+------------------+------------------+------------------+-------------------+------------------+------------------+------------------+--------------------+------------------+------------------+------------------+-------------------+-------------------+-----------------+-------------------+-------------------+-------------------+------------------+--------------------+--------------------+
|  event-id|individual-local-identifier|     prev_timestamp|          timestamp|prev_long|prev_lat|location-long|location-lat|join_event-id|join_individual-local-identifier|join_prev_timestamp|     join_timestamp|join_prev_long|join_

In [40]:
%%time
filter_delay_PPAs(delay4w_test_df, id_field, join_id_field, data_dir, "test_Earl")

CPU times: total: 219 ms
Wall time: 1min 31s


In [41]:
# check the result
delay_result = spark.read.format("parquet").load(f"{data_dir}\delay_3w_events_test_Earl")
delay_result.show(50)

+---------------------+----+-------+-------------------+-------------------+-------------------+-------------------+------------------+
|continuous_segment_id|  p1|     p2|           p1_start|             p1_end|           p2_start|             p2_end|    max_difference|
+---------------------+----+-------+-------------------+-------------------+-------------------+-------------------+------------------+
|                    1|Earl|  Ethan|2015-07-16 07:05:43|2015-11-30 11:08:31|2015-06-28 10:21:00|2015-12-15 13:25:00|503.99944444444446|
|                    2|Earl|  Ethan|2015-12-14 11:57:26|2015-12-16 14:27:38|2015-11-25 11:29:00|2015-11-25 14:47:00| 503.9605555555556|
|                    6|Earl|  Ethan|2016-02-02 10:05:15|2016-02-02 12:22:42|2016-02-18 11:20:00|2016-02-23 05:45:00|499.58916666666664|
|                   10|Earl|  Ethan|2016-03-01 10:52:33|2016-05-26 13:02:10|2016-02-10 11:16:00|2016-06-09 04:18:00|             504.0|
|                   11|Earl|  Ethan|2016-05-26 1

In [33]:
# write as CSV?
delay_result.coalesce(1).write.option("header", True).csv(f"{data_dir_csv}\delay_3w_events_test_Earl.csv", mode = "overwrite")

### All Pairs

In [26]:
df = df.st.set_geometry_field("geom_PPA")
df = df.st.set_time_fields("prev_timestamp", "timestamp")

In [27]:
%%time
"""
--conf spark.driver.memory=10g ^
--conf spark.executor.memory=40g
"""
run_delay_spatiotemporal_join(pairs, id_field, join_id_field, df, data_dir, broadcast_bounds, broadcast_spatial_bounds, attribute_list, select_columns, "All_run3", 28)

Writing status every 100 pairs
Processing time for pair Cinco-Fenix: 56.64 seconds
Processing time for pair Cinco-Whitey: 16.30 seconds
Processing time for pair Cinco-Two-thirds: 37.02 seconds
Processing time for pair Cinco-Halfway: 32.68 seconds
Processing time for pair Cinco-Mandeb: 15.40 seconds
Processing time for pair Cinco-Young Luro: 16.59 seconds
Processing time for pair Steve-Tekoa: 33.79 seconds
Processing time for pair Steve-Tire Pile: 30.91 seconds
Processing time for pair Steve-Tintin: 32.02 seconds
Processing time for pair Steve-Superior: 29.87 seconds
Processing time for pair Steve-Versace: 45.43 seconds
Processing time for pair Steve-Thomas: 33.30 seconds
Processing time for pair Mary-Steamhouse 1: 16.39 seconds
Processing time for pair Mary-Steamhouse 2: 14.68 seconds
...100 finished
Processing time for pair Schaumboch-Sill: 28.97 seconds
Processing time for pair Prado-Steamhouse 1: 25.81 seconds
Processing time for pair Prado-Steamhouse 2: 22.55 seconds
Processing tim

In [28]:
# check the output parquet for intersecting PPAs
delay4w_all_df = spark.read.format("parquet").load(f"{data_dir}\delay_28d_intersect_PPAs_All_run3")
# Sort the combined_intersect_ppas DataFrame
delay4w_all_df = delay4w_all_df.orderBy(
        col(id_field).asc(), 
        col(join_id_field).asc(), 
        col("prev_timestamp").asc()
    )
delay4w_all_df.show(5)

+----------+---------------------------+-------------------+-------------------+---------+--------+-------------+------------+-------------+--------------------------------+-------------------+-------------------+--------------+-------------+------------------+-----------------+----------------------+----------------------+------------------------+------------------------+------------------+------------------+------------------+------------------+-------------+------------------+------------------+--------------------+-------------+------------------+------------------+-------------------+-----------------+-------------------+-------------------+-------------------+------------------+--------------------+--------------------+--------------------+
|  event-id|individual-local-identifier|     prev_timestamp|          timestamp|prev_long|prev_lat|location-long|location-lat|join_event-id|join_individual-local-identifier|join_prev_timestamp|     join_timestamp|join_prev_long|join_prev_lat|j

In [29]:
delay4w_all_df.count()

55586456

In [31]:
%%time
filter_delay_PPAs(delay4w_all_df, id_field, join_id_field, data_dir, "All")

CPU times: total: 172 ms
Wall time: 7min 31s


In [32]:
# check the result
delay_result_all = spark.read.format("parquet").load(f"{data_dir}\delay_3w_events_All")
delay_result_all.show()

+---------------------+------+-------------+-------------------+-------------------+-------------------+-------------------+--------------+
|continuous_segment_id|    p1|           p2|           p1_start|             p1_end|           p2_start|             p2_end|max_difference|
+---------------------+------+-------------+-------------------+-------------------+-------------------+-------------------+--------------+
|                  317|  Airy|Artful_Dodger|2018-06-05 19:00:00|2018-06-19 11:00:00|2018-06-20 11:00:00|2018-07-04 15:00:00|         408.0|
|                  350|  Airy|Artful_Dodger|2018-07-31 15:00:00|2018-07-31 18:00:00|2018-08-19 18:00:00|2018-08-19 23:00:00|         462.0|
|                  353|  Airy|Artful_Dodger|2018-08-01 10:00:00|2018-08-01 13:00:00|2018-08-19 19:00:00|2018-08-19 21:00:00|         442.0|
|                  354|  Airy|Artful_Dodger|2018-08-01 17:00:00|2018-08-02 11:00:00|2018-08-19 18:00:00|2018-08-20 13:00:00|         448.0|
|                  3

In [33]:
# write as CSV?
delay_result_all.coalesce(1).write.option("header", True).csv(f"{data_dir_csv}\delay_3w_events_All_run3.csv", mode = "overwrite")