In [0]:
!pip install folium
import folium

from folium import plugins

def display_traj(lon_lat_list,time_indices=None):
    coordinates = [[x[1],x[0]] for x in lon_lat_list]
    # Initialize map
    m = folium.Map(location=coordinates[0], zoom_start=15)


    # for lat,lon in pip_list:
    #     folium.Marker(
    #         location=[lat, lon],
    #         popup="Pip Location",
    #         icon=folium.Icon(color='red')  # Color can be 'red', 'blue', 'green', 'purple', etc.
    #     ).add_to(m)

    folium.Marker(
        location=coordinates[0],
        popup=time_indices[0],
        icon=folium.Icon(color='green')  # Color can be 'red', 'blue', 'green', 'purple', etc.
    ).add_to(m)

    folium.Marker(
        location=coordinates[-1],
        popup=time_indices[-1],
        icon=folium.Icon(color='red')  # Color can be 'red', 'blue', 'green', 'purple', etc.
    ).add_to(m)
    # Add markers
    for i, (lat, lon) in enumerate(coordinates[1:-1]):
        if time_indices is None:
            folium.Marker([lat, lon], popup=(lat, lon)).add_to(m)
        else:
            folium.Marker([lat, lon], popup=(i+1, time_indices[i+1], lat, lon)).add_to(m)
        

    # Draw arrows between points
    for i in range(len(coordinates) - 1):
        start = coordinates[i]
        end = coordinates[i + 1]

        # Draw the line
        line = folium.PolyLine([start, end], color="blue", weight=3, opacity=0.7).add_to(m)

        # Add directional arrow
        plugins.PolyLineTextPath(
            line,
            'âž¤',  # arrow symbol
            repeat=True,
            offset=7,
            attributes={'fill': 'blue', 'font-weight': 'bold', 'font-size': '16'}
            ).add_to(m)
        
        

    # Save and show the map
    m.save("map_with_arrows.html")
    return m


In [0]:
df = spark.read.parquet("/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/backfill_static_moving_relax_condition_v2")

In [0]:
display(df)
df.count()


In [0]:
df.where('predicted_work_type = "moving"').count()/df.count()

In [0]:
# imp points: 150862, 228036

In [0]:
display(df.where('predicted_work_type = "moving"'))

In [0]:
display(df.where("userid  = 6771816"))

In [0]:
%sql
select * from main_prod.ml_data.cm_work_dwell_time_v3_2 where userid = 15776078 order by calc_date desc

In [0]:
traj_df = spark.read.table('main_prod.ml_data.traj_data')
data_dir = '/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/backfill_traj_data_relaxed'

traj_df_v2 = spark.read.format('delta').load(data_dir)
display(traj_df)

In [0]:
display(traj_df.where('userid = 4515091').orderBy("traj_date", ascending=False))

In [0]:
from pyspark.sql.functions import col
def get_candidate_df(userid):
    candidate_df = traj_df.filter(col("userid") == userid)
    candidate_df_v2 = traj_df_v2.filter(col("userid") == userid)
    candidate_df_final = candidate_df.union(candidate_df_v2)
    return candidate_df_final

In [0]:
userid = 17049707
traj_dates = ["2024-08-01","2024-08-02","2024-08-03", "2024-08-04"]
list_lng_lat_list = []
time_indices_list = []
candidate_df = get_candidate_df(userid)
for traj_date in traj_dates:
    # print(traj_date)
    lng_lat_list = candidate_df.where("traj_date = '{}'".format(traj_date)).select("wgs_seq").collect()[0][0]
    time_indices = candidate_df.where("traj_date = '{}'".format(traj_date)).select("sorted_ts").collect()[0][0]
    list_lng_lat_list.append(lng_lat_list)
    time_indices_list.append(time_indices)
display(candidate_df)

In [0]:
display_traj(list_lng_lat_list[0], time_indices_list[0])

In [0]:


display_traj(list_lng_lat_list[1], time_indices_list[1])

In [0]:
display_traj(list_lng_lat_list[2], time_indices_list[2])

In [0]:
display_traj(list_lng_lat_list[3], time_indices_list[3])

In [0]:
emb_df = spark.read.table('main_prod.ml_data.traj_emb')
display(emb_df)

In [0]:
import numpy as np

def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return dot_product / (norm1 * norm2)

def calculate_similarity(embeddings):
    similarity = {}
    for i in range(len(embeddings)):
        for j in range(i+1, len(embeddings)):
            similarity[(i, j)] = cosine_similarity(embeddings[i], embeddings[j])
    return similarity

def get_emb_df(userid, traj_dates):
    df = emb_df.where(
        'userid = {} and traj_date in ({})'.format(
            userid,
            ','.join(["'{}'".format(traj_date) for traj_date in traj_dates])
        )
    ).orderBy("traj_date")
    # Extract embedding arrays from Row objects
    embeddings = [
        np.array(row['embedding']) for row in df.select("embedding").collect()
    ]
    return embeddings

