In [0]:
query = """
select * from main_prod.datascience_scratchpad.all_traj_data_test
"""
all_df = spark.sql(query)
all_df.createOrReplaceTempView("all_traj_data_test")
display(all_df)

In [0]:
%sql
select * from main_prod.datascience.userpiphistory limit 10

In [0]:
from pyspark.sql import functions as F

df = spark.read.table("main_prod.datascience.userpiphistory")

# keep only valid times (optional but wise)
df = df.filter(F.col("createdon").isNotNull())

userid_timezone_df = (
    df.groupBy("userid")
      .agg(F.max(F.struct("createdon", "timezone")).alias("maxrow"))
      .select(
          "userid",
          F.col("maxrow.createdon").alias("latest_time"),
          F.col("maxrow.timezone").alias("latest_timezone"),
      )
)

display(userid_timezone_df)

In [0]:
from pyspark.sql.functions import col

userid_timezone_df_fil = userid_timezone_df.filter(col("latest_timezone").isNotNull())
display(userid_timezone_df_fil)

In [0]:
# join df and userid_timezone_df_fil on userid
df_with_tz = all_df.join(userid_timezone_df_fil, on="userid", how="inner")
display(df_with_tz)

In [0]:
df_with_tz.createOrReplaceTempView("all_traj_data_test_with_tz")


In [0]:
query = """
select *, from_utc_timestamp(from_unixtime(location_timestamp/1000), latest_timezone) as localized_ts from all_traj_data_test_with_tz
"""
df_loc_ts = spark.sql(query)
display(df_loc_ts)

In [0]:
# keep timestamps after 2022 and before 23rd July 2025
df_loc_ts_fil = df_loc_ts.filter((df_loc_ts.localized_ts >= '2022-01-01') & (df_loc_ts.localized_ts < '2025-07-23'))
display(df_loc_ts_fil)

In [0]:
df_loc_ts_fil.createOrReplaceTempView("all_traj_data_test_loc_ts")

In [0]:

query = """
SELECT 
    userid,
    DATE(localized_ts) AS traj_date,
    COLLECT_LIST(localized_ts) AS timestamps,
    COLLECT_LIST(latitude) AS latitudes,
    COLLECT_LIST(longitude) AS longitudes
FROM 
    all_traj_data_test_loc_ts
WHERE
    userid IS NOT NULL 
    AND localized_ts IS NOT NULL 
    AND latitude IS NOT NULL 
    AND longitude IS NOT NULL
GROUP BY 
    userid, DATE(localized_ts)
ORDER BY 
    userid, traj_date
"""

result_df = spark.sql(query)
display(result_df)

In [0]:
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType, DateType, TimestampType
def sort_by_time(timestamps, longitudes, latitudes):
    sorted_ts = []
    polylines = []
    for ts, lon, lat in sorted(zip(timestamps, longitudes, latitudes), key=lambda x: x[0]):
        sorted_ts.append(ts)
        polylines.append([float(lon), float(lat)])
    return sorted_ts, polylines


sort_and_extract_udf = udf(sort_by_time, 
                           StructType([
                               StructField("sorted_ts", ArrayType(TimestampType())),
                               StructField("polylines", ArrayType(ArrayType(DoubleType())))
                           ]))


result_df_v2 = result_df.withColumn("sorted_data", 
                                         sort_and_extract_udf("timestamps", "longitudes", "latitudes"))

result_df_v2 = result_df_v2.withColumn("sorted_ts", col("sorted_data.sorted_ts")) \
                             .withColumn("polylines", col("sorted_data.polylines")) \
                             .drop("sorted_data")
display(result_df_v2)

In [0]:
result_df_v3 = result_df_v2.drop("timestamps", "longitudes", "latitudes")
display(result_df_v3)

In [0]:
from pyspark.sql.types import IntegerType, ArrayType, DoubleType, BooleanType
from pyspark.sql.functions import udf, col

import math
def lonlat2meters(lon, lat):
    semimajoraxis = 6378137.0
    east = lon * 0.017453292519943295
    north = lat * 0.017453292519943295
    t = math.sin(north)
    return semimajoraxis * east, 3189068.5 * math.log((1 + t + 1e-5) / (1 - t + 1e-5))


lonlat2meters_udf = udf(lambda traj: [list(lonlat2meters(p[0], p[1])) for p in traj], ArrayType(ArrayType(DoubleType())))
result_df_v4 = result_df_v3.withColumn("merc_seq", lonlat2meters_udf(col("polylines")))
display(result_df_v4)

In [0]:
def filter_based_on_timestamps(ts_list):
    unique_hours = set()
    for ts in ts_list:
        unique_hours.add(ts.hour)
    if len(unique_hours) > 7:
        return True
    return False

filter_based_on_timestamps_udf = udf(filter_based_on_timestamps, BooleanType())
result_df_v5 = result_df_v4.filter(filter_based_on_timestamps_udf(col("sorted_ts")))
display(result_df_v5)

In [0]:
# rename columns

result_df_v6 = result_df_v5.withColumnRenamed("sorted_ts", "timestamps").withColumnRenamed("polylines", "wgs_seq")


result_df_v6.write.mode("overwrite").saveAsTable("main_prod.datascience_scratchpad.all_traj_test")

In [0]:
result_df_v6.select("userid").distinct().count()

In [0]:
result_df_v6.count()

In [0]:
df = spark.read.table("main_prod.datascience_scratchpad.all_traj_test")

In [0]:
display(df)

In [0]:

from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType
from pyspark.sql.functions import col


