# Seq2Seq - Encoder/Decoder networks
In this exercise we'll have a deeper look into the ability to use multiple RNN's to infer and generate sequences of data.
Specifically we will implement a Encoder-Decoder RNN based for a simple sequence to sequence translation task.
This type of models have shown impressive performance in Neural Machine Translation and Image Caption generation. 

In the encoder-decoder structure one RNN (blue) encodes the input into a hidden representation, and a second RNN (red) uses this representation to predict the target values.
An essential step is deciding how the encoder and decoder should communicate.
In the simplest approach you use the last hidden state of the encoder to initialize the decoder.
This is what we will do in this notebook, as shown here:

![](./images/enc-dec.png)

In this exercise we will translate from the words of number (e.g. 'nine') to the actual number (e.g. '9').
The input for the Encoder RNN consists of words defining the number, whilst the output of such an encoding serves as input for the Decoder RNN that aims to generate generate a number. 
Our dataset is generated and consists of numbers and an End-of-Sentence (EOS) character ('#'). The data we want to generate should be like follows:

```
Examples: 
prediction  |  input
991136#00 	 nine nine one one three six
81771#000 	 eight one seven seven one
3519614#0 	 three five one nine six one four
26656#000 	 two six six five six
60344#000 	 six zero three four four
162885#00 	 one six two eight eight five
78612625# 	 seven eight six one two six two five
9464710#0 	 nine four six four seven one zero
191306#00 	 one nine one three zero six
10160378# 	 one zero one zero six three seven eight
```

Let us define the space of characters and numbers to be learned with the networks:

```
Number of valid characters: 27
'0'=0,	'1'=1,	'2'=2,	'3'=3,	'4'=4,	'5'=5,	'6'=6,	'7'=7,	'8'=8,	'9'=9,	'#'=10,	' '=11,	'e'=12,	'g'=13,	'f'=14,	'i'=15,	'h'=16,	'o'=17,	'n'=18,	's'=19,	'r'=20,	'u'=21,	't'=22,	'w'=23,	'v'=24,	'x'=25,	'z'=26,	
Stop/start character = #
```

All represented characters and numbers as characters, gets mapped to an integer from 0-26. Our total space of valid characters consists of 27.

In [293]:
from data_generator import generate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from data_generator import generate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device in use:", device)

NUM_INPUTS = 27 #No. of possible characters
NUM_OUTPUTS = 11  # (0-9 + '#')

### Hyperparameters and general configs
MAX_SEQ_LEN = 8
MIN_SEQ_LEN = 5
BATCH_SIZE = 80
TRAINING_SIZE = 8000
LEARNING_RATE = 0.003

# Hidden size of enc and dec need to be equal if last hidden of encoder becomes init hidden of decoder
# Otherwise we would need e.g. a linear layer to map to a space with the correct dimension
NUM_UNITS_ENC = NUM_UNITS_DEC = 96
TEST_SIZE = 200
EPOCHS = 100
TEACHER_FORCING = False

assert TRAINING_SIZE % BATCH_SIZE == 0

Device in use: cuda


For this exercise we won´t worry about data generation, but utilise a built function for this purpose. The function generates random data constained by the 27 characters described above.

The encoder takes as input the embedded text strings generated from the *generate* function as given here above ie. 'nine' would become [18 15 18 12].
Sequeneces are generated at random given settings of minima and maxima length, constrained by the dimensions of the two RNN´s architecture.
We may visualise a subset of the data generated by running the command below

In [294]:
!python ./data_generator.py

Generated batch length 3 from 3 iterations
input types: int32 int32 int32 int32 int32
Number of valid characters: 27
'0'=0,	'1'=1,	'2'=2,	'3'=3,	'4'=4,	'5'=5,	'6'=6,	'7'=7,	'8'=8,	'9'=9,	'#'=10,	' '=11,	'e'=12,	'g'=13,	'f'=14,	'i'=15,	'h'=16,	'o'=17,	'n'=18,	's'=19,	'r'=20,	'u'=21,	't'=22,	'w'=23,	'v'=24,	'x'=25,	'z'=26,	
Stop/start character = #

