In [0]:
import sys
sys.path.append("..")
from config_infer import InferenceConfig
cfg = InferenceConfig()


In [0]:
query = """
select * from main_prod.earnings_analysis.fact_user_earnings_daily where paydate >= current_date - 3 and paydate<current_date and total_pck_amt > 0
"""

df = spark.sql(query)
# display(df)

In [0]:
# get userid employername and employerid


userid_emp = df.select('userid','employerid').distinct()
# display(userid_emp)

In [0]:
# userid_emp.count()

In [0]:
query = """
select userid, employerid, employername, paydate, prev_paydate, total_pck_amt from main_prod.earnings_analysis.fact_user_earnings_daily where paydate is not NULL and paydate<= current_date -1 
"""

facts_df = spark.sql(query)
# display(facts_df)

In [0]:
facts_df = facts_df.distinct()
# display(facts_df)

In [0]:
facts_df_fil= facts_df.join(userid_emp, ['userid','employerid'], "inner")
# display(facts_df_fil)

In [0]:
# facts_df_fil.count()

In [0]:
# for each userid, empolyerid pair keep the last 4 paydates and corresponding total_pck_amt

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
from pyspark.sql.functions import desc

window = Window.partitionBy('userid','employerid').orderBy(desc('paydate'))
facts_df_fil_2 = facts_df_fil.withColumn('rn', row_number().over(window))

In [0]:
traj_emb_df = spark.read.table(cfg.traj_emb_table_name)

display(traj_emb_df)

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.functions import broadcast

joined = (
    facts_df_fil_2.join(
        broadcast(traj_emb_df),  # remove broadcast() if df1 is large
        (facts_df_fil_2.userid == traj_emb_df.userid)
        & (traj_emb_df.traj_date >= facts_df_fil_2.prev_paydate)   # for closed interval use <= below
        & (traj_emb_df.traj_date < facts_df_fil_2.paydate),
        "left"
    )
)


In [0]:
facts_df_count = (
    joined.groupBy(facts_df_fil_2.userid, facts_df_fil_2.prev_paydate, facts_df_fil_2.paydate, facts_df_fil_2.rn)
          .agg(F.count(traj_emb_df.traj_date).alias("count"))
          .select("userid", "prev_paydate", "paydate", "rn", "count")
)

In [0]:
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import col

def filter_by_count(prev_date, current_date, count):
    n_days = (current_date - prev_date).days
    if count >= n_days*0.5:
        return True
    else:
        return False
    
df_filtered_by_count = (
    facts_df_count
    .filter(F.udf(filter_by_count, returnType=BooleanType())(col("prev_paydate"), col("paydate"), col("count")))
)

In [0]:
final_facts_df = facts_df_fil_2.join(df_filtered_by_count, ["userid", "prev_paydate", "paydate", "rn"])

In [0]:
from pyspark.sql.functions import broadcast
from pyspark.sql.functions import col
df = (
    traj_emb_df.join(
        broadcast(final_facts_df),
        (traj_emb_df.userid == final_facts_df.userid)
        & (col("traj_date") >= col("prev_paydate"))
        & (col("traj_date") < col("paydate")),
        "inner"       # use "left" if you want to keep df2 rows without a matching interval
    )
    .select(traj_emb_df.userid, final_facts_df.employerid, final_facts_df.employername, traj_emb_df.traj_date, traj_emb_df.embedding, final_facts_df.paydate, final_facts_df.prev_paydate, final_facts_df.rn)
)

In [0]:
import mlflow

# Disable all autologging
mlflow.autolog(disable=True)

# Or disable just sklearn autologging
mlflow.sklearn.autolog(disable=True)

In [0]:
from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.cluster import DBSCAN
import json 
def apply_dbscan(embs, target_min_similarity=0.9):
    eps = 1.0 - target_min_similarity    # cosine distance threshold
    n_embs = embs.shape[0]
    min_samples = 3
    db = DBSCAN(eps=eps, min_samples=max(int(n_embs*0.2), min_samples), metric="cosine", n_jobs=-1).fit(embs)
    return db

def get_cluster(df):
    embeddings = np.stack(df['embedding'].values)
    db = apply_dbscan(embeddings)
    labels = db.labels_
    cluster_dict = {}
    date_label_dict = {}
    
    for i,label in enumerate(labels):
        if label != -1:
            if label in cluster_dict.keys():
                cluster_dict[label]+=1
            else:
                cluster_dict[label]=1
            date_label_dict[df.iloc[i]['traj_date'].strftime("%Y-%m-%d")] = label
    
    return db, cluster_dict, date_label_dict


def is_static(cluster_dict, traj_count):
    cluster_count = 0
    for key in cluster_dict.keys():
        if cluster_dict[key] >= int(0.25*traj_count):
            cluster_count += cluster_dict[key]
    if cluster_count >= int(0.6*traj_count):
        return True
    return False

# def cluster_exists(db):
#     labels = db.labels_
#     for label in labels:
#         if label != -1:
#             return True
#     return False


