- Batch the dataframe by a few thousand and write tensors to a directory
- Use that directory to train the model, Pytorch dataset can data file names as input

In [1]:
import pandas as pd
import ast
import h3
import pickle
import numpy as np
import torch
import math
import random
import time
import multiprocessing as mp
import itertools

In [2]:
data = pd.read_csv("../data/original_data.csv")
data = data.sample(35000)
data = data.reset_index(drop=True)
#data = pd.read_csv("../data/subset_data.csv")
data.head()

Unnamed: 0,TRIP_ID,CALL_TYPE,ORIGIN_CALL,ORIGIN_STAND,TAXI_ID,TIMESTAMP,DAY_TYPE,MISSING_DATA,POLYLINE
0,1381408427620000361,C,,,20000361,1381408427,A,False,"[[-8.626437,41.152185],[-8.625753,41.152428],[..."
1,1402587497620000166,C,,,20000166,1402587497,A,False,"[[-8.617113,41.143815],[-8.615349,41.147172],[..."
2,1399759179620000664,C,,,20000664,1399759179,A,False,"[[-8.616276,41.147082],[-8.61741,41.147217],[-..."
3,1383443307620000305,B,,57.0,20000305,1383443307,A,False,"[[-8.611002,41.146056],[-8.610831,41.145993],[..."
4,1402355881620000288,B,,54.0,20000288,1402355881,A,False,"[[-8.630253,41.157342],[-8.629686,41.157036],[..."


In [3]:
num_samples = len(data)
print(num_samples)

35000


In [4]:
num_samples / 25000

1.4

https://www.kaggle.com/crailtap/taxi-trajectory

## Functions

In [5]:
def remove_repeats(vals):
    if len(vals) == 0:
        return None
    result = []
    curr = vals[0]
    for val in vals[1:]:
        if val != curr:
            result.append(curr)
        curr = val
    result.append(curr)
    return result

In [6]:
def latlon_to_h3(latlons, res=9):
    latlons = ast.literal_eval(latlons)
    result = []
    for latlon in latlons:
        h3_id = h3.geo_to_h3(latlon[0], latlon[1], res)
        result.append(h3_id)
    result = remove_repeats(result)
    return result

## Feature Engineering

### Convert to H3 Ids

In [7]:
data["H3_POLYLINE"] = data["POLYLINE"].apply(latlon_to_h3)

In [8]:
data = data.dropna(axis=0, subset=["H3_POLYLINE"])

In [9]:
data["len_h3"] = data["H3_POLYLINE"].apply(len)
data = data[data["len_h3"] > 1]
print(data["len_h3"].sum())
data = data.drop(columns=["len_h3"])

677354


### Tokenizing H3 IDs

In [10]:
unq_h3_ids = data.explode("H3_POLYLINE").H3_POLYLINE.unique()

In [11]:
h3_to_token = {}
token_to_h3 = {}
for i in range(len(unq_h3_ids)):
    h3_to_token[unq_h3_ids[i]] = i
    token_to_h3[i] = unq_h3_ids[i]

In [12]:
with open("../models/tokenizers/encode_h3_to_token.pickle", "wb") as f:
    pickle.dump(h3_to_token, f)
with open("../models/tokenizers/decode_token_to_h3.pickle", "wb") as f:
    pickle.dump(token_to_h3, f)

In [13]:
def tokenize(vals, val_to_token_dict):
    result = []
    for val in vals:
        result.append(val_to_token_dict[val])
    return result

In [14]:
def decode(tokens, token_to_val_dict):
    result = []
    for token in tokens:
        result.append(token_to_val_dict[token])
    return result

In [15]:
data["h3_tokens"] = data["H3_POLYLINE"].apply(lambda x: tokenize(x, h3_to_token))

In [16]:
token_vocab = data["h3_tokens"].explode().unique().tolist()

In [17]:
data.head()