SAMPLE 0
TEXT INPUTS:			 three four eight two
ENCODED INPUTS:			 [22 16 20 12 12 11 14 17 21 20 11 12 15 13 16 22 11 22 23 17]
INPUTS SEQUENCE LENGTH:	 20
TEXT TARGETS INPUT:		 #3482
TEXT TARGETS OUTPUT:	 3482#
ENCODED TARGETS INPUT:	 [10  3  4  8  2]
ENCODED TARGETS OUTPUT:	 [ 3  4  8  2 10]
TARGETS SEQUENCE LENGTH: 5
TARGETS MASK:			 [ 1.  1.  1.  1.  1.]

SAMPLE 1
TEXT INPUTS:			 four eight
ENCODED INPUTS:			 [14 17 21 20 11 12 15 13 16 22  0  0  0  0  0  0  0  0  0  0]
INPUTS SEQUENCE LENGTH:	 10
TEXT TARGETS INPUT:		 #48
TEXT TARGETS OUTPUT:	 48#
ENCODED TARGETS INPUT:	 [10  4  8  0  0]
ENCODED TARGETS OUTPUT:	 [ 4  8 10  0  0]
TARGETS S

## Let's define the two RNN's



In [295]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, self.hidden_size)
        rnn = nn.LSTM
        self.rnn = rnn(self.hidden_size, self.hidden_size, batch_first=True)

    def forward(self, inputs, hidden):
        # Input shape [batch, seq_in_len]
        inputs = inputs.long()

        # Embedded shape [batch, seq_in_len, embed]
        embedded = self.embedding(inputs)
        
        # Output shape [batch, seq_in_len, embed]
        # Hidden shape [1, batch, embed], last hidden state of the GRU cell
        # We will feed this last hidden state into the decoder
        #print(hidden.size())
        output, hidden = self.rnn(embedded, hidden)
        return output, hidden

    def init_hidden(self, batch_size):
        init = torch.zeros(1, batch_size, self.hidden_size, device=device)
        return (init, init)


In [296]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
        rnn = nn.LSTM
        self.rnn = rnn(self.hidden_size, self.hidden_size, batch_first=True)

    def forward(self, inputs, hidden, output_len, teacher_forcing=False):
        # Input shape: [batch, output_len]
        # Hidden shape: [seq_len=1, batch_size, hidden_dim] (the last hidden state of the encoder)
    
        if teacher_forcing:
            dec_input = inputs
            embed = self.embedding(dec_input)   # shape [batch, output_len, hidden_dim]
            out, hidden = self.rnn(embed, hidden)
            out = self.out(out)  # linear layer, out has now shape [batch, output_len, output_size]
            output = F.log_softmax(out, -1)
        else:
            # Take the EOS character only, for the whole batch, and unsqueeze so shape is [batch, 1]
            # This is the first input, then we will use as input the GRU output at the previous time step
            
            dec_input = inputs[:, 0].unsqueeze(1)

            output = []
            for i in range(output_len):
                out, hidden = self.rnn(self.embedding(dec_input), hidden)
                out = self.out(out)  # linear layer, out has now shape [batch, 1, output_size]
                out = F.log_softmax(out, -1)
                output.append(out.squeeze(1))
                out_symbol = torch.argmax(out, dim=2)   # shape [batch, 1]
                dec_input = out_symbol   # feed the decoded symbol back into the recurrent unit at next step

            output = torch.stack(output).permute(1, 0, 2)  # [batch_size x seq_len x output_size]

        return output

The learned representation from the *Encoder* gets propagated to the *Decoder* as the final hidden layer in the *Encoder* network is set as initialisation for the *Decoder*'s first hidden layer.

