In [0]:
%pip install -r ../requirements.txt
%pip install tqdm

In [0]:
import torch
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import FloatType
from config_infer import InferenceConfig
from model.dual_attention import DualSTBTimeWeighted
from utils.traj import *
import pickle
from utils.cellspace import *
from torch.nn.utils.rnn import pad_sequence
import pickle
import warnings
warnings.filterwarnings("ignore")
import torch
from tqdm import tqdm
import numpy as np
from pyspark.sql.types import *
# Global lazy init so it loads once per worker
_model = None
_device = None

def _get_model():
    global _model, _device
    cfg = InferenceConfig()
    if _model is None:
        _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        _model = DualSTBTimeWeighted(cfg.seq_embedding_dim, 
                                            cfg.trans_hidden_dim, 
                                            cfg.trans_attention_head, 
                                            cfg.trans_attention_layer, 
                                            cfg.trans_attention_dropout, 
                                            cfg.trans_pos_encoder_dropout)                 # define your model
        checkpoint = torch.load(cfg.checkpoint_file, map_location=_device)['model_state_dict']
        model_keys = [k for k in list(checkpoint.keys()) if 'encoder_q' in k]
        new_checkpoint = {}
        for k in model_keys:
            new_k = k.replace('clmodel.encoder_q.', '')
            new_checkpoint[new_k] = checkpoint[k]
        _model.load_state_dict(new_checkpoint)
        _model.to(_device).eval()
    return _model, _device

@pandas_udf(ArrayType(FloatType()))
def predict_udf(traj: pd.Series, time_indices: pd.Series) -> pd.Series:
    model, device = _get_model()
    cfg = InferenceConfig()
    embs_parent = pickle.load(open(cfg.dataset_embs_file_parent, 'rb')).to('cpu').detach() # tensor
    embs_child = pickle.load(open(cfg.dataset_embs_file_child, 'rb')).to('cpu').detach() # tensor
    cellspace_parent = pickle.load(open(cfg.dataset_cell_file_parent, 'rb'))
    cellspace_child = pickle.load(open(cfg.dataset_cell_file_child, 'rb'))
    hier_cellspace = HirearchicalCellSpace(cellspace_parent, cellspace_child)
    with torch.no_grad():
        # batch the column for efficiency
        batches = []
        bs = 32
        for i in range(0, len(traj), bs):
            traj_list = traj.iloc[i:i+bs].values.tolist()
            time_indices_list = time_indices.iloc[i:i+bs].values.tolist()
            time_indices_list = [np.array(time_indices_list[i], dtype='datetime64[ns]') for i in range(len(time_indices_list))]
            # x = torch.tensor(traj.iloc[i:i+bs].values).float().to(device)
            # y = model(x).squeeze().detach().cpu().numpy()
            # batches.append(y)
            traj_cell_parent, traj_cell_child, traj_p, traj_timedelta = zip(*[merc2cell2(l[:800],t[:800], hier_cellspace) for l,t in zip(traj_list, time_indices_list)])
        # print(traj_cell)
            traj_emb_p = [torch.tensor(generate_spatial_features(t, hier_cellspace)) for t in traj_p]
            traj_emb_p = pad_sequence(traj_emb_p, batch_first = False).to(device)
            traj_emb_cell_parent = [embs_parent[list(t)] for t in traj_cell_parent]
            traj_emb_cell_child = [embs_child[list(t)] for t in traj_cell_child]
            traj_emb_cell = [a + b for a, b in zip(traj_emb_cell_parent, traj_emb_cell_child)]
            traj_emb_cell = pad_sequence(traj_emb_cell, batch_first = False).to(device)
            traj_len = torch.tensor(list(map(len, traj_p)), dtype = torch.long, device = device)
            traj_timedelta = pad_sequence([torch.tensor(t) for t in traj_timedelta], batch_first=False, padding_value=0).to(device)
            max_traj_len = traj_len.max().item() # trajs1_len[0]
            src_padding_mask = torch.arange(max_traj_len, device = device)[None, :] >= traj_len[:, None]
            traj_embs = model(**{'src': traj_emb_cell, 'time_deltas': traj_timedelta, 'attn_mask': None, 'src_padding_mask': src_padding_mask, 'src_len': traj_len, 'srcspatial': traj_emb_p})
            batches.append(traj_embs.detach().cpu().numpy())
        return pd.Series(np.concatenate(batches).tolist())


In [0]:
data_dir = '/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/backfill_traj_pre_data_rem_parquet'

df = spark.read.parquet(data_dir)
# display(df)

In [0]:
df.count()

In [0]:
# limit to 100 samples
# df = df.limit(1000000)
# display(df)

In [0]:
# df.count()

In [0]:
# out_df = df.withColumn('embedding', predict_udf('merc_seq','sorted_ts'))
# display(out_df) 

In [0]:
out_df = df.withColumn('embedding', predict_udf('merc_seq','sorted_ts'))
out_df.write.mode("overwrite").parquet('/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/backfill_rem_traj_emb')

In [0]:
df = spark.read.parquet("/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/backfill_rem_traj_emb")

display(df)

In [0]:
# drop cols from df
df = df.drop('merc_seq', 'sorted_ts', 'wgs_seq')
display(df)


In [0]:
# create a new col with value as "v1"
from pyspark.sql.functions import lit
df = df.withColumn('model_version', lit('v1'))

display(df)

In [0]:
df.createOrReplaceTempView("traj_data")

In [0]:
%sql
MERGE INTO main_prod.datascience_scratchpad.traj_emb AS target
USING traj_data AS source
ON target.userid = source.userid
   AND target.traj_date = source.traj_date
WHEN MATCHED THEN 
  UPDATE SET *
WHEN NOT MATCHED THEN
  INSERT *

In [0]:
# df.write.mode("append").saveAsTable("main_prod.datascience_scratchpad.traj_emb")

In [0]:
df = spark.read.table("main_prod.datascience_scratchpad.traj_emb")
df.count()

In [0]:
# display(spark.sql("select * from main_prod.datascience_scratchpad.traj_emb"))