In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
vocab_size = 5000     # 词汇表大小
embedding_size = 64   # 词向量维度
embedding = nn.Embedding(vocab_size, embedding_size)

data = [[1, 2, 4, 6, 3],  [4, 6, 7, 3, 2]]
inputs = Variable(torch.LongTensor(data))  # 输入
print('Input size:', inputs.size())

embedding_output = embedding(inputs)    
print('Output size:', embedding_output.size())

Input size: torch.Size([2, 5])
Output size: torch.Size([2, 5, 64])


In [9]:
hidden_size = 128   # RNN隐藏层维度
seq_len = 5         # 序列长度
batch_size = 2 

rnn = nn.RNN(embedding_size, hidden_size)      # 单层RNN
rnn_inputs = embedding_output.transpose(0, 1)  # 前2个维度转置，变为 (seq_len, batch, embedding_size)
hidden = Variable(torch.randn(1, batch_size, hidden_size))      # 初始化hidden
rnn_outputs, hidden = rnn(rnn_inputs, hidden)   # 输出为 (seq_len, batch, hidden_size)
print('RNN output size:', rnn_outputs.size())

RNN output size: torch.Size([5, 2, 128])


In [15]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size):
        super(RNNModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.RNN(embedding_size, hidden_size)
        
    def forward(self, x, hidden):
        x = self.embedding(x)     # 词嵌入, (batch, seq_len, embedding_size)
        x = x.transpose(0, 1)      # 转置, (seq_len, batch, embedding_size)
        output, hidden = self.rnn(x, hidden)   # rnn计算
        return output, hidden
    
    def init_hidden(self, batch_size, hidden_size):   # 初始化隐藏层
        return Variable(torch.zeros(1, batch_size, hidden_size))

In [16]:
rnn = RNNModel(vocab_size, embedding_size, hidden_size)

In [23]:
rnn = RNNModel(vocab_size, embedding_size, hidden_size)

data = [[1, 2, 4, 6, 3],  [4, 6, 7, 3, 2]]
inputs = Variable(torch.LongTensor(data))   # # 维度为 (batch, seq_len)
hidden = rnn.init_hidden(batch_size, hidden_size)
rnn_outputs, hidden = rnn(inputs, hidden)  # 输出为 (seq_len, batch, hidden_size)
print('RNN output size:', rnn_outputs.size())

RNN output size: torch.Size([5, 2, 128])


In [None]:
embedding_output = embedding(inputs)
embedding_output.size()

In [None]:
inputs_res = inputs.transpose(0, 1)

In [None]:
inputs_res

In [None]:
out_res = embedding(inputs_res)

In [None]:
embedding.weight

In [None]:
out_res

In [None]:
embedding_output

In [None]:
out_res[:, 0]

In [30]:
data = [[1, 2, 4, 6, 3],  [4, 6, 7, 3, 2]]
data_t = torch.LongTensor(data).transpose(0, 1)  # 在网络之外就转置完毕
inputs = Variable(data_t)   # 输入，(seq_len, batch)
print(inputs)

Variable containing:
 1  4
 2  6
 4  7
 6  3
 3  2
[torch.LongTensor of size 5x2]



In [31]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size):
        super(RNNModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.RNN(embedding_size, hidden_size)
        
    def forward(self, x, hidden):
        x = self.embedding(x)     # 词嵌入, (seq_len, batch, embedding_size)
        output, hidden = self.rnn(x, hidden)   # rnn计算
        return output, hidden
    
    def init_hidden(self, batch_size, hidden_size):   # 初始化隐藏层
        return Variable(torch.zeros(1, batch_size, hidden_size))

In [32]:
rnn = RNNModel(vocab_size, embedding_size, hidden_size)

hidden = rnn.init_hidden(batch_size, hidden_size)
rnn_outputs, hidden = rnn(inputs, hidden)  # 输出为 (seq_len, batch, hidden_size)
print('RNN output size:', rnn_outputs.size())

RNN output size: torch.Size([5, 2, 128])


In [34]:
import numpy as np

In [35]:
data = np.arange(60)
data

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59])

In [41]:
batch_size = 4
seq_len = 5
num_batch = len(data) // batch_size // seq_len
data.reshape(num_batch, batch_size, -1)

array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

In [54]:
data.reshape(num_batch, batch_size, -1).swapaxes(1, 2)

array([[[ 0,  5, 10, 15],
        [ 1,  6, 11, 16],
        [ 2,  7, 12, 17],
        [ 3,  8, 13, 18],
        [ 4,  9, 14, 19]],

       [[20, 25, 30, 35],
        [21, 26, 31, 36],
        [22, 27, 32, 37],
        [23, 28, 33, 38],
        [24, 29, 34, 39]],

       [[40, 45, 50, 55],
        [41, 46, 51, 56],
        [42, 47, 52, 57],
        [43, 48, 53, 58],
        [44, 49, 54, 59]]])

In [42]:
target = np.arange(1, 61)
target

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
       35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
       52, 53, 54, 55, 56, 57, 58, 59, 60])

In [43]:
batch_size = 4
seq_len = 5
num_batch = len(target) // batch_size // seq_len
target.reshape(num_batch, batch_size, -1)

array([[[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20]],

       [[21, 22, 23, 24, 25],
        [26, 27, 28, 29, 30],
        [31, 32, 33, 34, 35],
        [36, 37, 38, 39, 40]],

       [[41, 42, 43, 44, 45],
        [46, 47, 48, 49, 50],
        [51, 52, 53, 54, 55],
        [56, 57, 58, 59, 60]]])

In [64]:
target.reshape(num_batch, batch_size, -1).swapaxes(1, 2)

array([[[ 1,  6, 11, 16],
        [ 2,  7, 12, 17],
        [ 3,  8, 13, 18],
        [ 4,  9, 14, 19],
        [ 5, 10, 15, 20]],

       [[21, 26, 31, 36],
        [22, 27, 32, 37],
        [23, 28, 33, 38],
        [24, 29, 34, 39],
        [25, 30, 35, 40]],

       [[41, 46, 51, 56],
        [42, 47, 52, 57],
        [43, 48, 53, 58],
        [44, 49, 54, 59],
        [45, 50, 55, 60]]])

In [66]:
target.reshape(num_batch, batch_size, -1).swapaxes(1, 2).reshape(num_batch, -1)

array([[ 1,  6, 11, 16,  2,  7, 12, 17,  3,  8, 13, 18,  4,  9, 14, 19,  5,
        10, 15, 20],
       [21, 26, 31, 36, 22, 27, 32, 37, 23, 28, 33, 38, 24, 29, 34, 39, 25,
        30, 35, 40],
       [41, 46, 51, 56, 42, 47, 52, 57, 43, 48, 53, 58, 44, 49, 54, 59, 45,
        50, 55, 60]])