In [297]:
def forward_pass(encoder, decoder, x, t, t_in, criterion, max_t_len, teacher_forcing):
    """
    Executes a forward pass through the whole model.

    :param encoder:
    :param decoder:
    :param x: input to the encoder, shape [batch, seq_in_len]
    :param t: target output predictions for decoder, shape [batch, seq_t_len]
    :param criterion: loss function
    :param max_t_len: maximum target length

    :return: output (after log-softmax), loss, accuracy (per-symbol)
    """
    # Run encoder and get last hidden state (and output)
    batch_size = x.size(0)

    enc_h = encoder.init_hidden(batch_size)
    enc_out, enc_h = encoder(x, enc_h)

    dec_h = enc_h  # Init hidden state of decoder as hidden state of encoder
    dec_input = t_in
    
    out = decoder(dec_input, dec_h, max_t_len, True)
    out = out.permute(0, 2, 1)
    # Shape: [batch_size x num_classes x out_sequence_len], with second dim containing log probabilities

    loss = criterion(out, t)
    pred = get_pred(log_probs=out)

    accuracy = (pred == t).type(torch.FloatTensor).mean()

    return out, loss, accuracy


In [298]:
def train(encoder, decoder, inputs, targets, targets_in, criterion, enc_optimizer, dec_optimizer, epoch, max_t_len):
    encoder.train()
    decoder.train()
    for batch_idx, (x, t, t_in) in enumerate(zip(inputs, targets, targets_in)):
        
        
        # getting the data on the device
        x = x.to(device)
        t = t.long().to(device)
        t_in = t_in.long().to(device)
        
        # compute the networks outputs with doing a forward pass
        # the forward pass calls succesively the encoder and the decoder and retrieves the output of the decoder
        out, loss, accuracy = forward_pass(encoder, decoder, x, t, t_in, criterion, max_t_len,teacher_forcing=TEACHER_FORCING)
        
        # zero the gradient of optimizers
        enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()
        
        # compute new gradients
        loss.backward()
        
        # finally update network
        enc_optimizer.step()
        dec_optimizer.step()
        
        
        if batch_idx % 20 == 0:
            print('Epoch {} [{}/{} ({:.0f}%)]\tTraining loss: {:.4f} \tTraining accuracy: {:.1f}%'.format(
                epoch, batch_idx * len(x), TRAINING_SIZE,
                100. * batch_idx * len(x) / TRAINING_SIZE, loss.item(),
                100. * accuracy.item()))


In [299]:
def test(encoder, decoder, inputs, targets, targets_in, criterion, max_t_len):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        inputs = inputs.to(device)
        targets = targets.long().to(device)
        targets_in = targets_in.long().to(device)
        out, loss, accuracy = forward_pass(encoder, decoder, inputs, targets, targets_in, criterion, max_t_len,
                                           teacher_forcing=TEACHER_FORCING)
    return out, loss, accuracy

In [300]:
def numbers_to_text(seq):
    return "".join([str(to_np(i)) if to_np(i) != 10 else '#' for i in seq])

def to_np(x):
    return x.cpu().numpy()

def get_pred(log_probs):
    """
    Get class prediction (digit prediction) from the net's output (the log_probs)
    :param log_probs: Tensor of shape [batch_size x n_classes x sequence_len]
    :return:
    """
    return torch.argmax(log_probs, dim=1)

In [301]:
encoder = EncoderRNN(NUM_INPUTS, NUM_UNITS_ENC).to(device)
decoder = DecoderRNN(NUM_UNITS_DEC, NUM_OUTPUTS).to(device)
enc_optimizer = optim.RMSprop(encoder.parameters(), lr=LEARNING_RATE)
dec_optimizer = optim.RMSprop(decoder.parameters(), lr=LEARNING_RATE)
criterion = nn.NLLLoss()

