In [33]:
import torch
import torch.nn as nn


In [34]:
class Encoder(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.hidden = None
        self.gru = nn.GRU(n_features, hidden_dim, batch_first=True)

    def forward(self, x):
        gru_out, self.hidden = self.gru(x)
        return gru_out


In [35]:
full_seq = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]])
            .float()
            .view(1, 4, 2))
source_seq = full_seq[:, :2]  # first two corners
target_seq = full_seq[:, 2:]  # last two corners


In [36]:
source_seq


tensor([[[-1., -1.],
         [-1.,  1.]]])

In [37]:
target_seq


tensor([[[ 1.,  1.],
         [ 1., -1.]]])

In [38]:
torch.manual_seed(21)
encoder = Encoder(n_features=2, hidden_dim=2)
hidden_seq = encoder(source_seq)


In [39]:
hidden_seq


tensor([[[ 0.0832, -0.0356],
         [ 0.3105, -0.5263]]], grad_fn=<TransposeBackward1>)

In [40]:
hidden_final = hidden_seq[:, -1:]
hidden_final


tensor([[[ 0.3105, -0.5263]]], grad_fn=<SliceBackward0>)

### Decoder

In [41]:


class Decoder(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.gru = nn.GRU(n_features, hidden_dim, batch_first=True)
        self.regreesion = nn.Linear(hidden_dim, n_features)

    def init_hidden(self, hidden_seq):
        self.hidden = hidden_seq[:, -1:].permute(1, 0, 2)

    def forward(self, X):
        batch_first_output, self.hidden = self.gru(X, self.hidden)
        last_output = batch_first_output[:, -1]
        out = self.regreesion(last_output)
        return out.view(-1, 1, self.n_features)


In [42]:
torch.manual_seed(21)
decoder = Decoder(n_features=2, hidden_dim=2)
decoder.init_hidden(hidden_seq)
input = source_seq[:, -1:]

target_len = 2
for i in range(target_len):
    print(f'Hidden: {decoder.hidden}')
    out = decoder(input)
    print(f'Output: {out}\n')
    input = out


Hidden: tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
Output: tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)

Hidden: tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
Output: tensor([[[-0.0226,  0.4628]]], grad_fn=<ViewBackward0>)



### Teacher Forcing

In [43]:
decoder.init_hidden(hidden_seq)
inputs = source_seq[:, -1:]
target_len = 2

for i in range(target_len):
    print(f'Hidden: {decoder.hidden}')
    out = decoder(inputs)
    print(f'Output: {out}\n')
    inputs = target_seq[:, i:i+1]


Hidden: tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
Output: tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)

Hidden: tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
Output: tensor([[[0.2265, 0.4529]]], grad_fn=<ViewBackward0>)



### Random use predict or teaching forcing

In [44]:
# Initial hidden state will be encoder's final hidden state
decoder.init_hidden(hidden_seq)
# Initial data point is the last element of source sequence
inputs = source_seq[:, -1:]
teacher_forcing_prob = 0.5
target_len = 2
for i in range(target_len):
    print(f'Hidden: {decoder.hidden}')
    out = decoder(inputs)
    print(f'Output: {out}\n')
    # If it is teacher forcing
    if torch.rand(1) <= teacher_forcing_prob:
        # Takes the actual element
        inputs = target_seq[:, i:i+1]
    else:
        # Otherwise uses the last predicted output
        inputs = out


Hidden: tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
Output: tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)

Hidden: tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
Output: tensor([[[-0.0226,  0.4628]]], grad_fn=<ViewBackward0>)



### Encoder + Decoder

In [45]:
class Encoder_Decoder(nn.Module):
    def __init__(self, encoder, decoder, input_len, target_len, teacher_forcing_prob=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder  
        self.input_len = input_len
        self.target_len = target_len
        self.teacher_forcing_prob = teacher_forcing_prob
        self.outputs = None
    def init_outputs(self, batch_size):
        device = next(self.parameters()).device
        self.outputs[:, i:i+1] = torch.zeros(batch_size, self.target_len, self.encoder.n_features).to(device)
        
    def store_output(self, i, out):
        # Stores the output
        self.outputs[:, i:i+1, :] = out
        
    def forward(self, X):
        source_seq = X[:, :self.input_len, :]
        target_seq = X[:, self.input_len:, :]
        self.init_outputs(X.shape[0])
        
        hidden_seq = self.encoder(source_seq)

        self.decoder.init_hidden(hidden_seq)
        
        dec_inputs = source_seq[:, -1:]
        
        for i in range(self.target_len):
            out = self.decoder(dec_inputs)
            self.store_output(i, out)
            prob = self.teacher_forcing_prob
            if not self.training:
                prob = 0
            if torch.rand(1) <= prob:
                dec_inputs = target_seq[:, i:i+1, :]
            else:
                dec_inputs = out
        return self.outputs

In [46]:
enc_dec = Encoder_Decoder(encoder, decoder, input_len=2, target_len=2, teacher_forcing_prob=0.5)


In [47]:
enc_dec.train()

Encoder_Decoder(
  (encoder): Encoder(
    (gru): GRU(2, 2, batch_first=True)
  )
  (decoder): Decoder(
    (gru): GRU(2, 2, batch_first=True)
    (regreesion): Linear(in_features=2, out_features=2, bias=True)
  )
)

In [48]:
enc_dec(full_seq)

TypeError: 'NoneType' object does not support item assignment