In [1]:
import torch
import torch.nn as nn

import os
import pickle as pk
import numpy as np

from model import GRUPredictor

from multiprocessing import pool

In [2]:
num_clusters = 3600
num_timeslots = 288
time_embedding = 128
cluster_embedding = 256
hidden_dim = 256
context_dim = 64

In [3]:
multiprocess = True

In [4]:
dT = 12

In [5]:
def load_cluster_data(filepath):
    print(filepath)
    with open(filepath, 'rb') as f:
        return torch.stack(list(pk.load(f).values()))

In [7]:
batch_size = 1024
max_iter = 40000

for d in range(2, 31):
    filename = '/data/fan/data/sadCrowdPrediction/cluster/cluster_201206{:02d}.pk'.format(d)
    data = load_cluster_data(filename)
    
    predictor = torch.load('./trained_model/component_predictors/component_predictor{:02d}.pytorch'.format(d)).cuda(1)
    optimizer = torch.optim.RMSprop(predictor.parameters(), lr=2e-5)
    optimizer_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)

    for epoch in range(1, 3):
        avg_loss = 0.0

        for i in range(1, 1 + max_iter):

            t = np.random.randint(num_timeslots - 2 * dT + 1)

            time_indices = [t, t + 5, t + 8, t + 10, t + 11]

            batch_indices = torch.LongTensor(np.random.choice(data.shape[0], batch_size, replace=False))        
            xc = data[batch_indices][:, time_indices].cuda(1)
            xt = torch.zeros_like(xc) + torch.LongTensor(time_indices).cuda(1)
            yc = data[batch_indices, t + 2 * dT - 1].cuda(1)

            loss = predictor(xc, xt, yc)        
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            avg_loss += loss.item()

            print('Epoch {:03d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, i / max_iter * 100, avg_loss / i), end='\r')

        print()
        optimizer_scheduler.step()
    torch.save(predictor.cpu(), './trained_model/component_predictors/component_predictor{:02d}.pytorch'.format(d))

/data/fan/data/sadCrowdPrediction/cluster/cluster_20120602.pk
Epoch 001, 100.0%, avg_loss = 1.2664
Epoch 002, 100.0%, avg_loss = 1.2602
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120603.pk
Epoch 001, 100.0%, avg_loss = 1.1248
Epoch 002, 100.0%, avg_loss = 1.1299
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120604.pk
Epoch 001, 100.0%, avg_loss = 1.3006
Epoch 002, 100.0%, avg_loss = 1.2999
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120605.pk
Epoch 001, 100.0%, avg_loss = 1.3216
Epoch 002, 100.0%, avg_loss = 1.3177
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120606.pk
Epoch 001, 100.0%, avg_loss = 1.3064
Epoch 002, 100.0%, avg_loss = 1.3076
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120607.pk
Epoch 001, 100.0%, avg_loss = 1.3328
Epoch 002, 100.0%, avg_loss = 1.3314
/data/fan/data/sadCrowdPrediction/cluster/cluster_20120608.pk
Epoch 001, 100.0%, avg_loss = 1.3555
Epoch 002, 100.0%, avg_loss = 1.3592
/data/fan/data/sadCrowdPrediction/cluster/cluste

KeyboardInterrupt: 