In [1]:
import torch
import pandas
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
torch.manual_seed(0)
tqdm.get_lock().locks = []

Just as a warm-up exercise let's confirm that quadratic-sized recurrent neural network is capable of reverting a simple sequence.
Just for the sake of simplicity the sequence will be one-hot encoded and put through a network with 2 recurrent layers of $seq\_length^2$ neurons and an output linear layer providing a final output with the next sequence item. 

TODO:
- [ ] enable teacher forcing randomization
- [ ] perform validation during the training procedure
- [ ] improve the code quality
- [ ] work in batches
- [ ] check border conditions
- [ ] improve progress reporting
- [ ] try tensorboard output
- [ ] implement sequence padding to enable variable sequence length capability in the model

In [2]:
SEQ_LENGTH = 10
SAMPLES = 4000
VOCAB_SIZE = 6 # number of characters in vocab + EOS and SOS characters
src_sequences = torch.randint(VOCAB_SIZE - 2, [SAMPLES, SEQ_LENGTH])
reversed_sequences = src_sequences.flip(1)
print(src_sequences[:10])
print(reversed_sequences[:10])

tensor([[0, 3, 1, 0, 3, 3, 3, 3, 1, 3],
        [1, 2, 0, 3, 2, 0, 0, 0, 2, 1],
        [2, 3, 3, 2, 0, 1, 1, 1, 1, 0],
        [1, 0, 3, 0, 3, 1, 2, 3, 3, 0],
        [2, 3, 0, 1, 3, 1, 3, 3, 2, 3],
        [0, 1, 1, 1, 3, 0, 3, 2, 0, 3],
        [3, 2, 3, 2, 3, 0, 2, 0, 0, 0],
        [1, 1, 2, 0, 0, 1, 3, 0, 1, 2],
        [2, 3, 0, 1, 1, 3, 1, 1, 3, 2],
        [3, 3, 2, 2, 3, 0, 2, 3, 1, 0]])
tensor([[3, 1, 3, 3, 3, 3, 0, 1, 3, 0],
        [1, 2, 0, 0, 0, 2, 3, 0, 2, 1],
        [0, 1, 1, 1, 1, 0, 2, 3, 3, 2],
        [0, 3, 3, 2, 1, 3, 0, 3, 0, 1],
        [3, 2, 3, 3, 1, 3, 1, 0, 3, 2],
        [3, 0, 2, 3, 0, 3, 1, 1, 1, 0],
        [0, 0, 0, 2, 0, 3, 2, 3, 2, 3],
        [2, 1, 0, 3, 1, 0, 0, 2, 1, 1],
        [2, 3, 1, 1, 3, 1, 1, 0, 3, 2],
        [0, 1, 3, 2, 0, 3, 2, 2, 3, 3]])


In [3]:
src_sequences_one_hot = torch.zeros((SAMPLES, SEQ_LENGTH, VOCAB_SIZE)).scatter(2, src_sequences.unsqueeze(2), 1.0)

In [4]:
class ReverseEncoder(nn.Module):
    def __init__(self, seq_length, vocab_size, rec_layers_count):
        super(ReverseEncoder, self).__init__()
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.rec_layers_count = rec_layers_count 
        self.rec_layers = nn.RNN(
            input_size = vocab_size, 
            hidden_size = seq_length**2,
            nonlinearity = "tanh",
            num_layers = rec_layers_count,
            batch_first = True
        )
        self.hidden_state = self.init_hidden_state()
        
    def init_hidden_state(self):
        return torch.randn((self.rec_layers_count, 1, self.seq_length**2))
        
    def forward(self, input_sequence):
        post_recurrent, hidden = self.rec_layers(input_sequence, self.hidden_state)

        return post_recurrent, hidden
    
class ReverseDecoder(nn.Module):
    def __init__(self, seq_length, vocab_size, rec_layers_count):
        super(ReverseDecoder, self).__init__()
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.rec_layers_count = rec_layers_count 
        self.rec_layers = nn.RNN(
            input_size = vocab_size, 
            hidden_size = seq_length**2,
            nonlinearity = "tanh",
            num_layers = rec_layers_count,
            batch_first = True
        )
        self.output = nn.Linear(seq_length*seq_length, vocab_size)
        
    def forward(self, input_sequence, hidden_state):
        post_recurrent, hidden = self.rec_layers(input_sequence, hidden_state)
        item_probs = F.log_softmax(self.output(post_recurrent), dim=2)
        return item_probs, hidden

In [5]:
enc_model = ReverseEncoder(SEQ_LENGTH, VOCAB_SIZE, 2)
dec_model = ReverseDecoder(SEQ_LENGTH, VOCAB_SIZE, 2)

In [6]:
print(enc_model)
print(dec_model)

ReverseEncoder(
  (rec_layers): RNN(6, 100, num_layers=2, batch_first=True)
)
ReverseDecoder(
  (rec_layers): RNN(6, 100, num_layers=2, batch_first=True)
  (output): Linear(in_features=100, out_features=6, bias=True)
)