def dbscan_predict_all(db, X_train, X_new):
    nn = NearestNeighbors(radius=db.eps, metric=db.metric).fit(X_train)
    dists, idxs = nn.radius_neighbors(X_new, return_distance=True)
    y = db.labels_
    pred = np.full(len(X_new), -1, dtype=int)
    for i, (di, ii) in enumerate(zip(dists, idxs)):
        if len(ii) == 0: 
            continue
        lbls, di = y[ii], di
        mask = lbls != -1
        if mask.any():
            pred[i] = lbls[mask][np.argmin(di[mask])]
    return pred
    
def works_on_weekends_fn(db, weekday_df, weekend_df):
    if len(weekday_df) == 0 or len(weekend_df) == 0:
        return False
    weekday_embs = np.stack(weekday_df['embedding'].values)
    weekend_embs = np.stack(weekend_df['embedding'].values)
    pred = dbscan_predict_all(db, weekday_embs, weekend_embs)
    # if 40% of pred is not -1, then it works on weekends
    pred_not_neg = pred[pred != -1]
    return len(pred_not_neg) > 0.4*len(pred)


from pyspark.sql import functions as F, types as T
import pandas as pd

# -----------------------
# Helper functions assumed to exist and be importable on workers:
#   get_cluster(pdf) -> any
#   cluster_exists(cluster) -> bool
#   works_on_weekends_fn(cluster, weekday_pdf, weekend_pdf) -> bool
# Ensure they are defined in the same file or available on PYTHONPATH for executors.
# -----------------------

# Output schema (adjust types if your real types differ)
out_schema = T.StructType([
    T.StructField("userid", T.IntegerType(), False),
    T.StructField("employerid", T.IntegerType(), False),
    T.StructField("employername", T.StringType(), True),
    T.StructField("predicted_work_type", T.StringType(), True),
    T.StructField("predicted_on", T.DateType(), True),
    T.StructField("works_on_weekends", T.BooleanType(), True),
    T.StructField("paydate", T.DateType(), True),
    T.StructField("old_latest_date_label_dict", T.MapType(T.StringType(), T.IntegerType()), True),
    T.StructField("latest_date_label_dict", T.MapType(T.StringType(),  T.IntegerType()), True),
    T.StructField("old_latest_weekday_df_length", T.LongType(), True),
    T.StructField("latest_weekday_df_length", T.LongType(), True),
])

def convert_dict_keys_to_str(d):
    if isinstance(d, dict):
        return {str(k): v for k, v in d.items()}
    return d

def compute_static_moving(group_pdf: pd.DataFrame) -> pd.DataFrame:
    """
    Runs your pandas logic for one (userid, employerid, employername) group.
    This function executes on a Spark worker.
    """
    import datetime
    today = datetime.date.today()

    # We only process a single group here
    # Extract group keys (safe because it's a single group)
    userid = group_pdf["userid"].iloc[0]
    employerid = group_pdf["employerid"].iloc[0]
    employername = group_pdf.get("employername", pd.Series([None])).iloc[0]

    # Ensure 'weekday' exists (create if your input doesn't have it)
    if "weekday" not in group_pdf.columns:
        # If traj date is present, you could compute it; otherwise expect it precomputed
        if "traj_date" in group_pdf.columns:
            group_pdf = group_pdf.copy()
            group_pdf["weekday"] = pd.to_datetime(group_pdf["traj_date"]).dt.weekday
        else:
            raise ValueError("Missing 'weekday' column and no 'traj_date' to compute it from.")
    paycycle_len = (group_pdf["paydate"].iloc[-1] - group_pdf["prev_paydate"].iloc[-1]).days
    if paycycle_len <=8: 
        past_cycle_to_consider = 4
    elif paycycle_len >8 and paycycle_len <=15:
        past_cycle_to_consider = 2
    else:
        past_cycle_to_consider = 1
    target_n_cycles = 1
    outputs = []

    # # Convert to proper dtypes just in case (optional but helpful)
    # # Expect paydate to be date-like if coming from Spark DateType
    # if not pd.api.types.is_datetime64_any_dtype(group_pdf["paydate"]):
    #     group_pdf = group_pdf.copy()
    #     group_pdf["paydate"] = pd.to_datetime(group_pdf["paydate"]).dt.date

    # Loop cycles (kept from your original logic; target_n_cycles = 1)
    for target_rn in range(target_n_cycles):
        target_paycycle = target_rn + 1

        # Latest cycle slices
        latest_weekday_df = group_pdf.loc[
            (group_pdf["weekday"].isin([0,1,2,3,4])) & (group_pdf["rn"] == target_paycycle)
        ].reset_index(drop=True)
        latest_weekend_df = group_pdf.loc[
            (group_pdf["weekday"].isin([5,6])) & (group_pdf["rn"] == target_paycycle)
        ].reset_index(drop=True)

        # Old cycles (NOTE: use bitwise & for pandas)
        rn_mask_old = (group_pdf["rn"] > target_paycycle) & (group_pdf["rn"] <= target_paycycle + past_cycle_to_consider)
        old_weekday_df = group_pdf.loc[
            (group_pdf["weekday"].isin([0,1,2,3,4])) & rn_mask_old
        ].reset_index(drop=True)
        old_weekend_df = group_pdf.loc[
            (group_pdf["weekday"].isin([5,6])) & rn_mask_old
        ].reset_index(drop=True)

        if len(latest_weekday_df) == 0:
            continue
        old_latest_df = pd.concat([latest_weekday_df, old_weekday_df], ignore_index=True).reset_index(drop=True)
        if len(old_latest_df)< 0.7*paycycle_len:
            continue
        old_latest_date_label_dict = {}
        old_latest_cluster, old_latest_cluster_dict,  old_latest_date_label_dict = get_cluster(old_latest_df)
        latest_date_label_dict = {}
        if is_static(old_latest_cluster_dict, len(old_latest_df)):
            work_type = "static"
            works_on_weekends = works_on_weekends_fn(
                old_latest_cluster,
                old_latest_df,
                pd.concat([latest_weekend_df, old_weekend_df], ignore_index=True),
            )
        else:
            latest_cluster, latest_cluster_dict, latest_date_label_dict = get_cluster(latest_weekday_df)
            if is_static(latest_cluster_dict, len(latest_weekday_df)):
                work_type = "static"
                works_on_weekends = works_on_weekends_fn(
                    latest_cluster, latest_weekday_df, latest_weekend_df
                )
            else:
                work_type = "moving"
                works_on_weekends = False

        # If multiple rows exist for the latest cycle, choose a representative paydate.
        # Here we just take the first paydate from the latest cycle rows.
        paydate_value = latest_weekday_df["paydate"].iloc[0] if len(latest_weekday_df) else None
        # convert dict to string
        # old_latest_date_label_dict_str = convert_dict_keys_to_str(old_latest_date_label_dict)
        # latest_date_label_dict_str = convert_dict_keys_to_str(latest_date_label_dict)

        outputs.append({
            "userid": userid,
            "employerid": employerid,
            "employername": employername,
            "predicted_work_type": work_type,
            "predicted_on": today,
            "works_on_weekends": works_on_weekends,
            "paydate": paydate_value,
            "old_latest_date_label_dict": old_latest_date_label_dict,
            "latest_date_label_dict": latest_date_label_dict,
            "old_latest_weekday_df_length": len(old_latest_df),
            "latest_weekday_df_length": len(latest_weekday_df)
        })

    if outputs:
        return pd.DataFrame(outputs, columns=[f.name for f in out_schema])
    else:
        # Return empty frame with correct columns if no output for this group
        return pd.DataFrame(columns=[f.name for f in out_schema])

