# Code Generation as a Dual Task of Code Summarization

```
@article{wei2019code,
  title={Code Generation as a Dual Task of Code Summarization},
  author={Wei, Bolin and Li, Ge and Xia, Xin and Fu, Zhiyi and Jin, Zhi},
  journal={arXiv preprint arXiv:1910.05923},
  year={2019}
}
```

<img src='https://i.imgur.com/RqN1agC.png' width='600' align='left'>

## References
- https://www.tensorflow.org/tutorials/text/nmt_with_attention
- https://blog.floydhub.com/attention-mechanism/

## Definitions

$x \; \text{: code snippets}, \; y \; \text{: comments}$

$P(x,y) = \color{#00a010}{P(x) \cdot P(y|x)} = \color{#a010a0}{P(y) \cdot P(x|y)}$

### Loss terms

$l_{xy} = -\frac{1}{m} \sum_{t=1}^{m} P(y_t | y_{\lt t}, x)$

$l_{yx} = -\frac{1}{n} \sum_{t=1}^{n} P(x_t | x_{\lt t}, y)$

$l_{dual}=\left[ \left(\color{#00a010}{\log\hat{P}(x) + \log P(y \vert x; \theta_{xy})} \right) - \left(\color{#a010a0}{\log\hat{P}(y) + \log P(x \vert y; \theta_{yx})} \right) \right]^{2} \text{ : regularization term}$

$l_{att} = l_1 + l_2, \text{ where } l_k = \mathcal{D}_{JS} \left( b_i, b_i' \right ) = \frac{1}{2n}\sum_{i=1}^{n} \mathcal{D}_{KL} \left(b_i \, || \, \frac{b_i + b_i'}{2} \right) + \mathcal{D}_{KL} \left(b_i' \, || \, \frac{b_i + b_i'}{2} \right)$

$b_i = softmax \left( A_{xy}[i, :] \right), \; b_i' = softmax \left( A_{yx}[i, :] \right)$

$A_{xy} \in \mathbb{R}^{n \times m}, \; A_{yx} \in \mathbb{R}^{m \times n} \text{ : attention weights}$

### Updates

$\text{Minibatch of } k \text{ pairs: } \langle \left(x_i, y_i\right) \rangle_{i=1}^{k}$

$
\begin{cases}
G_{xy} = \nabla_{\theta_{xy}} \frac{1}{k} \sum_{i=1}^{k} \left( l_{xy} + \lambda_{dual}^{(1)} \cdot l_{dual} + \lambda_{att}^{(1)} \cdot l_{att} \right)\\
G_{yx} = \nabla_{\theta_{yx}} \frac{1}{k} \sum_{i=1}^{k} \left( l_{yx} + \lambda_{dual}^{(2)} \cdot l_{dual} + \lambda_{att}^{(2)} \cdot l_{att} \right)
\end{cases}
$

$\text{Update } \theta_{xy} \text{ and } \theta_{yx} \text{ independently}$

### Notes
- The last encoder's hidden state is used to init the decoder's hidden state.

### Hyperparameters

In [None]:
import os
from argparse import Namespace

import numpy as np
import pandas as pd

from timeit import default_timer as timer
from tqdm.auto import tqdm
tqdm.pandas()

import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Embedding, Bidirectional, LSTM
from tensorflow.keras.initializers import Constant
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

from datasets import Django

### Globals

In [None]:
EMB_DIR    = '/home/alex/workspace/msc-research/embeddings'
DJANGO_DIR = '/home/alex/workspace/msc-research/raw-datasets/django/'

### Hyperparameters

In [None]:
HP = Namespace()
HP.batch_size = 5
HP.epochs     = 1

### Dataset

In [None]:
HP.dataset_config = Namespace()
HP.dataset_config.__dict__ = {
    'p_split': 0.8,
    'anno_seq_maxlen': 40,
    'code_seq_maxlen': 20,
    'emb_file': os.path.join(EMB_DIR, 'glove.6B.50d.txt.pickle')
}

django = Django(root_dir=DJANGO_DIR, config=HP.dataset_config)

### Encoder

In [None]:
class Encoder(tf.keras.Model):
    def __init__(self, emb_matrix, hidden_size, input_maxlen, batch_size):
        super(Encoder, self).__init__()
        
        self.batch_size   = batch_size
        self.hidden_size  = hidden_size
        self.input_maxlen = input_maxlen
        
        self.vocab_size, self.emb_dim = emb_matrix.shape
        
        self.embedding = Embedding(input_dim=self.vocab_size, 
                                   output_dim=self.emb_dim, 
                                   embeddings_initializer=Constant(emb_matrix), 
                                   input_length=input_maxlen, 
                                   trainable=False)
        
        self.lstm = LSTM(self.hidden_size, 
                         return_sequences=True, 
                         return_state=True, 
                         recurrent_initializer='glorot_uniform')
        
        self.bidir_lstm = Bidirectional(self.lstm, merge_mode='concat')

    def call(self, x, hidden=None):
        if hidden is None:
            hidden = self.init_hidden()
        
        out, h0, c0, h1, c1 = self.bidir_lstm(self.embedding(x), hidden)
        
        return out, (h0, c0), (h1, c1)

    def init_hidden(self):
        z = tf.zeros((self.batch_size, self.hidden_size))
        return (z,z) * 2

#### Test

In [None]:
bs = 1
enc = Encoder(emb_matrix=django.emb_matrix,
             hidden_size=1024,
             input_maxlen=HP.dataset_config.anno_seq_maxlen,
             batch_size=bs)

x_batch = np.array([django.x_train[i] for i in range(bs)])

o, _, _ = enc(x_batch)

### Luong's Attention

In [None]:
class LuongAttention(tf.keras.Model):
    def __init__(self, rnn_size):
        super(LuongAttention, self).__init__()
        
        self.W = tf.keras.layers.Dense(rnn_size)

    def call(self, decoder_output, encoder_output):
        # score: h_t x W x h_s
        # encoder_output (h_s) shape: (batch_size, max_len, rnn_size)
        # decoder_output (h_t) shape: (batch_size, 1, rnn_size)
        # score will have shape: (batch_size, 1, max_len)
        
        score = tf.matmul(decoder_output, self.W(encoder_output), transpose_b=True)
        alignment = tf.nn.softmax(score, axis=2)
        context = tf.matmul(alignment, encoder_output)

        return context, alignment

#### Test

In [None]:
att = LuongAttention(20)
bs = 1
e = np.random.rand(bs, 10, 20).astype('f')
d = np.random.rand(bs, 1, 20).astype('f')
c, a = att(d, e)

### Decoder

In [None]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, hidden_size, batch_size):
        super(Decoder, self).__init__()
        
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        
        self.lstm = tf.keras.layers.LSTM(self.hidden_size,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')
        
        self.fc = tf.keras.layers.Dense(vocab_size)

        self.attention = LuongAttention(self.hidden_size)

    def call(self, x, hidden, enc_output):
        # enc_output shape == (batch_size, max_length, hidden_size)
        context_vector, attention_weights = self.attention(hidden, enc_output)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([x, tf.expand_dims(context_vector, 1)], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.lstm(x)

        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))

        # output shape == (batch_size, vocab)
        x = self.fc(output)

        return x, state, attention_weights

#### Test

## Training

In [None]:
encoder = Encoder(emb_matrix=django.emb_matrix,
                  hidden_size=1024,
                  input_maxlen=HP.dataset_config.anno_seq_maxlen,
                  batch_size=HP.batch_size)

decoder = Decoder(vocab_size=100,
                  embedding_dim=30,
                  hidden_size=1024,
                  batch_size=HP.batch_size)

In [None]:
@tf.function
def train_step(inp, targ, enc_hidden):
    loss = 0

    with tf.GradientTape() as tape:
        enc_output, enc_hidden = encoder(inp, enc_hidden)
        dec_hidden = enc_hidden

        dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)

        # Teacher forcing - feeding the target as the next input
        for t in range(1, targ.shape[1]):
            # passing enc_output to the decoder
            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

            loss += loss_function(targ[:, t], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(targ[:, t], 1)

    batch_loss = (loss / int(targ.shape[1]))
    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return batch_loss

In [None]:
for epoch in range(HP.epochs):
    t_start = timer()

    enc_hidden = encoder.init_hidden()
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(inp, targ, enc_hidden)
        total_loss += batch_loss

    if batch % 100 == 0:
        print(f'Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy() :.5f}')
        
#     if (epoch + 1) % 2 == 0:
#         checkpoint.save(file_prefix=checkpoint_prefix)

    print(f'Epoch {epoch+1} Loss {total_loss/steps_per_epoch :.5f}')
    print(f'Time taken for 1 epoch {timer() - t_start :.4f} sec\n')