In [1]:
import torch
import torch.nn as nn

import os
import pickle as pk
import numpy as np

from model import GRUPredictor

from multiprocessing import pool

In [2]:
from config import ensemble_components_folder_path
from config import didi_hidx_traj_path
from config import num_clusters, num_timeslots, time_embedding, cluster_embedding, hidden_dim, context_dim, dT

In [5]:
from config import gru_predictor_model_path

In [3]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

In [13]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

path_list = []

for m in range(10, 12):
    for d in range(1, 32):
        filename = os.path.join(didi_hidx_traj_path, f'hidx_traj_2012{m:02d}{d:02d}.pk')
        if os.path.isfile(filename):
            path_list.append((m, d, filename))

path_list = path_list[:14]

In [14]:
if not os.path.isdir(ensemble_components_folder_path):
    os.mkdir(ensemble_components_folder_path)
    print('Create folder', ensemble_components_folder_path)

In [15]:
batch_size = 1024
max_iter = 10000

for m, d, filename in path_list:
    data = load_cluster_data(filename)
    
    predictor = torch.load(gru_predictor_model_path).cuda(1)
    optimizer = torch.optim.RMSprop(predictor.parameters(), lr=1e-5)
    optimizer_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)

    for epoch in range(1, 4):
        avg_loss = 0.0

        for i in range(1, 1 + max_iter):

            i += 1
            t = np.random.randint(num_timeslots - 2 * dT + 1)
            
            batch_indices = torch.LongTensor(np.random.choice(data.shape[0], batch_size, replace=False))        
            xc = data[batch_indices][:, t: t + dT].cuda(1)
            xt = torch.zeros_like(xc) + t
            yc = data[batch_indices, t + 2 * dT - 1].cuda(1)

            pred = predictor(xc, xt)
            loss = nn.functional.cross_entropy(pred, yc)

            loss.backward()

            if i % 5 == 0:
                optimizer.step()
                optimizer.zero_grad()

            avg_loss += loss.item()

            print('Epoch {:03d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, i / max_iter * 100, avg_loss / i), end='\r')

        print()
        optimizer_scheduler.step()
    torch.save(predictor.cpu(), f'{ensemble_components_folder_path}ensemble_gru_day_{d:02d}.pytorch')

/data/fan/didi/processed/hidx_traj/hidx_traj_20121001.pk
Epoch 001, 100.0%, avg_loss = 0.4566
Epoch 002, 100.0%, avg_loss = 0.4503
Epoch 003, 100.0%, avg_loss = 0.4502
/data/fan/didi/processed/hidx_traj/hidx_traj_20121002.pk
Epoch 001, 100.0%, avg_loss = 0.4482
Epoch 002, 100.0%, avg_loss = 0.4438
Epoch 003, 100.0%, avg_loss = 0.4414
/data/fan/didi/processed/hidx_traj/hidx_traj_20121003.pk
Epoch 001, 100.0%, avg_loss = 0.4492
Epoch 002, 100.0%, avg_loss = 0.4413
Epoch 003, 100.0%, avg_loss = 0.4416
/data/fan/didi/processed/hidx_traj/hidx_traj_20121004.pk
Epoch 001, 100.0%, avg_loss = 0.4449
Epoch 002, 100.0%, avg_loss = 0.4405
Epoch 003, 100.0%, avg_loss = 0.4384
/data/fan/didi/processed/hidx_traj/hidx_traj_20121005.pk
Epoch 001, 100.0%, avg_loss = 0.4420
Epoch 002, 100.0%, avg_loss = 0.4392
Epoch 003, 100.0%, avg_loss = 0.4338
/data/fan/didi/processed/hidx_traj/hidx_traj_20121006.pk
Epoch 001, 100.0%, avg_loss = 0.4488
Epoch 002, 100.0%, avg_loss = 0.4416
Epoch 003, 100.0%, avg_loss =