In [2]:
import torch
import torch.nn as nn

from model import CrowdPredictorGRU

import numpy as np

import pickle as pk

import os
from multiprocessing import pool

In [1]:
from config import didi_hidx_traj_path, num_clusters, num_

In [2]:
num_timeslots = 288
time_embedding = 128
cluster_embedding = 256
hidden_dim = 256
context_dim = 64

dT = 12

In [3]:
multiprocess = True

In [9]:
crowd_predictor = CrowdPredictorGRU(num_timeslots, time_embedding, num_clusters, cluster_embedding, hidden_dim, context_dim, n_layers=2).cuda(1)

<All keys matched successfully>

In [10]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

In [11]:
path_list = []

for m in range(6, 8):
    for d in range(1, 32):
        filename = '/data/fan/data/sadCrowdPrediction/cluster/cluster_2012{:02d}{:02d}.pk'.format(m, d)
        if os.path.isfile(filename):
            path_list.append(filename)

if multiprocess:
    proc_pool = pool.Pool(8)
    data = proc_pool.map(load_cluster_data, path_list)

    proc_pool.close()
else:
    data = [load_cluster_data(filepath) for filepath in path_list]

/data/fan/data/sadCrowdPrediction/cluster/cluster_20120601.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120611.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120613.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120603.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120607.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120609.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120615.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120605.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120614.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120604.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120610.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120606.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120612.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120608.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120616.pk
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120602.pk
/data/fa

In [21]:
optimizer = torch.optim.Adam(crowd_predictor.parameters(), lr=1e-4)
optimizer_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)

In [22]:
import random
from itertools import product

In [26]:
batch_size = 8192
max_iter = 100000

day_time_list = list(product(range(len(data)), range(num_timeslots - 2 * dT + 1)))

for epoch in range(1, 5):
    avg_loss = 0.0
    random.shuffle(day_time_list)

    i = 0
    for d, t in day_time_list:

        i += 1
        batch_indices = torch.LongTensor(np.random.choice(data[d].shape[0], batch_size, replace=False))        
        xc = data[d][batch_indices][:, t: t + dT].cuda(1)
        xt = torch.zeros_like(xc) + t
        yc = data[d][batch_indices, t + 2 * dT - 1].cuda(1)
        
        crowd_pred = crowd_predictor(xc, xt)
        loss = nn.functional.cross_entropy(crowd_pred, yc)
        
        loss.backward()
        
        if i % 5 == 0:
            optimizer.step()
            optimizer.zero_grad()

        avg_loss += loss.item()
        
        print('Epoch {:03d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, i / len(day_time_list) * 100, avg_loss / i), end='\r')
        
    print()
    optimizer_scheduler.step()

Epoch 001, 100.0%, avg_loss = 1.2912
Epoch 002, 66.4%, avg_loss = 1.2916

KeyboardInterrupt: 

In [27]:
torch.save(crowd_predictor.cpu(), './trained_model/crowd_pred_gru_sametime_h256_l2.model')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
