In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from itertools import chain

In [2]:
class AttnEncoder(nn.Module):
    def __init__(self, loc_num, embedding_dim, hidden_dim, n_layers=2):
        super(AttnEncoder, self).__init__()
        self.embedding = nn.Embedding(loc_num, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=n_layers, batch_first=True)
    
    def forward(self, x):
        embedded = self.embedding(x)
        output, hidden = self.gru(embedded, None)
        return output, hidden

In [3]:
class AttnDecoder(nn.Module):
    def __init__(self, loc_num, embedding_dim, hidden_dim, length, n_layers=2):
        super(AttnDecoder, self).__init__()
        self.embedding = nn.Embedding(loc_num, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=n_layers, batch_first=True)
        self.attn = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim * 2, loc_num - 2)
        
    def forward(self, encoder_outputs, x, hidden):
        embedded = self.embedding(x)
        output, _ = self.gru(embedded, hidden)
        weights = self.score(encoder_outputs, output)
        context = torch.bmm(weights, encoder_outputs) # B, de_s, hidden_dim
        return self.out(torch.cat([context, output], dim=2)), weights
        
    def score(self, encoder_h, decoder_h):
        return nn.functional.softmax(torch.bmm(decoder_h, encoder_h.permute(0, 2, 1)), dim=2) # B, de_s, en_s

In [4]:
LOC_NUM = 1443
EMBEDDING_DIM = 64
HIDDEN_DIM = 256
N_LAYERS = 2
pivot = 72
T = 144

encoder = AttnEncoder(LOC_NUM, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS).cuda()
decoder = AttnDecoder(LOC_NUM, EMBEDDING_DIM, HIDDEN_DIM, T - pivot).cuda()

In [5]:
optimizer = torch.optim.SGD(chain(encoder.parameters(), decoder.parameters()), lr=1e-2, momentum=0.9)
optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4, 6, 8], gamma=0.1)

In [6]:
criteria = nn.CrossEntropyLoss(ignore_index=-1)

In [7]:
data = np.concatenate([np.genfromtxt('../data/dis_forautoencoder_2012_dec_tokyo/day_{}.csv'.format(d), delimiter=',', dtype=np.int32)
                       for d in range(1, 32)], axis=0)
data = data[:, 1:]
data_in = data[:, :pivot]
data_out = data[:, pivot:]
data_size = data.shape[0]

In [8]:
batch_size = 64

for epoch in range(1, 11):
    avg_loss = 0.0
    cnt = 0
    np.random.shuffle(data)
    optim_scheduler.step()
    for i in range(0, data_size - batch_size, batch_size):
        optimizer.zero_grad()
        cnt += 1
        x_batch = data_in[i: i + batch_size]
        y_batch = data_out[i: i + batch_size]
        encoder_outputs, hidden = encoder(Variable(torch.LongTensor(x_batch)).cuda())
        
        tf_batch = np.ones([batch_size, 1], dtype=np.int32)
        for t in range(T - pivot - 1):
            if np.random.ranf() < 0.5: # teacher forcing
                tf_batch = np.concatenate([tf_batch, y_batch[:, t + 1: t + 2]], axis=1)
            else:
                pred, _ = decoder(encoder_outputs, Variable(torch.LongTensor(tf_batch)).cuda(), hidden)
                prob = nn.functional.softmax(pred[:, t], dim=1).cpu().data.numpy()
                tf_batch = np.concatenate([tf_batch, np.array(list(map(lambda p:np.random.choice(LOC_NUM - 2, 1, p=p), prob))) + 2], axis=1)

        tf_batch = Variable(torch.LongTensor(tf_batch)).cuda()
        pred, _ = decoder(encoder_outputs, tf_batch, hidden)
        loss = criteria(pred.view(-1, LOC_NUM - 2), torch.clamp(Variable(torch.LongTensor(y_batch - 2)).cuda(), min=-1).view(-1))
        loss.backward()
        avg_loss += float(loss.data[0])
        optimizer.step()
        print('Epoch {:04d}, {:.1f}%, avg_loss={:.4f}'.format(epoch, i * 100 / data_size, avg_loss / cnt), end='\r')
    print('')
    torch.save(encoder, '../results/sadAttenSeq2Seq/attn_encoder_half_tf_sampling')
    torch.save(decoder, '../results/sadAttenSeq2Seq/attn_decoder_half_tf_sampling')

Epoch 0001, 100.0%, avg_loss=2.7738


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch 0002, 100.0%, avg_loss=1.9627
Epoch 0003, 100.0%, avg_loss=1.8711
Epoch 0004, 100.0%, avg_loss=1.8598
Epoch 0005, 100.0%, avg_loss=1.8513
Epoch 0006, 100.0%, avg_loss=1.8491
Epoch 0007, 100.0%, avg_loss=1.8489
Epoch 0008, 62.3%, avg_loss=1.8492

KeyboardInterrupt: 