In [0]:
import os
os.system("pip install -r requirements.txt")
os.system("pip install tqdm")

from config_infer import InferenceConfig
from model.dual_attention import DualSTBTimeWeighted
from utils.traj import *
import pickle
from utils.cellspace import *
from torch.nn.utils.rnn import pad_sequence
import pickle
import warnings
warnings.filterwarnings("ignore")
import torch
from tqdm import tqdm
import numpy as np

cfg = InferenceConfig()
# print(cfg.to_str())

print(cfg.checkpoint_file)
encoder_q = DualSTBTimeWeighted(cfg.seq_embedding_dim, 
                                            cfg.trans_hidden_dim, 
                                            cfg.trans_attention_head, 
                                            cfg.trans_attention_layer, 
                                            cfg.trans_attention_dropout, 
                                            cfg.trans_pos_encoder_dropout)

encoder_q = encoder_q.to(cfg.device)
device = cfg.device

print(encoder_q)

# load model from checkpoint
checkpoint = torch.load(cfg.checkpoint_file, map_location=cfg.device)['model_state_dict']
encoder_q_keys = [k for k in list(checkpoint.keys()) if 'encoder_q' in k]

new_checkpoint = {}
for k in encoder_q_keys:
    new_k = k.replace('clmodel.encoder_q.', '')
    new_checkpoint[new_k] = checkpoint[k]

encoder_q.load_state_dict(new_checkpoint)
encoder_q.eval()
print("Model loaded from checkpoint.")

embs_parent = pickle.load(open(cfg.dataset_embs_file_parent, 'rb')).to('cpu').detach() # tensor
embs_child = pickle.load(open(cfg.dataset_embs_file_child, 'rb')).to('cpu').detach() # tensor
cellspace_parent = pickle.load(open(cfg.dataset_cell_file_parent, 'rb'))
cellspace_child = pickle.load(open(cfg.dataset_cell_file_child, 'rb'))
hier_cellspace = HirearchicalCellSpace(cellspace_parent, cellspace_child)

def model_forward(trajs1_emb, trajs1_emb_p, trajs1_len, time_deltas1):
    max_trajs1_len = trajs1_len.max().item() # trajs1_len[0]
    src_padding_mask1 = torch.arange(max_trajs1_len, device = cfg.device)[None, :] >= trajs1_len[:, None]
    # traj_embs = self.clmodel.encoder_q(**{'src': trajs1_emb, 'time_indices': time_indices1, 'attn_mask': None, 'src_padding_mask': src_padding_mask1, 'src_len': trajs1_len, 'srcspatial': trajs1_emb_p})
    traj_embs = encoder_q(**{'src': trajs1_emb, 'time_deltas': time_deltas1, 'attn_mask': None, 'src_padding_mask': src_padding_mask1, 'src_len': trajs1_len, 'srcspatial': trajs1_emb_p})
    return traj_embs

def infer_batch(traj, time_indices):
    traj_cell_parent, traj_cell_child, traj_p, traj_timedelta = zip(*[merc2cell2(l[:800],t[:800], hier_cellspace) for l,t in zip(traj, time_indices)])
    # print(traj_cell)
    traj_emb_p = [torch.tensor(generate_spatial_features(t, hier_cellspace)) for t in traj_p]
    traj_emb_p = pad_sequence(traj_emb_p, batch_first = False).to(device)
    traj_emb_cell_parent = [embs_parent[list(t)] for t in traj_cell_parent]
    traj_emb_cell_child = [embs_child[list(t)] for t in traj_cell_child]
    traj_emb_cell = [a + b for a, b in zip(traj_emb_cell_parent, traj_emb_cell_child)]
    traj_emb_cell = pad_sequence(traj_emb_cell, batch_first = False).to(device)
    traj_len = torch.tensor(list(map(len, traj_p)), dtype = torch.long, device = device)
    traj_timedelta = pad_sequence([torch.log(torch.tensor(t)) for t in traj_timedelta], batch_first=False, padding_value=0).to(cfg.device)
    # print(traj_emb_cell, traj_emb_p, traj_len)
    traj_embs = model_forward(traj_emb_cell.float(), traj_emb_p.float(), traj_len, traj_timedelta)
    return traj_embs, traj_cell_parent, traj_cell_child , traj_p, traj_timedelta

batch_size = cfg.batch_size
# def infer(traj, time_indices):
#     if len(traj)> batch_size:
#         traj_embs = []
#         for i in range(0, len(traj), batch_size):
#             traj_batch = traj[i:i+batch_size]
#             time_indices_batch = time_indices[i:i+batch_size] 
#             traj_embs.append(infer_batch(traj_batch, time_indices_batch))
#         return torch.cat(traj_embs, dim=0)
#     else:
#         return infer_batch(traj, time_indices)
    




In [0]:
# from pyspark.sql.functions import current_date
# query = """
# SELECT * FROM {} """.format(cfg.traj_df_table_name)

# # query = """
# # SELECT * FROM {} LIMIT 1000
# # """.format(cfg.traj_df_table_name)
# df = spark.sql(query)

# df.count()
from glob import glob

files = glob("/Volumes/main_prod/datascience_scratchpad/jatin/trajcl_exp/usa/last_20_days/*.parquet")

print(files)


In [0]:
import pandas as pd
from tqdm import tqdm
for file in tqdm(files):
    df_pd = pd.read_parquet(file)
    emb_list = []
    print("Started Inference")
    for i in range(0,len(df_pd), batch_size):
        traj = df_pd['merc_seq'].iloc[i:i+batch_size].tolist()
        time_indices = df_pd['sorted_ts'].iloc[i:i+batch_size].tolist()
        time_indices = [np.array(time_indices[i], dtype='datetime64[ns]') for i in range(len(time_indices))]
        
        traj_embs, traj_cell_parent, traj_cell_child, traj_p, traj_timedelta = infer_batch(traj, time_indices)

        emb_list.append(traj_embs.detach().cpu())

    all_emb = torch.cat(emb_list, dim=0)
    # drop all columns except userid and traj_date
    df_pd = df_pd[['userid', 'traj_date']]
    df_pd['embedding'] = all_emb.tolist()
    df_pd['model_version'] = cfg.model_version
    print("Pandas df created")
    df_spark = spark.createDataFrame(df_pd)
    print("Spark df created")
    df_spark.createOrReplaceTempView('traj_data')
    print("Stared update")
    query = """
    MERGE INTO main_prod.datascience_scratchpad.traj_emb AS target
    USING traj_data AS source
    ON target.userid = source.userid
    AND target.traj_date = source.traj_date
    WHEN MATCHED THEN 
    UPDATE SET *
    WHEN NOT MATCHED THEN
    INSERT *
        """
    _sqldf = spark.sql(query)
    print("Spark Updated")
    # break