In [0]:
from pyspark.sql.functions import col, udf
from pyspark.sql.types import BooleanType, ArrayType, DoubleType

In [0]:
df = spark.table("main_prod.datascience_scratchpad.nyc_test_df_v2")
display(df)

In [0]:
df = df.withColumnRenamed("polylines", "wgs_seq")
df = df.filter(col("wgs_seq").isNotNull())
display(df)

In [0]:
from pyspark.sql.functions import expr
df = df.withColumn("trajlen", expr("size(wgs_seq)"))

display(df)

In [0]:
df = df.filter(col("trajlen")>=20)
display(df)

In [0]:
import math
def lonlat2meters(lon, lat):
    semimajoraxis = 6378137.0
    east = lon * 0.017453292519943295
    north = lat * 0.017453292519943295
    t = math.sin(north)
    return semimajoraxis * east, 3189068.5 * math.log((1 + t) / (1 - t))

In [0]:
lonlat2meters_udf = udf(lambda traj: [list(lonlat2meters(p[0], p[1])) for p in traj], ArrayType(ArrayType(DoubleType())))
df = df.withColumn("merc_seq", lonlat2meters_udf(col("wgs_seq")))
display(df)

In [0]:
from pyspark.sql.functions import unix_timestamp, col
from pyspark.sql.types import IntegerType

def get_distant_ts(timestamps):
    final_indices = [0]
    diff=0
    for i in range(1,len(timestamps)):
        diff += (timestamps[i] - timestamps[i-1]).total_seconds()
        if diff > 600:
            final_indices.append(i)
            diff=0
    return final_indices

get_distant_ts_udf = udf(lambda traj: get_distant_ts(traj), ArrayType(IntegerType()))

df_filtered = df.withColumn("final_indices", get_distant_ts_udf(col("timestamps")))
display(df_filtered)




In [0]:
from pyspark.sql.types import ArrayType, DoubleType, TimestampType
def filter_by_index(input_list, indices):
    return [input_list[i] for i in indices]

wgs_seq_filter_udf = udf(lambda traj, indices: filter_by_index(traj, indices), ArrayType(ArrayType(DoubleType())))

timestamps_filter_udf = udf(lambda traj, indices: filter_by_index(traj, indices), ArrayType(TimestampType()))

merc_seq_filter_udf = udf(lambda traj, indices: filter_by_index(traj, indices), ArrayType(ArrayType(DoubleType())))

df_filtered = df_filtered.withColumn("wgs_seq_filtered", wgs_seq_filter_udf(col("wgs_seq"), col("final_indices")))
df_filtered = df_filtered.withColumn("timestamps_filtered", timestamps_filter_udf(col("timestamps"), col("final_indices")))
df_filtered = df_filtered.withColumn("merc_seq_filtered", merc_seq_filter_udf(col("merc_seq"), col("final_indices")))
display(df_filtered)

In [0]:
from pyspark.sql.functions import expr
df_filtered = df_filtered.withColumn("trajlen", expr("size(wgs_seq_filtered)"))

display(df_filtered)

In [0]:
df_filtered = df_filtered.filter((col("trajlen")>=20) & (col("trajlen")<=100))
display(df_filtered)

In [0]:
df_filtered.count()

In [0]:
df_final = df_filtered.select("userid","traj_id", "timestamps_filtered", "wgs_seq_filtered", "merc_seq_filtered", "trajlen")
display(df_final)

In [0]:
df_final.write.mode("overwrite").saveAsTable("main_prod.datascience_scratchpad.nyc_test_filtered")

In [0]:
# # Repartition the DataFrame to reduce memory load on each partition
# df_final_repartitioned = df_final.repartition(200)  # Adjust the number of partitions as needed

# # Write the DataFrame to S3
# df_final_repartitioned.write.mode("overwrite").parquet(
#     "s3://earnin-prod-datalake-us-west-2-dl-scratchpad/jatin/trajcl-exp/train_filtered"
# )