In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

import numpy as np
from itertools import chain

from model import Encoder, Decoder

In [2]:
data = np.concatenate([np.genfromtxt('../data/dis_forautoencoder_2012_dec_tokyo/day_{}.csv'.format(d), delimiter=',', dtype=np.int32)
                       for d in range(1, 32)], axis=0)

In [3]:
data = data[:, 1:]

In [4]:
pivot = 72
T = 144
data_in = data[:, :pivot]
data_out = data[:, pivot:]

In [5]:
data_size = data.shape[0]

In [6]:
LOC_NUM = 1443
EMBEDDING_DIM = 64
HIDDEN_DIM = 256
LATENT_DIM = 256
N_LAYERS = 2

encoder = Encoder(LOC_NUM, EMBEDDING_DIM, HIDDEN_DIM, LATENT_DIM, N_LAYERS).cuda()
decoder = Decoder(LOC_NUM, EMBEDDING_DIM, HIDDEN_DIM, LATENT_DIM, N_LAYERS).cuda()

In [7]:
optimizer = torch.optim.SGD(chain(encoder.parameters(), decoder.parameters()), lr=1e-2, momentum=0.9)
optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4, 6, 8], gamma=0.1)

In [8]:
criteria = nn.CrossEntropyLoss(ignore_index=-1)

In [9]:
batch_size = 64

for epoch in range(1, 11):
    avg_loss = 0.0
    cnt = 0
    np.random.shuffle(data)
    optim_scheduler.step()
    for i in range(0, data_size - batch_size, batch_size):
        cnt += 1
        x_batch = data_in[i: i + batch_size]
        y_batch = data_out[i: i + batch_size]
        tf_batch = np.ones([batch_size, T - pivot], dtype=np.int32)
        tf_batch[:, 1:] = y_batch[:, :-1]

        x_batch = Variable(torch.LongTensor(x_batch)).cuda()
        y_batch = Variable(torch.LongTensor(y_batch)).cuda()
        tf_batch = Variable(torch.LongTensor(tf_batch)).cuda()

        latent_code = encoder(x_batch)
        pred = decoder(latent_code[:, -1], tf_batch, T - pivot)
        loss = criteria(pred.view(-1, LOC_NUM - 2), torch.clamp(y_batch - 2, min=-1).view(-1))

        optimizer.zero_grad()
        loss.backward()
        avg_loss += float(loss.data[0])
        optimizer.step()
        print('Epoch {:04d}, {:.1f}%, avg_loss={:.4f}'.format(epoch, i * 100 / data_size, avg_loss / cnt), end='\r')
    print('')
    torch.save(encoder, '../results/sadAttenSeq2Seq/encoder')
    torch.save(decoder, '../results/sadAttenSeq2Seq/decoder')

Epoch 0001, 100.0%, avg_loss=2.6239
Epoch 0002, 100.0%, avg_loss=1.8865
Epoch 0003, 100.0%, avg_loss=1.8000
Epoch 0004, 100.0%, avg_loss=1.7898
Epoch 0005, 100.0%, avg_loss=1.7832
Epoch 0006, 100.0%, avg_loss=1.7822
Epoch 0007, 100.0%, avg_loss=1.7816
Epoch 0008, 100.0%, avg_loss=1.7815
Epoch 0009, 100.0%, avg_loss=1.7813
Epoch 0010, 100.0%, avg_loss=1.7814