Unnamed: 0,TRIP_ID,CALL_TYPE,ORIGIN_CALL,ORIGIN_STAND,TAXI_ID,TIMESTAMP,DAY_TYPE,MISSING_DATA,POLYLINE,H3_POLYLINE,h3_tokens
0,1381408427620000361,C,,,20000361,1381408427,A,False,"[[-8.626437,41.152185],[-8.625753,41.152428],[...","[897b63adb8fffff, 897b63adb8bffff, 897b63adb9b...","[0, 1, 2, 3, 4, 2, 5]"
1,1402587497620000166,C,,,20000166,1402587497,A,False,"[[-8.617113,41.143815],[-8.615349,41.147172],[...","[897b63adb23ffff, 897b63adb3bffff]","[6, 7]"
2,1399759179620000664,C,,,20000664,1399759179,A,False,"[[-8.616276,41.147082],[-8.61741,41.147217],[-...","[897b63adb3bffff, 897b63adb07ffff, 897b63adbab...","[7, 8, 9, 10, 0, 11, 12, 13, 14, 15, 16, 17, 18]"
3,1383443307620000305,B,,57.0,20000305,1383443307,A,False,"[[-8.611002,41.146056],[-8.610831,41.145993],[...","[897b63adb67ffff, 897b63adb6fffff, 897b63adb67...","[19, 20, 19, 21, 22, 23, 6, 23, 24, 23, 25, 26..."
4,1402355881620000288,B,,54.0,20000288,1402355881,A,False,"[[-8.630253,41.157342],[-8.629686,41.157036],[...","[897b63ad867ffff, 897b63adb9bffff, 897b63adb83...","[4, 2, 11, 0, 10, 9, 8]"


### Skipgram features

In [18]:
temp_ids = data.loc[88, "H3_POLYLINE"]

In [19]:
test_inp = list(range(1, 6))
window_size = 2

#### Positive pairs

In [20]:
def get_positive_pairs(seq, window_size=3):
    pairs = []
    for i in range(len(seq)):
        for j in reversed(range(1, window_size+1)):
            new_idx = i - j
            if new_idx >= 0:
                pairs.append([seq[i], seq[new_idx], 1])
        for k in range(1, window_size+1):
            new_idx = i + k
            if new_idx < len(seq):
                pairs.append([seq[i], seq[new_idx], 1])
    return pairs

In [21]:
start_time = time.time()
temp_pos_pairs = get_positive_pairs(test_inp)
print(time.time() - start_time)
temp_pos_pairs[0]

5.745887756347656e-05


[1, 2, 1]

In [22]:
start = time.time()
data["h3_tokens"] = data["h3_tokens"].apply(lambda x: get_positive_pairs(x, 2))
data = data.explode("h3_tokens")
print(time.time() - start)

3.369764804840088


#### Get negative pairs

In [23]:
def get_training_sample(positive_pair, vocab, num_pairs=2):
    try:
        target = positive_pair[0]
    except:
        print(positive_pair)
    contexts = [positive_pair[1]]
    labels = [1]
    while True:
        neg_context = np.random.choice(vocab, size=1)[0]
        if neg_context not in positive_pair:
            contexts.append(neg_context)
            labels.append(0)
        if len(contexts) == num_pairs + 1:
            break
            
    c_l = list(zip(contexts, labels))
    random.shuffle(c_l)
    contexts, labels = zip(*c_l)

    return target, *contexts, *labels

In [24]:
temp_pos = temp_pos_pairs[0]

In [25]:
start_time = time.time()
temp_neg_pairs = get_training_sample(temp_pos, token_vocab, 2)
print(time.time() - start_time)
temp_neg_pairs

0.0020775794982910156


(1, 4272, 2110, 2, 0, 0, 1)

### Creating training set

In [26]:
def tokens_to_skipgram_data_mp(token_pairs, token_vocab, num_neg_sample):
    pool = mp.Pool(processes=8)
    result = pool.starmap_async(get_training_sample, zip(token_pairs, itertools.repeat(token_vocab), itertools.repeat(num_neg_sample)))
    result = np.array(result.get())
    targets = np.expand_dims(result[:, 0], 1)
    contexts = result[:, 1:num_neg_sample+2]
    labels = result[:, num_neg_sample+2:]
    return targets, contexts, labels

In [27]:
def get_batch_indices(data_len, batch_sz):
    batches = []
    num_batches = math.ceil(data_len / batch_sz)
    for i in range(num_batches):
        batches.append((i*batch_sz, (i+1)*batch_sz))
    return batches

In [28]:
batch_idx = get_batch_indices(len(data), 30000)

In [29]:
start = time.time()
train_targets, train_contexts, train_labels = tokens_to_skipgram_data_mp(data.h3_tokens, token_vocab, 2)
print(time.time() - start)

1220.2296116352081


## Making Torch Tensor

In [None]:
train_targets_tensor = torch.tensor(train_targets)
train_contexts_tensor = torch.tensor(train_contexts)
train_labels_tensor = torch.tensor(train_labels)

In [None]:
print(train_targets_tensor.shape)
print(train_contexts_tensor.shape)
print(train_labels_tensor.shape)

In [None]:
torch.save(train_targets_tensor, "../data/subset_train_targets.pt")
torch.save(train_contexts_tensor, "../data/subset_train_contexts.pt")
torch.save(train_labels_tensor, "../data/subset_train_labels.pt")