In this Notebook, we give the example of training the cluster-level predictor that is introduced in the paper. For other baseline models, you can change the import (for )

In [1]:
import torch
import torch.nn as nn

from model import CrowdPredictor

import numpy as np

import pickle as pk

import os
from multiprocessing import pool

from config import didi_hidx_traj_path, num_clusters, num_timeslots, time_embedding, cluster_embedding, hidden_dim, context_dim, dT, context_mean_predictor_model_path

In [2]:
crowd_predictor = CrowdPredictor(num_timeslots, time_embedding, num_clusters, cluster_embedding, hidden_dim, context_dim, n_layers=2, pooling_func='mean').cuda()

In [3]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

path_list = []

for m in range(10, 12):
    for d in range(1, 32):
        filename = os.path.join(didi_hidx_traj_path, f'hidx_traj_2012{m:02d}{d:02d}.pk')
        if os.path.isfile(filename):
            path_list.append(filename)

path_list = path_list[:45]

data = [load_cluster_data(filepath) for filepath in path_list]

/data/fan/didi/processed/hidx_traj/hidx_traj_20121001.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121002.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121003.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121004.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121005.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121006.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121007.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121008.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121009.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121010.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121011.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121012.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121013.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121014.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121015.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121016.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121017.pk
/data/fan/didi/processed/hidx_t

In [4]:
optimizer = torch.optim.Adam(crowd_predictor.parameters(), lr=1e-3)
optimizer_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)

In [5]:
import random
from itertools import product

In [6]:
batch_size = 8192
max_iter = 100000

day_time_list = list(product(range(len(data)), range(num_timeslots - 2 * dT + 1)))

for epoch in range(1, 11):
    avg_loss = 0.0
    random.shuffle(day_time_list)

    i = 0
    for d, t in day_time_list:

        i += 1
        batch_indices = torch.LongTensor(np.random.choice(data[d].shape[0], batch_size, replace=False))        
        xc = data[d][batch_indices][:, t: t + dT].cuda()
        xt = torch.zeros_like(xc) + t
        yc = data[d][batch_indices, t + 2 * dT - 1].cuda()

        crowd_pred = crowd_predictor(xc, xt)
        loss = nn.functional.cross_entropy(crowd_pred, yc)

        loss.backward()
        
        if i % 5 == 0:
            optimizer.step()
            optimizer.zero_grad()

        avg_loss += loss.item()
        
        print('Epoch {:03d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, i / len(day_time_list) * 100, avg_loss / i), end='\r')
        
    print()
    optimizer_scheduler.step()

Epoch 001, 100.0%, avg_loss = 0.6062
Epoch 002, 100.0%, avg_loss = 0.4938
Epoch 003, 100.0%, avg_loss = 0.4874
Epoch 004, 100.0%, avg_loss = 0.4844
Epoch 005, 100.0%, avg_loss = 0.4828
Epoch 006, 100.0%, avg_loss = 0.4823
Epoch 007, 100.0%, avg_loss = 0.4816
Epoch 008, 100.0%, avg_loss = 0.4816
Epoch 009, 100.0%, avg_loss = 0.4817
Epoch 010, 100.0%, avg_loss = 0.4814


In [7]:
torch.save(crowd_predictor.cpu(), context_mean_predictor_model_path)

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