In [7]:
loss_function = nn.NLLLoss()
optimizer = torch.optim.Adam(
    list(enc_model.parameters()) + list(dec_model.parameters()), 
    lr = 0.0005
)

In [8]:
SOS = torch.tensor(VOCAB_SIZE - 2)
EOS = torch.tensor(VOCAB_SIZE - 1)
SOS_filler = torch.cat((torch.zeros(VOCAB_SIZE - 2), torch.tensor([1.0, 0.0])))
EOS_filler = torch.cat((torch.zeros(VOCAB_SIZE - 1), torch.tensor([1.0])))

In [9]:
def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

In [10]:
epochs = 10
BATCH_SIZE = 64
losses = []
i = 0
for epoch in range(epochs):
    with tqdm(list(chunks(list(zip(src_sequences_one_hot, reversed_sequences)), BATCH_SIZE))) as cit:
        for chunk in cit:
            loss = 0
            for (sequence, sequence_y) in chunk:
                X, y = sequence.unsqueeze(0), sequence_y
                single_batch_result_out, hidden = enc_model(X)
                Xss = torch.cat((SOS_filler.unsqueeze(0), X.squeeze(0)), dim=0)
                yss = torch.cat((y, EOS.unsqueeze(0)), dim=0).unsqueeze(1)
                for X_char, y_char in zip(Xss, yss):
                    y_pred, hidden = dec_model(X_char.unsqueeze(0).unsqueeze(0), hidden)
                    cur_loss = loss_function(y_pred.squeeze(0), y_char)
                    loss += cur_loss
                    i += 1
            enc_model.zero_grad()
            dec_model.zero_grad()
            loss.backward()
            losses.append(loss.tolist())
            optimizer.step()
            cit.set_postfix({
                'epoch': epoch, 
                'mean_loss': sum(losses)/len(losses),
                'last_loss': losses[-1]
            })

    
    
   

100%|██████████| 63/63 [00:35<00:00,  2.08it/s, epoch=0, mean_loss=1.03e+3, last_loss=447]    
100%|██████████| 63/63 [00:35<00:00,  2.12it/s, epoch=1, mean_loss=938, last_loss=395]    
100%|██████████| 63/63 [00:34<00:00,  2.09it/s, epoch=2, mean_loss=857, last_loss=305]
100%|██████████| 63/63 [00:34<00:00,  2.12it/s, epoch=3, mean_loss=769, last_loss=195]
100%|██████████| 63/63 [00:37<00:00,  2.11it/s, epoch=4, mean_loss=675, last_loss=105]
100%|██████████| 63/63 [00:33<00:00,  2.24it/s, epoch=5, mean_loss=584, last_loss=47.6]
100%|██████████| 63/63 [00:33<00:00,  2.24it/s, epoch=6, mean_loss=508, last_loss=20.5]
100%|██████████| 63/63 [00:32<00:00,  2.22it/s, epoch=7, mean_loss=448, last_loss=10.6]
100%|██████████| 63/63 [00:32<00:00,  2.25it/s, epoch=8, mean_loss=400, last_loss=6.9] 
100%|██████████| 63/63 [00:32<00:00,  2.21it/s, epoch=9, mean_loss=361, last_loss=4.94]


In [11]:
def reverse_sequence(seq, model):
    seq_one_hot = torch.zeros((1, SEQ_LENGTH, VOCAB_SIZE)).scatter(2, seq.unsqueeze(2), 1.0)
    (_, hidden) = model[0](seq_one_hot)
    result = []
    out, hidden = model[1](SOS_filler.unsqueeze(0).unsqueeze(0), hidden)
    result.append(torch.argmax(out, dim=2).squeeze(0).squeeze(0).tolist())
    for seq_char in seq_one_hot.squeeze(0):
        out, hidden = model[1](seq_char.unsqueeze(0).unsqueeze(0), hidden)
        result.append(torch.argmax(out, dim=2).squeeze(0).squeeze(0).tolist())
    return result[:-1]

In [12]:
test_cases = [
    torch.tensor([[3, 2, 1, 0, 1, 2, 3, 2, 1, 0]]),
    torch.tensor([[3, 2, 1, 3, 2, 1, 3, 2, 1, 0]]),
    torch.tensor([[3, 2, 1, 0, 0, 0, 3, 2, 1, 0]])
]
for test_case in test_cases:
    print(reverse_sequence(test_case, (enc_model, dec_model)), test_case)


[0, 1, 2, 3, 2, 1, 0, 1, 2, 3] tensor([[3, 2, 1, 0, 1, 2, 3, 2, 1, 0]])
[0, 1, 2, 3, 1, 2, 3, 1, 2, 3] tensor([[3, 2, 1, 3, 2, 1, 3, 2, 1, 0]])
[0, 1, 2, 3, 0, 0, 0, 1, 2, 3] tensor([[3, 2, 1, 0, 0, 0, 3, 2, 1, 0]])
