In [1]:
import torch 
import numpy as np

import os
import pandas as pd
import geopandas as gpd
import pickle as pickle

# from loc_predict.processing import _split_train_test

from shapely import wkt

import powerlaw
import matplotlib.pyplot as plt

from metrics.metrics import radius_gyration, jump_length, location_frquency, wait_time
from utils.utils import load_data

data_dir = os.path.join("data", "validation")

def save_pk_file(save_path, data):
    """Function to save data to pickle format given data and path."""
    with open(save_path, "wb") as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


# Validation


In [None]:
simulated_sp = pd.read_csv(os.path.join(data_dir, "mobis_mhsa_generation.csv"))
simulated_sp

In [None]:
simulated_sp = simulated_sp.merge(
    loc.reset_index()[["id", "center"]].rename(columns={"id": "location_id"}), how="left", left_on="generated_ls", right_on="location_id"
)

In [None]:
simulated_sp.rename(columns={"center": "geometry"}, inplace=True)
simulated_sp["geometry"] = simulated_sp["geometry"].apply(wkt.loads)
simulated_sp = gpd.GeoDataFrame(simulated_sp, geometry="geometry", crs="EPSG:4326")

In [None]:
simulated_sp.drop(columns={"user_id", "generated_ls"}, inplace=True)
simulated_sp.rename(columns={"seq_id": "user_id"}, inplace=True)

# Validate metric calculation

In [51]:
from metrics.evaluations import Metric
from utils.utils import load_data, setup_seed, load_config
from easydict import EasyDict as edict
from utils.dataloader import get_train_test, _get_valid_sequence

import pandas as pd
import numpy as np
import os
import geopandas as gpd
from shapely import wkt

In [52]:
# initialization
config = load_config("./config/movesim.yml")
config = edict(config)

# read and preprocess
sp = pd.read_csv(os.path.join(config.temp_save_root, "sp.csv"), index_col="id")
loc = pd.read_csv(os.path.join(config.temp_save_root, "locs_s2.csv"), index_col="id")
sp = load_data(sp, loc)

all_locs = pd.read_csv(os.path.join(config.temp_save_root, "all_locations.csv"), index_col="id")
all_locs["geometry"] = all_locs["geometry"].apply(wkt.loads)
all_locs = gpd.GeoDataFrame(all_locs, geometry="geometry", crs="EPSG:4326")
# transform to projected coordinate systems
all_locs = all_locs.to_crs("EPSG:2056")

train_data, vali_data, test_data, all_locs = get_train_test(sp, all_locs=all_locs)

config["total_loc_num"] = int(all_locs.loc_id.max() + 1)
config["total_user_num"] = int(train_data.user_id.max() + 1)

In [53]:
train_data["id"] = np.arange(len(train_data))
vali_data["id"] = np.arange(len(vali_data))
test_data["id"] = np.arange(len(test_data))

train_idx = _get_valid_sequence(train_data, print_progress=config.verbose, previous_day=config.previous_day)
vali_idx = _get_valid_sequence(vali_data, print_progress=config.verbose, previous_day=config.previous_day)
test_idx = _get_valid_sequence(test_data, print_progress=config.verbose, previous_day=config.previous_day)


100%|██████████| 2094/2094 [00:41<00:00, 50.11it/s]
100%|██████████| 2088/2088 [00:11<00:00, 188.71it/s]
100%|██████████| 2094/2094 [00:10<00:00, 199.15it/s]


In [5]:
temp_data = train_data.copy()
temp_data.set_index("id", inplace=True)

train_seq = [temp_data.iloc[idx[0] : idx[1]]["location_id"].values for idx in train_idx]

In [6]:
temp_data = test_data.copy()
temp_data.set_index("id", inplace=True)

test_seq = [temp_data.iloc[idx[0] : idx[1]]["location_id"].values for idx in test_idx]

In [7]:
metrics = Metric(config, locations=all_locs, input_data=vali_data, valid_start_end_idx=vali_idx)

In [8]:
jsds = metrics.get_individual_jsds(gene_data=test_seq)
print(
    "Metric: distance {:.3f}, rg {:.3f}, period {:.3f}, topk all {:.3f}, topk {:.3f}".format(
        jsds[0], jsds[1], jsds[2], jsds[3], jsds[4]
    )
)

Metric: distance 0.091, rg 0.081, period 0.024, topk all 0.005, topk 0.000


In [9]:
jsds = metrics.get_individual_jsds(gene_data=train_seq)
print(
    "Metric: distance {:.3f}, rg {:.3f}, period {:.3f}, topk all {:.3f}, topk {:.3f}".format(
        jsds[0], jsds[1], jsds[2], jsds[3], jsds[4]
    )
)

Metric: distance 0.092, rg 0.071, period 0.021, topk all 0.000, topk 0.288