# print(new_dates)
embeddings = get_emb_df(userid, traj_dates)
calculate_similarity(embeddings)

In [0]:
from sklearn.metrics.pairwise import pairwise_distances
import numpy as np

def neighbor_counts(embs, target_min_similarity=0.9):
    eps = 1.0 - target_min_similarity
    D = pairwise_distances(embs, metric="cosine")       # cosine distance
    # neighbors within eps (including self on diagonal)
    N = (D <= eps).sum(axis=1)
    return N, D

N, D = neighbor_counts(embeddings, target_min_similarity=0.9)
print("Per-point neighbor counts (including self):", N)
print("Max neighbors:", N.max(), "Median:", np.median(N))

In [0]:
from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.cluster import DBSCAN 
def apply_dbscan(embs, paycycle_len, target_min_similarity=0.9):
    eps = 1.0 - target_min_similarity    # cosine distance threshold
    n_embs = embs.shape[0]
    min_samples = 3
    db = DBSCAN(eps=eps, min_samples=max(int(n_embs*0.2), min_samples), metric="cosine", n_jobs=-1).fit(embs)
    return db

def get_cluster(df, paycycle_len):
    embeddings = np.stack(df['embedding'].values)
    db = apply_dbscan(embeddings, paycycle_len)
    labels = db.labels_
    cluster_dict = {}
    date_label_dict = {}
    
    for label in labels:
        if label != -1:
            if label in cluster_dict:
                cluster_dict[label]+=1
            else:
                cluster_dict[label]=1
            date_label_dict[df['traj_date'].iloc[0]] = label
    
    return db


def is_static(cluster_dict, traj_count):
    cluster_count = 0
    for key in cluster_dict:
        cluster_count += cluster_dict[key]
    if cluster_count >= int(0.7*traj_count):
        return True
    return False

def cluster_exists(db):
    labels = db.labels_
    for label in labels:
        if label != -1:
            return True
    return False


def dbscan_predict_all(db, X_train, X_new):
    nn = NearestNeighbors(radius=db.eps, metric=db.metric).fit(X_train)
    dists, idxs = nn.radius_neighbors(X_new, return_distance=True)
    y = db.labels_
    pred = np.full(len(X_new), -1, dtype=int)
    for i, (di, ii) in enumerate(zip(dists, idxs)):
        if len(ii) == 0: 
            continue
        lbls, di = y[ii], di
        mask = lbls != -1
        if mask.any():
            pred[i] = lbls[mask][np.argmin(di[mask])]
    return pred
    
def works_on_weekends_fn(db, weekday_df, weekend_df):
    if len(weekday_df) == 0 or len(weekend_df) == 0:
        return False
    weekday_embs = np.stack(weekday_df['embedding'].values)
    weekend_embs = np.stack(weekend_df['embedding'].values)
    pred = dbscan_predict_all(db, weekday_embs, weekend_embs)
    # if 40% of pred is not -1, then it works on weekends
    pred_not_neg = pred[pred != -1]
    return len(pred_not_neg) > 0.4*len(pred)


from pyspark.sql import functions as F, types as T
import pandas as pd

# -----------------------
# Helper functions assumed to exist and be importable on workers:
#   get_cluster(pdf) -> any
#   cluster_exists(cluster) -> bool
#   works_on_weekends_fn(cluster, weekday_pdf, weekend_pdf) -> bool
# Ensure they are defined in the same file or available on PYTHONPATH for executors.
# -----------------------

# Output schema (adjust types if your real types differ)
out_schema = T.StructType([
    T.StructField("userid", T.IntegerType(), False),
    T.StructField("employerid", T.IntegerType(), False),
    T.StructField("employername", T.StringType(), True),
    T.StructField("predicted_work_type", T.StringType(), True),
    T.StructField("predicted_on", T.DateType(), True),
    T.StructField("works_on_weekends", T.BooleanType(), True),
    T.StructField("paydate", T.DateType(), True),
    T.StructField("old_latest_date_label_dict", T.MapType(T.DateType(), T.IntegerType()), True),
    T.StructField("new_latest_date_label_dict", T.MapType(T.DateType(), T.IntegerType()), True),
    T.StructField("old_latest_weekday_df_length", T.LongType(), True),
    T.StructField("latest_weekday_df_length", T.LongType(), True),
])