# Get training set
inputs, _, targets_in, targets, targets_seqlen, _, _, _, text_targ = generate(TRAINING_SIZE, min_len=MIN_SEQ_LEN, max_len=MAX_SEQ_LEN)
max_target_len = max(targets_seqlen)
inputs = torch.tensor(inputs)
targets = torch.tensor(targets)
targets_in = torch.tensor(targets_in)
unique_text_targets = set(text_targ)

# Get validation set
val_inputs, _, val_targets_in, val_targets, val_targets_seqlen, _, val_text_in, _, val_text_targ = \
    generate(TEST_SIZE, min_len=MIN_SEQ_LEN, max_len=MAX_SEQ_LEN, invalid_set=unique_text_targets)
val_inputs = torch.tensor(val_inputs)
val_targets = torch.tensor(val_targets)
val_targets_in = torch.tensor(val_targets_in)
max_val_target_len = max(val_targets_seqlen)




# Split training set in batches

inputs =[inputs[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] for i in range(TRAINING_SIZE // BATCH_SIZE)]
targets = [targets[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] for i in range(TRAINING_SIZE // BATCH_SIZE)]
targets_in = [targets_in[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] for i in range(TRAINING_SIZE // BATCH_SIZE)]



# Quick and dirty - just loop over training set without reshuffling
for epoch in range(1, EPOCHS + 1):
    train(encoder, decoder, inputs, targets, targets_in, criterion, enc_optimizer, dec_optimizer, epoch, max_target_len)
    _, loss, accuracy = test(encoder, decoder, val_inputs, val_targets, val_targets_in, criterion, max_val_target_len)
    print('\nTest set: Average loss: {:.4f} \tAccuracy: {:.3f}%\n'.format(loss, accuracy.item()*100.))

    # Show examples
    print("Examples: prediction | input")
    out, _, _ = test(encoder, decoder, val_inputs[:10], val_targets[:10], val_targets_in[:10], criterion, max_val_target_len)
    pred = get_pred(out)
    pred_text = [numbers_to_text(sample) for sample in pred]
    for i in range(10):
        print(pred_text[i], "\t", val_text_in[i])
    print()

Generated batch length 8000 from 8000 iterations
Generated batch length 200 from 200 iterations

Test set: Average loss: 1.7846 	Accuracy: 35.833%

Examples: prediction | input
0878377#0 	 four seven four zero six eight nine
330706#00 	 six three one four nine eight
22222#000 	 five nine two seven six
313046000 	 zero nine one zero nine zero
313011#00 	 zero two three zero six three
22266#000 	 six one three one nine
666620000 	 four three six eight zero
222202#00 	 two nine two five five four
22222##00 	 five five eight one two two
222222##0 	 six one nine five two six six


Test set: Average loss: 1.6691 	Accuracy: 39.000%

Examples: prediction | input
0069171#0 	 four seven four zero six eight nine
337337#00 	 six three one four nine eight
62290#000 	 five nine two seven six
333337#00 	 zero nine one zero nine zero
333333#00 	 zero two three zero six three
33779#000 	 six one three one nine
33337#000 	 four three six eight zero
222000#00 	 two nine two five five four
222222#00 	 fiv


Test set: Average loss: 0.6109 	Accuracy: 79.944%

Examples: prediction | input
4440689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023663#00 	 zero two three zero six three
61719#000 	 six one three one nine
43680#000 	 four three six eight zero
255554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6996666#0 	 six one nine five two six six


Test set: Average loss: 0.5075 	Accuracy: 83.722%

Examples: prediction | input
4440689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
090090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61199#000 	 six one three one nine
43680#000 	 four three six eight zero
255554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6995666#0 	 six one nine five two six six


Test set: Average loss: 0.4


Test set: Average loss: 0.3183 	Accuracy: 91.333%

Examples: prediction | input
4440689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558222400 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.3400 	Accuracy: 89.444%

Examples: prediction | input
4440889#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
011190#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195262#0 	 six one nine five two six six


Test set: Average loss: 0.2


Test set: Average loss: 0.2124 	Accuracy: 93.778%

Examples: prediction | input
4440689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122600 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1877 	Accuracy: 94.611%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1


Test set: Average loss: 0.1817 	Accuracy: 95.500%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091099#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1846 	Accuracy: 95.389%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1


Test set: Average loss: 0.1718 	Accuracy: 95.056%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091099#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1692 	Accuracy: 95.278%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091099#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1


Test set: Average loss: 0.1760 	Accuracy: 95.333%

Examples: prediction | input
4440689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1659 	Accuracy: 95.611%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1


Test set: Average loss: 0.2397 	Accuracy: 93.944%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122200 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1664 	Accuracy: 95.556%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1


Test set: Average loss: 0.1681 	Accuracy: 95.833%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023068#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1631 	Accuracy: 95.722%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023063#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1


Test set: Average loss: 0.1491 	Accuracy: 95.944%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023061#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195262#0 	 six one nine five two six six


Test set: Average loss: 0.1638 	Accuracy: 95.333%

Examples: prediction | input
4740689#0 	 four seven four zero six eight nine
631498#00 	 six three one four nine eight
59276#000 	 five nine two seven six
091090#00 	 zero nine one zero nine zero
023061#00 	 zero two three zero six three
61319#000 	 six one three one nine
43680#000 	 four three six eight zero
292554#00 	 two nine two five five four
558122#00 	 five five eight one two two
6195266#0 	 six one nine five two six six


Test set: Average loss: 0.1

# Exercise:

1. Implement missing code for the network in the *train* function. Your validation accuracy is expected to be <20% at this point.
2. These networks implement the GRU-gates. Implement an alternative control utilising a memory mechanism (Hint: LSTM). What do you experience? 
3. There are some parameters in the model that may be optimized further, what could they be? Achieve >90% validation accuracy.

### Assignment answers

With the "vanilla" network I reached arround $55\%$ accuracy. Then, I change the GRU-gates with LSTM ones and got over $75\%$ accuracy. After that, some parameters could still be optimized further like batch size, number of hidden unit in encoder and decoder, learning rates, number of epochs and teacher forcing. 

I first, switch `TEACHER_FORCING` to **True**, and it was a big improvment, I reached an accuracy of $83\%$. But I had to add a few epochs to get a full view on the performances at this point.  

Finally, a tweaked a bit further some parameters, like batch size and number of hidden units and I manage to get an accuracy of $94.6\%$ after 100 epochs, which seems quite good. However I should have stop the learning earlier because the network was not really imporving after epoch 41. However the accuracy peak was reached at epoch 89 (with $96.3\%$ accuracy). 

Final values of the parameters (and answer to question 3):

* `BATCH_SIZE = 80`
* `NUM_UNITS_ENC = NUM_UNITS_DEC = 96`
* `TEACHER_FORCING = True`
* `NUM_EPOCH = 100`
* `LEARNING_RATE = 0.003`



### Exercise from Michael Nielsen's book

We have to prove that the standard deviation of $z=\sum_j w_jx_j + b$ is $\sqrt{3/2}$ if the weights are randomly initialized from a distribution that has a standard deviation of $\frac{1}{\sqrt{n_{in}}}$ where $n_{in}$ is the number of input weights. Bias should still be initialized from a distribution with a standard deviation of $1$ though. 

First we compute the variance of $z$.

$$ V(z) = V(\sum_j w_jx_j + b) = \sum_j V(w_j)x_j + V(b) $$

The last equality comes from the fact that the $w_j$ and the bias are independant random variables and in such circumstances, the variance is linear. 

We did not mention it before but the standard deviation of $z$ should be computed for $n_{in}=1000$ and we have to consider that half of the $x_j$ are $0$ while the other half is $1$.

Then we got:

$$V(z) = \sum_{j=1}^{500}V(w_j) + V(b) = 500.\frac{1}{1000} + 1 =\frac{3}{2}$$

Finally, $$\sigma(z) = \sqrt{V(z)} = \sqrt{3/2}$$.