# RNN-based Encoder-Decoder实现属性槽预测

In [1]:
import torch
import torch.nn as nn

## 模型构建

In [2]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.embeding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers, batch_first=True)
    
    def forward(self, x):
        x = self.embeding(x)
        output, (h, c) = self.rnn(x)
        return output

In [35]:
full_seq = torch.tensor([[0, 2, 3, 4, 5, 1], [6, 7, 8, 9, 10, 6]])

In [36]:
full_seq

tensor([[0, 2, 3, 4, 5, 1],
        [0, 6, 7, 8, 9, 1]])

In [37]:
source_seq = full_seq[:1] # first two
target_seq = full_seq[1:] # last two

In [38]:
source_seq, target_seq

(tensor([[0, 2, 3, 4, 5, 1]]), tensor([[0, 6, 7, 8, 9, 1]]))

In [40]:
with torch.no_grad():
    encoder = Encoder(10,10,10,1)
    hidden_seq = encoder(source_seq)
    hidden_final = hidden_seq[:, -1:]

In [41]:
hidden_seq

tensor([[[-0.2476,  0.0519, -0.0366,  0.0613,  0.0198,  0.0061,  0.1407,
           0.0809,  0.0987,  0.1103],
         [-0.1229,  0.1263, -0.1942,  0.3224,  0.0229,  0.0441,  0.0986,
           0.1420,  0.0580, -0.2052],
         [-0.1190,  0.1428, -0.1500,  0.1121, -0.0704,  0.1114,  0.1376,
          -0.0477,  0.1730, -0.2858],
         [-0.1881,  0.1444, -0.0652,  0.3063, -0.1572, -0.0243,  0.2957,
          -0.0518,  0.0281, -0.0133],
         [-0.0846,  0.0397, -0.0037,  0.0371, -0.1898,  0.0167,  0.3401,
           0.1125, -0.0787, -0.1691],
         [-0.0833,  0.0933, -0.0885,  0.2465,  0.0876, -0.1497,  0.0993,
           0.0630,  0.0852, -0.0574]]])

In [42]:
hidden_final

tensor([[[-0.0833,  0.0933, -0.0885,  0.2465,  0.0876, -0.1497,  0.0993,
           0.0630,  0.0852, -0.0574]]])

In [43]:
class Decoder(nn.Module):
    def __init__(self, n_features, hidden_size):
        super(Decoder, self).__init__()
        self.hidden = None
        self.cell = None
        self.rnn = nn.LSTM(n_features, hidden_size)
        self.linear = nn.Linear(hidden_size, n_features)
        
    def init_hidden(self, hidden_seq):
        hidden_final = hidden_seq[:, -1:]
        self.hidden = hidden_final.permute(1,0,2)
        
    def forward(self, x):
        output, self.hidden, self.cell = self.rnn(x, self.hidden, self.cell)
        last_output = output[:, -1]
        out = self.linear(last_output)
        return out

In [65]:
with torch.no_grad():
    decoder = Decoder(n_features=10, hidden_size=10)
    decoder.init_hidden(hidden_seq)
    print(decoder.hidden)
    
    

tensor([[[-0.0833,  0.0933, -0.0885,  0.2465,  0.0876, -0.1497,  0.0993,
           0.0630,  0.0852, -0.0574]]])