def compute_static_moving(group_pdf: pd.DataFrame) -> pd.DataFrame:
    """
    Runs your pandas logic for one (userid, employerid, employername) group.
    This function executes on a Spark worker.
    """
    import datetime
    today = datetime.date.today()

    # We only process a single group here
    # Extract group keys (safe because it's a single group)
    userid = group_pdf["userid"].iloc[0]
    employerid = group_pdf["employerid"].iloc[0]
    employername = group_pdf.get("employername", pd.Series([None])).iloc[0]

    # Ensure 'weekday' exists (create if your input doesn't have it)
    if "weekday" not in group_pdf.columns:
        # If traj date is present, you could compute it; otherwise expect it precomputed
        if "traj_date" in group_pdf.columns:
            group_pdf = group_pdf.copy()
            group_pdf["weekday"] = pd.to_datetime(group_pdf["traj_date"]).dt.weekday
        else:
            raise ValueError("Missing 'weekday' column and no 'traj_date' to compute it from.")
    paycycle_len = (group_pdf["paydate"].iloc[-1] - group_pdf["prev_paydate"].iloc[-1]).days
    if paycycle_len <=8: 
        past_cycle_to_consider = 4
    elif paycycle_len >8 and paycycle_len <=15:
        past_cycle_to_consider = 2
    else:
        past_cycle_to_consider = 1
    target_n_cycles = 80
    outputs = []

    # # Convert to proper dtypes just in case (optional but helpful)
    # # Expect paydate to be date-like if coming from Spark DateType
    # if not pd.api.types.is_datetime64_any_dtype(group_pdf["paydate"]):
    #     group_pdf = group_pdf.copy()
    #     group_pdf["paydate"] = pd.to_datetime(group_pdf["paydate"]).dt.date

    # Loop cycles (kept from your original logic; target_n_cycles = 1)
    for target_rn in range(target_n_cycles):
        target_paycycle = target_rn + 1

        # Latest cycle slices
        latest_weekday_df = group_pdf.loc[
            (group_pdf["weekday"].isin([0,1,2,3,4])) & (group_pdf["rn"] == target_paycycle)
        ]
        latest_weekend_df = group_pdf.loc[
            (group_pdf["weekday"].isin([5,6])) & (group_pdf["rn"] == target_paycycle)
        ]

        # Old cycles (NOTE: use bitwise & for pandas)
        rn_mask_old = (group_pdf["rn"] > target_paycycle) & (group_pdf["rn"] <= target_paycycle + past_cycle_to_consider)
        old_weekday_df = group_pdf.loc[
            (group_pdf["weekday"].isin([0,1,2,3,4])) & rn_mask_old
        ]
        old_weekend_df = group_pdf.loc[
            (group_pdf["weekday"].isin([5,6])) & rn_mask_old
        ]

        if len(latest_weekday_df) == 0:
            continue
        old_latest_df = pd.concat([latest_weekday_df, old_weekday_df], ignore_index=True)
        old_latest_cluster, old_latest_cluster_dict,  old_latest_date_label_dict = get_cluster(old_latest_df)
        latest_date_label_dict = {}
        if is_static(old_latest_cluster_dict, len(old_latest_df)):
            work_type = "static"
            works_on_weekends = works_on_weekends_fn(
                old_latest_cluster,
                old_latest_df,
                pd.concat([latest_weekend_df, old_weekend_df], ignore_index=True),
            )
        else:
            latest_cluster, latest_cluster_dict, latest_date_label_dict = get_cluster(pd.concat(latest_weekday_df, ignore_index=True))
            if is_static(latest_cluster_dict, len(latest_weekday_df)):
                work_type = "static"
                works_on_weekends = works_on_weekends_fn(
                    latest_cluster, latest_weekday_df, latest_weekend_df
                )
            else:
                work_type = "moving"
                works_on_weekends = False

        # If multiple rows exist for the latest cycle, choose a representative paydate.
        # Here we just take the first paydate from the latest cycle rows.
        paydate_value = latest_weekday_df["paydate"].iloc[0] if len(latest_weekday_df) else None

        outputs.append({
            "userid": userid,
            "employerid": employerid,
            "employername": employername,
            "predicted_work_type": work_type,
            "predicted_on": today,
            "works_on_weekends": works_on_weekends,
            "paydate": paydate_value,
            "old_latest_date_label_dict": old_latest_date_label_dict,
            "latest_date_label_dict": latest_date_label_dict,
            "old_latest_weekday_df_length": len(old_latest_df),
            "latest_weekday_df_length": len(latest_weekday_df)
        })

    if outputs:
        return pd.DataFrame(outputs, columns=[f.name for f in out_schema])
    else:
        # Return empty frame with correct columns if no output for this group
        return pd.DataFrame(columns=[f.name for f in out_schema])

# -----------------------
# Run on Spark DataFrame
# Assume `df` is your Spark DataFrame version of df_pd,
# and it has the columns used above: userid, employerid, employername, weekday, rn, paydate (and/or traj_date if you need to compute weekday).
# -----------------------



In [0]:
emb_df_fil = emb_df.where('userid = 19822535 and traj_date > "2025-08-27" and traj_date <= "2025-09-27"').toPandas()
display(emb_df_fil)

In [0]:
paycycle_len = 14

In [0]:
db = get_cluster(emb_df_fil, paycycle_len)

In [0]:
new_dates = []
for i,j in zip(emb_df_fil['traj_date'].values, db.labels_):
    if (j != -1):
        new_dates.append(i)

In [0]:
db.labels_

In [0]:
neighbor_counts(np.stack(emb_df_fil['embedding'].values))