In this Notebook, we give the example of training the conditional cluster-level predictor that is introduced as baseline (condition) in the paper.

In [None]:
import torch
import torch.nn as nn

from model import GRUPredictorCondition

import numpy as np

import pickle as pk

import os
from multiprocessing import pool

from config import didi_hidx_traj_path, num_clusters, num_timeslots, time_embedding, cluster_embedding, hidden_dim, context_dim, dT, conditional_predictor_model_path

In [None]:
import pickle as pk

In [None]:
predictor = GRUPredictorCondition(num_timeslots, time_embedding, num_clusters, cluster_embedding, hidden_dim, n_layers=2).cuda(0)

In [None]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

path_list = []

for m in range(10, 12):
    for d in range(1, 32):
        filename = os.path.join(didi_hidx_traj_path, f'hidx_traj_2012{m:02d}{d:02d}.pk')
        if os.path.isfile(filename):
            path_list.append(filename)

path_list = path_list[:45]

data = [load_cluster_data(filepath) for filepath in path_list]

In [None]:
optimizer = torch.optim.Adam(predictor.parameters(), lr=1e-3)
optimizer_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)

In [None]:
import random
from itertools import product

In [None]:
batch_size = 8192

day_time_list = list(product(range(len(data)), range(num_timeslots - 2 * dT + 1)))

holidays = set([0, 1, 2, 3, 4, 5, 6, 14, 15, 21, 22, 28, 29, 35, 36, 42, 43, 49, 50, 56, 57])

for epoch in range(1, 11):
    avg_loss = 0.0
    random.shuffle(day_time_list)

    i = 0
    for d, t in day_time_list:

        i += 1
        batch_indices = torch.LongTensor(np.random.choice(data[d].shape[0], batch_size, replace=False))        
        xc = data[d][batch_indices][:, t: t + dT].cuda(0)
        xt = torch.zeros_like(xc) + t
        xd = torch.zeros_like(xt) + (1 if d in holidays else 0)
        yc = data[d][batch_indices, t + 2 * dT - 1].cuda(0)
        
        pred = predictor(xc, xt, xd)
        loss = nn.functional.cross_entropy(pred, yc)
        
        loss.backward()
        
        if i % 5 == 0:
            optimizer.step()
            optimizer.zero_grad()

        avg_loss += loss.item()
        
        print('Epoch {:03d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, i / len(day_time_list) * 100, avg_loss / i), end='\r')
        
    print()
    optimizer_scheduler.step()

In [None]:
torch.save(predictor.cpu(), conditional_predictor_model_path)