def udf_get_min_max_lat_lon(wgs_seq):
    min_lat = 100000
    max_lat = -100000
    min_lon = 100000
    max_lon = -100000

    for lon, lat in wgs_seq:
        if lat < min_lat:
            min_lat = lat
        if lat > max_lat:
            max_lat = lat
        if lon < min_lon:
            min_lon = lon
        if lon > max_lon:
            max_lon = lon
    return min_lat, max_lat, min_lon, max_lon

get_min_max_lat_lon_udf = udf(udf_get_min_max_lat_lon, StructType([
    StructField("min_lat", DoubleType()),
    StructField("max_lat", DoubleType()),
    StructField("min_lon", DoubleType()),
    StructField("max_lon", DoubleType())
]))
df = df.withColumn("min_max_lat_lon", get_min_max_lat_lon_udf(col("wgs_seq")))
df = df.withColumn("min_lat", col("min_max_lat_lon.min_lat")) \
       .withColumn("max_lat", col("min_max_lat_lon.max_lat")) \
       .withColumn("min_lon", col("min_max_lat_lon.min_lon")) \
       .withColumn("max_lon", col("min_max_lat_lon.max_lon")) \
       .drop("min_max_lat_lon")

display(df)


In [0]:
# get minimum value of min lat
df.select("min_lat").agg({"min_lat": "min"}).show()
df.select("min_lon").agg({"min_lon": "min"}).show()
df.select("max_lat").agg({"max_lat": "max"}).show()
df.select("max_lon").agg({"max_lon": "max"}).show()


In [0]:
target_min_lat = 25.11833
target_max_lat = 49.38447
target_min_lon = -124.73306
target_max_lon = -66.94978

df_fil = df.filter((col("min_lat") >= target_min_lat) & (col("max_lat") <= target_max_lat) & (col("min_lon") >= target_min_lon) & (col("max_lon") <= target_max_lon))
df_fil.count()

In [0]:

df_fil.write.mode("overwrite").saveAsTable("main_prod.datascience_scratchpad.all_traj_test_v2")

In [0]:
df_fil.select("userid").distinct().count()

In [0]:
# choose 10000 userids from df_fil, randomly

val_userids = df_fil.select("userid").distinct().sample(False, 0.04).collect()

In [0]:
val_df = df_fil.filter(col("userid").isin([x.userid for x in val_userids]))
display(val_df)

In [0]:
# randomly sample 50% of the data
val_df = val_df.sample(False, 0.2)
val_df.count()

In [0]:
val_df.write.mode("overwrite").saveAsTable("main_prod.datascience_scratchpad.all_traj_val")

In [0]:
val_df.count()

In [0]:
# check if userids in val_df are in df_train

df_train = spark.read.table("main_prod.datascience_scratchpad.all_traj_train_v2")
display(df_train)


In [0]:
from pyspark.sql.functions import udf, col
from pyspark.sql.types import IntegerType

import math
def lonlat2meters(lon, lat):
    semimajoraxis = 6378137.0
    east = lon * 0.017453292519943295
    north = lat * 0.017453292519943295
    t = math.sin(north)
    return semimajoraxis * east, 3189068.5 * math.log((1 + t + 1e-5) / (1 - t + 1e-5))

In [0]:

def get_xyidx_by_point( x, y):
    x_min = -13885220.72428684
    y_min = 2890285.81199535
    x_unit = 150000
    y_unit = 150000
    i_x = int(x - x_min) // x_unit
    i_y = int(y - y_min) // y_unit
    return (i_x, i_y)

def get_cellid_by_xyidx(i_x: int, i_y: int):
    y_max = 6340351.496013149
    y_min = 2890285.81199535
    y_unit = 150000
    y_size = int(math.ceil((y_max - y_min) / y_unit))
    return i_x * y_size + i_y

def get_cellid_by_point(lon,lat):
    x,y = lonlat2meters(lon,lat)
    i_x, i_y = get_xyidx_by_point(x, y)
    return get_cellid_by_xyidx(i_x, i_y)


get_cellid_by_point_udf = udf(get_cellid_by_point, IntegerType())

df_train = df_train.withColumn("cellid", get_cellid_by_point_udf(col("min_lon"), col("min_lat")))

display(df_train)




In [0]:
cell_ids = df_train.select("cellid").distinct().collect()

In [0]:
len(cell_ids)

In [0]:
df_repart = df_train.repartition("cellid")
display(df_repart)

In [0]:
from tqdm import tqdm

for cell_id in tqdm(cell_ids):
    df_train_fil = df_repart.filter(col("cellid") == cell_id[0])
    df_train_fil.write.mode("overwrite").saveAsTable("main_prod.datascience_scratchpad.all_traj_train_v2_cell_{}".format(cell_id[0]))


In [0]:
# save cell_ids

cell_ids_list = [cell_id[0] for cell_id in cell_ids]

In [0]:
import pickle
pickle.dump(cell_ids_list, open("/Workspace/Users/jatin.agrawal@earnin.com/TrajCL/data/cell_ids_list.pkl", "wb"))

In [0]:
import pickle
cell_ids = pickle.load(open("/Workspace/Users/jatin.agrawal@earnin.com/TrajCL/data/cell_ids_list.pkl", "rb"))

cell_ids

In [0]:
from tqdm import tqdm
for cell_id in cell_ids:
    # remove the table
    print(cell_id)
    spark.sql("DROP TABLE IF EXISTS main_prod.datascience_scratchpad.all_traj_train_v2_cell_{}".format(cell_id))

In [0]:
tables = spark.catalog.listTables("main_prod.datascience_scratchpad")
matching_tables = [t.name for t in tables if t.name.startswith("all_traj_train_cell_")]

In [0]:
tables = spark.catalog.listTables("main_prod.datascience_scratchpad")