# -----------------------
# Run on Spark DataFrame
# Assume `df` is your Spark DataFrame version of df_pd,
# and it has the columns used above: userid, employerid, employername, weekday, rn, paydate (and/or traj_date if you need to compute weekday).
# -----------------------



In [0]:
result_spark_df = (
    df
    .groupBy("userid", "employerid")
    .applyInPandas(compute_static_moving, schema=out_schema)
)

In [0]:
result_spark_df.write.mode("overwrite").parquet("/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/static_moving_dummy")

In [0]:
result_df = spark.read.parquet("/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/static_moving_dummy")
display(result_df)

In [0]:
# display(result_df.select('paydate').groupBy('paydate').count())


In [0]:
output_count = result_df.count()

In [0]:
if output_count>0:
    result_df.createOrReplaceTempView("new_data")
    spark.sql("""MERGE INTO {} AS target USING new_data AS source ON target.userid = source.userid AND target.employerid = source.employerid AND target.paydate = source.paydate WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""".format(cfg.static_moving_table_name))

In [0]:
%sql
select count(*) from main_prod.ml_data.static_moving_worktype where paydate = current_date - 1

In [0]:
%sql
select count(*) from main_prod.ml_data.static_moving_worktype

In [0]:
# output_df.where('predicted_work_type = "moving"').count()

In [0]:
# userid = 23122754
# df_pd[df_pd['userid']==userid].sort_values("traj_date").reset_index(drop=True)

In [0]:
# test_df = df_pd[df_pd['userid']==userid].sort_values("traj_date").reset_index(drop=True)
# embs = np.stack(test_df['embedding'].to_list())
# target_min_similarity = 0.9
# eps = 1.0 - target_min_similarity    # cosine distance threshold
# n_embs = embs.shape[0]
# db = DBSCAN(eps=eps, min_samples=5, metric="cosine", n_jobs=-1).fit(embs)
# labels = db.labels_
# print(labels)

In [0]:
# # print cosine similarity between all
# from sklearn.metrics.pairwise import cosine_similarity
# for i in range(n_embs):
#     for j in range(i+1, n_embs):
#         print(test_df['traj_date'][i], test_df['traj_date'][j], cosine_similarity(embs[i].reshape(1,-1), embs[j].reshape(1,-1)))

In [0]:
# delete table if exists main_prod.ml_data.static_moving_worktype
# spark.sql("DROP TABLE IF EXISTS main_prod.ml_data.static_moving_worktype")
