In [1]:
import torch
import torch.nn as nn

from model import GRUPredictorCondition

import numpy as np

import pickle as pk

import os
from multiprocessing import pool

from config import didi_hidx_traj_path, num_clusters, num_timeslots, time_embedding, cluster_embedding, hidden_dim, context_dim, dT, conditional_predictor_model_path

In [2]:
crowd_predictor = GRUPredictorCondition(num_timeslots, time_embedding, num_clusters, cluster_embedding, hidden_dim, n_layers=2)
crowd_predictor.load_state_dict(torch.load(conditional_predictor_model_path).state_dict())
crowd_predictor = crowd_predictor.cuda(0)

In [3]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

path_list = []

for m in range(10, 12):
    for d in range(1, 32):
        filename = os.path.join(didi_hidx_traj_path, f'hidx_traj_2012{m:02d}{d:02d}.pk')
        if os.path.isfile(filename):
            path_list.append(filename)

path_list = path_list[45:]

data = [load_cluster_data(filepath) for filepath in path_list]

/data/fan/didi/processed/hidx_traj/hidx_traj_20121115.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121116.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121117.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121118.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121119.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121120.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121121.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121122.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121123.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121124.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121125.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121126.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121127.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121128.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121129.pk
/data/fan/didi/processed/hidx_traj/hidx_traj_20121130.pk


In [4]:
import random
from itertools import product

In [6]:
batch_size = 8192
max_iter = 100000

day_time_list = list(product(range(len(data)), range(num_timeslots - 2 * dT + 1)))

holidays = set([0, 1, 2, 3, 4, 5, 6, 14, 15, 21, 22, 28, 29, 35, 36, 42, 43, 49, 50, 56, 57])

avg_loss = 0.0

n = 0
with torch.no_grad():
    for d, t in day_time_list:

        for j in range(0, data[d].shape[0], batch_size):
            xc = data[d][j: j + batch_size, t: t + dT].cuda(0)
            xt = torch.zeros_like(xc) + t
            xd = torch.zeros_like(xt) + (1 if d in holidays else 0)
            yc = data[d][j: j + batch_size, t + 2 * dT - 1].cuda(0)

            crowd_pred = crowd_predictor(xc, xt, xd)
            loss = nn.functional.cross_entropy(crowd_pred, yc)

            avg_loss += loss.item() * yc.shape[0]
            n += yc.shape[0]

            print('avg_loss = {:.4f}'.format(avg_loss / n), end='\r')

avg_loss = 0.5630