<a href="https://colab.research.google.com/github/m3yrin/code2seq/blob/master/code2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A PyTorch re-implementation code for "code2seq: Generating Sequences from Structured Representations of Code"

* Paper(Arxiv) : https://arxiv.org/abs/1808.01400  
* Official Github : https://github.com/tech-srl/code2seq

Apr 30, 2019 : v3

In [0]:
# Load dataset before run.
!wget https://s3.amazonaws.com/code2seq/datasets/java-small-preprocessed.tar.gz
!tar -xvzf java-small-preprocessed.tar.gz
!ls java-small

In [0]:
import random
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import pickle
from tqdm import tqdm_notebook as tqdm

from nltk import bleu_score


import torch
from torch import einsum
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

torch.manual_seed(1)
random_state = 42

import warnings
warnings.filterwarnings('ignore')

In [0]:
class Vocab(object):
    def __init__(self, word2id={}):
        """
        word2id: dict str to int
        id2word: dict int to str
        """
        self.word2id = dict(word2id)
        self.id2word = {v: k for k, v in self.word2id.items()}    
        
    def build_vocab(self, sentences, min_count=1):
        # counting population of each word
        word_counter = {}
        for word in sentences:
            word_counter[word] = word_counter.get(word, 0) + 1

        # use only the words that  word_count > min_count
        for word, count in sorted(word_counter.items(), key=lambda x: -x[1]):
            if count < min_count:
                break
            _id = len(self.word2id)
            self.word2id.setdefault(word, _id)
            self.id2word[_id] = word 

In [0]:
PAD_TOKEN = '<PAD>' 
BOS_TOKEN = '<S>' 
EOS_TOKEN = '</S>'
UNK_TOKEN = '<UNK>'
PAD = 0
BOS = 1
EOS = 2
UNK = 3

In [0]:
# load vocab dict
with open('java-small/java-small.dict.c2s', 'rb') as file:
    subtoken_to_count = pickle.load(file)
    node_to_count = pickle.load(file) 
    target_to_count = pickle.load(file)
    max_contexts = pickle.load(file)
    num_training_examples = pickle.load(file)
    print('Dictionaries loaded.')

In [0]:
# making vocab dicts for terminal subtoken, nonterminal node and target.

word2id = {
    PAD_TOKEN: PAD,
    BOS_TOKEN: BOS,
    EOS_TOKEN: EOS,
    UNK_TOKEN: UNK,
    }

vocab_subtoken = Vocab(word2id=word2id)
vocab_nodes = Vocab(word2id=word2id)
vocab_target = Vocab(word2id=word2id)


vocab_subtoken.build_vocab(list(subtoken_to_count.keys()), min_count=0)
vocab_nodes.build_vocab(list(node_to_count.keys()), min_count=0)
vocab_target.build_vocab(list(target_to_count.keys()), min_count=0)

vocab_size_subtoken = len(vocab_subtoken.id2word)
vocab_size_nodes = len(vocab_nodes.id2word)
vocab_size_target = len(vocab_target.id2word)

print('vocab_size_subtoken：', vocab_size_subtoken)
print('vocab_size_nodes：', vocab_size_nodes)
print('vocab_size_target：', vocab_size_target)

num_length_train = num_training_examples
print('num_examples : ', num_length_train)

In [0]:
def sentence_to_ids(vocab, sentence):
    # translate word to id
    ids = [vocab.word2id.get(word, UNK) for word in sentence]
    ids += [EOS]  # adding EOS to the end of sentence
    return ids
  
  
def pad_seq(seq, max_length):
    # pad tail of sequence to extend sequence length up to max_length
    res = seq + [PAD for i in range(max_length - len(seq))]
    return res 

def ids_to_sentence(vocab, ids):
    return [vocab.id2word[_id] for _id in ids]

def trim_eos(ids):
    # trim tokens after eos
    if EOS in ids:
        return ids[:ids.index(EOS)]
    else:
        return ids

In [0]:
class DataLoader(object):

    def __init__(self, data_path, num_examples, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, sector_size = 1):
        
        """
        data_path : path for data 
        num_examples : total lines of data file
        batch_size : batch size
        num_k : max ast pathes included to one examples
        vocab_subtoken : dict of subtoken and its id
        vocab_nodes : dict of node simbol and its id
        vocab_target : dict of target simbol and its id
        sector_size : data/sector_size is stored on memory. if you have enough RAM, sector_size = 1. For Google Coraboratory, sector_size = 10.
        """
        
        self.batch_size = batch_size
        self.start_index = 0
        self.num_examples = num_examples
        self.num_k = num_k
        self.data_path = data_path
        
        self.vocab_subtoken = vocab_subtoken
        self.vocab_nodes = vocab_nodes
        self.vocab_target = vocab_target
        
        self.sector_size = sector_size
        self.sector = 0
        
        print('Sector size :', sector_size)
        print(self.num_examples//self.sector_size//self.batch_size, 'iter / sector')
        
        self.reset()

    
    def reset(self):
        self.data = self.data_sampler()
        self.start_index = 0 
    
    def __iter__(self):
        return self

    def __next__(self):
      
        if self.start_index >= self.num_examples//self.sector_size:
            print('Sector', self.sector + 1, ' / ', self.sector_size, ' done.')
            self.sector += 1
            
            if self.sector >= self.sector_size:
                self.sector = 0
                self.reset()
                raise StopIteration()
            
            self.reset()
            
        seqs_S, seqs_N, seqs_E, seqs_Y = self.data
        
        batch_seqs_S = seqs_S[self.start_index: self.start_index + self.batch_size]
        batch_seqs_N = seqs_N[self.start_index: self.start_index + self.batch_size]
        batch_seqs_E = seqs_E[self.start_index: self.start_index + self.batch_size]
        batch_seqs_Y = seqs_Y[self.start_index: self.start_index + self.batch_size]
        
        # length_k : (batch_size, k)
        lengths_k = [len(ex) for ex in batch_seqs_N]
        
        #
        # flattening (batch_size, k, l) to (batch_size * k, l)
        # this is useful to make torch.tensor
        
        batch_seqs_S = [symbol for k in batch_seqs_S for symbol in k]
        batch_seqs_N = [symbol for k in batch_seqs_N for symbol in k] 
        batch_seqs_E = [symbol for k in batch_seqs_E for symbol in k] 
        
        #
        # Padding
        #
        
        lengths_S = [len(s) for s in batch_seqs_S]
        lengths_N = [len(s) for s in batch_seqs_N]
        lengths_E = [len(s) for s in batch_seqs_E]
        lengths_Y = [len(s) for s in batch_seqs_Y]
        
        max_length_S = max(lengths_S)
        max_length_N = max(lengths_N)
        max_length_E = max(lengths_E)
        max_length_Y = max(lengths_Y)

        padded_S = [pad_seq(s, max_length_S) for s in batch_seqs_S]
        padded_N = [pad_seq(s, max_length_N) for s in batch_seqs_N]
        padded_E = [pad_seq(s, max_length_E) for s in batch_seqs_E]
        padded_Y = [pad_seq(s, max_length_Y) for s in batch_seqs_Y]
        
        # index for split (batch_size * k, l) into (batch_size, k, l)
        index_N = range(len(lengths_N))
        
        
        # sort for rnn
        seq_pairs = sorted(zip(lengths_N, index_N, padded_N, padded_S, padded_E), key=lambda p: p[0], reverse=True)
        lengths_N, index_N, padded_N, padded_S, padded_E = zip(*seq_pairs)
        
        batch_S = torch.tensor(padded_S, dtype=torch.long, device=device)
        batch_E = torch.tensor(padded_E, dtype=torch.long, device=device)
        
        # transpose for rnn
        batch_N = torch.tensor(padded_N, dtype=torch.long, device=device).transpose(0, 1)
        batch_Y = torch.tensor(padded_Y, dtype=torch.long, device=device).transpose(0, 1)
        
        # update index
        self.start_index += self.batch_size

        return batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N
    
    def data_sampler(self):
        
        seqs_S = []
        seqs_E = []
        seqs_N = []
        seqs_Y = []
        
        
        with open(self.data_path, 'r') as f:
          
            for i, line in enumerate(f) :
                
                if ( i < self.num_examples // self.sector_size * self.sector):
                  continue
                
                if (i >= self.num_examples // self.sector_size * (self.sector + 1)):
                    break
                
                
                seq_S = []
                seq_N = []
                seq_E = []
                
                target, *syntax_path = line.split(' ')
                target = target.split('|')
                target =  sentence_to_ids(self.vocab_target, target)
                
                # remove '' and '\n' in sequence, java-small dataset contains many '' in a line.
                syntax_path = [s for s in syntax_path if s != '' and s != '\n']

                # if the amount of ast path exceed the k, uniformly sample ast pathes, as described in the paper.
                if len(syntax_path) > self.num_k:
                    sampled_path_index = random.sample(range(len(syntax_path)) , self.num_k)
                else :
                    sampled_path_index = range(len(syntax_path))
                
                
                for j in sampled_path_index:
                    terminal1, ast_path, terminal2 = syntax_path[j].split(',')

                    terminal1 = sentence_to_ids(self.vocab_subtoken, terminal1.split('|'))
                    ast_path = sentence_to_ids(self.vocab_nodes, ast_path.split('|'))
                    terminal2 = sentence_to_ids(self.vocab_subtoken, terminal2.split('|')) 

                    seq_S.append(terminal1)
                    seq_E.append(terminal2)
                    seq_N.append(ast_path)
                
                seqs_S.append(seq_S)
                seqs_E.append(seq_E)
                seqs_N.append(seq_N)
                seqs_Y.append(target)

        return seqs_S, seqs_N, seqs_E, seqs_Y


In [0]:
class Encoder(nn.Module):
    def __init__(self, input_size_subtoken, input_size_node, token_size, hidden_size, rnn_dropout = 0.5, embeddings_dropout = 0.25):
        
        """
        input_size_subtoken : # of unique subtoken
        input_size_node : # of unique node symbol
        token_size : embedded token size
        hidden_size : size of initial state of decoder
        rnn_dropout = 0.5 : rnn drop out ratio
        embeddings_dropout = 0.25 : dropout ratio for context vector
        """
        
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.token_size = token_size

        self.embedding_subtoken = nn.Embedding(input_size_subtoken, token_size, padding_idx=PAD)
        self.embedding_node = nn.Embedding(input_size_node, token_size, padding_idx=PAD)
        
        self.lstm = nn.LSTM(token_size, token_size, bidirectional=True, dropout=rnn_dropout)
        self.out = nn.Linear(token_size * 4, hidden_size)
        
        self.dropout = nn.Dropout(embeddings_dropout)

    def forward(self, batch_S, batch_N, batch_E, lengths_k, index_N, hidden=None):
        
        """
        batch_S : (B * k, l) start terminals' subtoken of each ast path
        batch_N : (l, B*k) nonterminals' nodes of each ast path
        batch_E : (B * k, l) end terminals' subtoken of each ast path
        
        lengths_k : length of k in each example
        index_N : index for unsorting,
        """
        
        output_bag = []
        hidden_batch = []
        
        
        # (B * k, l, d)
        encode_S = self.embedding_subtoken(batch_S)
        encode_E = self.embedding_subtoken(batch_E)
        
        
        # encode_S (B * k, d) token_representation of each ast path
        encode_S = encode_S.sum(1)
        encode_E = encode_E.sum(1)
        
        #print('encode_S', encode_S)
        #print('encode_E', encode_E)
        
        
        # emb_N :(l, B*k, d)
        
        emb_N = self.embedding_node(batch_N)
        packed = pack_padded_sequence(emb_N, lengths_N)
        output, _ = self.lstm(packed, hidden)
        output, _ = pad_packed_sequence(output)
        
        # output of shape (seq_len, B * k, num_directions * d)
        # -> (B * k, seq_len, num_directions * d)
        output = output.transpose(0, 1)
        
        #For the unpacked case, the directions can be separated using 
        # output.view(seq_len, batch, num_directions, hidden_size), 
        # with forward and backward being direction 0 and 1 respectively.
        # -> (B * k, seq_len, num_directions,  d)
        output = output.view(batch_N.shape[1], batch_N.shape[0], 2, self.token_size)
        
        
        
        #a b c d e
        #f g h 0 0
        #i j 0 0 0
        #k 0 0 0 0
        # output_normal should be [e, h, j, k] 
        # output_reverse          [a, f, i, k]
        
        ln = [lengths_N[0] * i + (l - 1) for i,l in enumerate(lengths_N)]
        output_normal = output[:, :,  0, :]
        output_normal = output_normal.contiguous().view(-1, self.token_size)
        output_normal = output_normal[ln]
        
        output_reverse = output[:,  0,  1, :].view(batch_N.shape[1], self.token_size)
        
        # encode_N  :(B*k, 2d)
        encode_N = torch.cat([output_normal, output_reverse], dim=1)
        
        # encode_SNE  : (B*k, 4d)
        encode_SNE = torch.cat([encode_N, encode_S, encode_E], dim=1)
        
        # encode_SNE  : (B*k, d)
        encode_SNE = self.out(encode_SNE)
        
        # unsort as example
        index = torch.tensor(index_N, dtype=torch.long, device=device)
        encode_SNE = torch.index_select(encode_SNE, dim=0, index=index)
        
        # as is in  https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L511
        encode_SNE = self.dropout(encode_SNE)
        
        # output_bag  : [ B, (k, d) ]
        output_bag = torch.split(encode_SNE, lengths_k, dim=0)
        # hidden_0  : [ B, (d) ]
        hidden_0 = [ob.mean(0).unsqueeze(dim=0) for ob in output_bag]

        # hidden_0  : (1, B, d)
        # size should be like (1, batch_size, hidden_size)
        hidden_0 = torch.cat(hidden_0, dim=0).unsqueeze(dim=0)
        
        return output_bag, hidden_0


In [0]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, rnn_dropout):
        """
        hidden_size : decoder unit size, 
        output_size : decoder output size, 
        rnn_dropout : dropout ratio for rnn
        """
        
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD)
        self.gru = nn.GRU(hidden_size, hidden_size, dropout=rnn_dropout)
        
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, seqs, hidden):
        
        emb = self.embedding(seqs)
        output, hidden = self.gru(emb, hidden)
        output = self.out(output)
        return output, hidden

#### memo. Attention and its score function

In this imprementation, score function is:
$$
   \mathrm{score}(\bar{h}_s, h_t) = h_t^{\mathrm{T}} W_a \bar{h}_s .
$$

The weight is
$$
    a_t(s) = \frac{\exp(\mathrm{score}(\bar{h}_s, h_t))}{\sum^S_{s'=1}\exp(\mathrm{score}(\bar{h}_s, h_t))} .
$$

And so on...
$$
    c_t = \sum^S_{s=1} a_t(s) \bar{h}_s
$$
$$
    \tilde{h}_t = \tanh(W_h h_t + W_c c_t + b)
$$
$$
    y_t = \mathrm{softmax}(W_{out}\tilde{h}_t + b_{out})
$$

In [0]:
class EncoderDecoder_with_Attention(nn.Module):
    
    """Conbine Encoder and Decoder"""
    
    def __init__(self, input_size_subtoken, input_size_node, token_size, output_size, hidden_size, rnn_dropout = 0.5, embeddings_dropout = 0.25):

        super(EncoderDecoder_with_Attention, self).__init__()
        self.encoder = Encoder(input_size_subtoken, input_size_node, token_size, hidden_size, rnn_dropout, embeddings_dropout)
        self.decoder = Decoder(hidden_size, output_size, rnn_dropout)
        
        self.W_a  = torch.rand((hidden_size, hidden_size), dtype=torch.float,device=device , requires_grad=True)
        self.W_cc = torch.rand((hidden_size, hidden_size), dtype=torch.float,device=device , requires_grad=True)
        self.W_ch = torch.rand((hidden_size, hidden_size), dtype=torch.float,device=device , requires_grad=True)
        self.b    = torch.rand(hidden_size, dtype=torch.float, device=device, requires_grad=True)
        
        nn.init.xavier_uniform_(self.W_a)
        nn.init.xavier_uniform_(self.W_cc)
        nn.init.xavier_uniform_(self.W_ch)
        
        

    def forward(self, batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S, max_length_N,max_length_E,max_length_Y, lengths_k, index_N, terget_max_length, batch_Y=None, use_teacher_forcing=False):

        # Encoder
        encoder_output_bag, encoder_hidden = \
        self.encoder(batch_S, batch_N, batch_E, lengths_k, index_N)
        _batch_size = len(encoder_output_bag)
        
        # calc initial decoder state with attention
        decoder_hidden = self.attention(encoder_output_bag, encoder_hidden, lengths_k)
        
        # make initial input for decoder
        decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device)
        decoder_input = decoder_input.unsqueeze(0)  # (1, batch_size)
        
        # output holder
        decoder_outputs = torch.zeros(terget_max_length, _batch_size, self.decoder.output_size, device=device)
        
        for t in range(terget_max_length):
            
            # Decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            
            # calc next hidden state w/ attention
            decoder_hidden = self.attention(encoder_output_bag, decoder_hidden, lengths_k)
            
            # hold output
            decoder_outputs[t] = decoder_output
            
            # Teacher Forcing
            if use_teacher_forcing and batch_Y is not None:
                decoder_input = batch_Y[t].unsqueeze(0)
            else: 
                decoder_input = decoder_output.max(-1)[1]
        
        return decoder_outputs
    
    def attention(self, encoder_output_bag, hidden, lengths_k):
        
        """
        encoder_output_bag : (batch * k, hidden_size) bag of embedded ast path
        hidden : (batch, hidden_size)  previous hidden state
        lengths_k : (batch, 1) length of k in each example
        """
        
        # d_hidden:(1, sample, hidden_size) -> (sample, hidden_size)
        d_hidden = hidden.squeeze(0)
        
        # d_hidden:(sample, hidden_size) 
        # -> (sample * k, hidden_size)
        
        index = torch.cat([ torch.tensor([i] * k, dtype=torch.long, device=device) for i,k in enumerate(lengths_k) ], dim =0)
        d_hidden = torch.index_select(d_hidden, dim=0, index=index)
        
        # e_output : (sample * k, hidden_size)
        e_output = torch.cat(encoder_output_bag, dim=0)
        
        # e_output: [sample * num_k(i), hidden_size(j)]
        # self.W_a  : [hidden_size(j), hidden_size(k)]
        # -> : [sample * num_k(i), hidden_size(k)]
        score = einsum('ij,jk->ik', e_output, self.W_a)
        
        # d_hidden: [sample * k(i), hidden_size(j)]
        # score:    [sample * k(i), hidden_size(j)]
        # -> score: [sample * k(i), 1]
        
        score = torch.einsum('ij,ij->i', d_hidden, score).unsqueeze(1)
        
        # score: [sample * k(i), 1]
        # -> [sample, k, 1]
        score = torch.split(score, lengths_k, dim=0)
        
        #  attn_weights: [sample, k, 1]
        attn_weights = [F.softmax(s, dim=0) for s in score]
        
        # aw: [k(i), 1(j)]
        # eo: [k(i), hidden_size(k)]
        # -> [1(j), hidden_size(k)]
        context_vector = [torch.einsum('ij,ik->jk', aw, eo) for aw, eo in zip(attn_weights, encoder_output_bag)]
        context_vector = torch.cat(context_vector, dim=0)
        
        # context_vector : (sample(i), hidden_size(j))
        # self.W_cc : (hidden_size(j), hidden_size(k))
        # -> (sample(i), hidden_size(k))
        
        # hidden : (1(i), sample(j), hidden_size(k))
        # self.W_ch : (hidden_size(k), hidden_size(l))
        # -> (sample(j), hidden_size(l))
        
        # decoder_hidden : (sample, hidden_size)
        decoder_hidden = F.tanh(torch.einsum('ij,jk->ik', context_vector, self.W_cc) + 
                                torch.einsum('ijk,kl->jl', hidden, self.W_ch) + 
                                self.b)
        
        # decoder_hidden : (1, sample, hidden_size)
        decoder_hidden = decoder_hidden.unsqueeze(0)
        
        return decoder_hidden
    


In [0]:
mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD)
def masked_cross_entropy(logits, target):
    return mce(logits.view(-1, logits.size(-1)), target.view(-1))

In [0]:
#
# make dataloader instances
#

BATCH_SIZE = 512
DEV_RATIO = 0.1
NUM_K = 200

SECTOR_SIZE = 10

DATA_PATH_TRAIN = 'java-small/java-small.train.c2s'
DATA_PATH_VALID = 'java-small/java-small.test.c2s'

train_loader_param = {
    'data_path' : DATA_PATH_TRAIN,
    'num_examples' : num_length_train,
    'batch_size' : BATCH_SIZE,
    'num_k' : NUM_K,
    'vocab_subtoken' : vocab_subtoken,
    'vocab_nodes' : vocab_nodes,
    'vocab_target' : vocab_target,
    'sector_size' : SECTOR_SIZE
}

# $ wc -l java-small.test.c2s
# 57088 java-small.test.c2s

valid_loader_param = {
    'data_path' : DATA_PATH_VALID,
    'num_examples' : 57088,
    'batch_size' : BATCH_SIZE,
    'num_k' : NUM_K,
    'vocab_subtoken' : vocab_subtoken,
    'vocab_nodes' : vocab_nodes,
    'vocab_target' : vocab_target,
    'sector_size' : SECTOR_SIZE
}

train_dataloader = DataLoader(**train_loader_param)
valid_dataloader = DataLoader(**valid_loader_param)

In [0]:
#
# make instances of optimizer and model
#

NUM_EPOCHS = 30

INITIAL_RL = 0.01
WEIGHT_DECAY = 0.01
MOMENTUM = 0.95
NESTEROV = True

DECAY_RATIO = 0.95

RNN_DROPOUT = 0.5
EMBEDDING_DROPOUT = 0.25

TOKEN_SIZE = 128
HIDDEN_SIZE = 320

model_args = {
    'input_size_subtoken' : vocab_size_subtoken,
    'input_size_node' : vocab_size_nodes,
    'output_size' : vocab_size_target,
    'hidden_size' : HIDDEN_SIZE, 
    'token_size' : TOKEN_SIZE,
    'rnn_dropout' : RNN_DROPOUT, 
    'embeddings_dropout' : EMBEDDING_DROPOUT
}

model = EncoderDecoder_with_Attention(**model_args).to(device)

optimizer = optim.SGD(model.parameters(), lr=INITIAL_RL, weight_decay=WEIGHT_DECAY, momentum=MOMENTUM, nesterov = NESTEROV)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: DECAY_RATIO ** epoch)


In [0]:
def compute_loss(batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, model, optimizer=None, is_train=True):
    
    # predict, calc loss and backward.
    
    model.train(is_train)
    
    use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)
    
    
    target_max_length = batch_Y.size(0)
    pred_Y = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)
    
    loss = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())
    
    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    batch_Y = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()
    pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()
    
    
    return loss.item(), batch_Y, pred

In [0]:
def calc_bleu(refs, hyps):
    _refs = [[ref[:ref.index(EOS)]] for ref in refs]
    _hyps = [hyp[:hyp.index(EOS)] if EOS in hyp else hyp for hyp in hyps]
    
    return 100 * bleu_score.corpus_bleu(_refs, _hyps)

In [0]:
def calculate_results(refs, preds):
    #calc precision, recall and F1
    #same as https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L239
    
    filterd_refs = [ref[:ref.index(EOS)] for ref in refs]
    filterd_preds = [pred[:pred.index(EOS)] if EOS in pred else pred for pred in preds]
    
    true_positive, false_positive, false_negative = 0, 0, 0

    for filterd_pred, filterd_ref in zip(filterd_preds, filterd_refs):

        if filterd_pred == filterd_ref:
            true_positive += len(filterd_pred)
            continue

        for fp in filterd_pred:
            if fp in filterd_ref:
                true_positive += 1
            else:
                false_positive += 1

        for fr in filterd_ref:
            if not fr in filterd_pred:
                false_negative += 1

    if true_positive + false_positive > 0:
        precision = true_positive / (true_positive + false_positive) 
    else:
        precision = 0

    if true_positive + false_negative > 0:
        recall = true_positive / (true_positive + false_negative)
    else:
        recall = 0

    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0
    
    return precision, recall, f1


In [0]:

#
# Training Loop
# 

teacher_forcing_rate = 0.2
save_path = 'model.pth'
best_valid_f1 = 0.

# at first several iters, the score is unstable and would be not good for saving.
# Model is saved after "start_saving_epoch"
start_saving_epoch = 2


for epoch in tqdm(range(NUM_EPOCHS), desc='EPOCH'):
    train_loss = 0.
    train_refs = []
    train_hyps = []
    valid_loss = 0.
    valid_refs = []
    valid_hyps = []
    
    # train
    for batch in tqdm(train_dataloader, total=train_dataloader.num_examples // train_dataloader.batch_size , desc='TRAIN'):
        batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch
        
        loss, gold, pred = compute_loss(
            batch_S, batch_N, batch_E, batch_Y, 
            lengths_S, lengths_N, lengths_E, lengths_Y, 
            max_length_S,max_length_N,max_length_E,max_length_Y, 
            lengths_k, index_N, model, optimizer,
            is_train=True
            )
        train_loss += loss
        train_refs += gold
        train_hyps += pred
    
    # valid
    for batch in tqdm(valid_dataloader, total=valid_dataloader.num_examples // valid_dataloader.batch_size , desc='VALID'):

        batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch

        loss, gold, pred = compute_loss(
            batch_S, batch_N, batch_E, batch_Y, 
            lengths_S, lengths_N, lengths_E, lengths_Y, 
            max_length_S,max_length_N,max_length_E,max_length_Y, 
            lengths_k, index_N, model, optimizer,
            is_train=False
            )
        valid_loss += loss
        valid_refs += gold
        valid_hyps += pred
            

    train_loss = np.sum(train_loss) / train_dataloader.num_examples
    valid_loss = np.sum(valid_loss) / valid_dataloader.num_examples
    
    # BLEU
    train_bleu = calc_bleu(train_refs, train_hyps)
    valid_bleu = calc_bleu(valid_refs, valid_hyps)
    
    # F1 etc
    train_precision, train_recall, train_f1 = calculate_results(train_refs, train_hyps)
    valid_precision, valid_recall, valid_f1 = calculate_results(valid_refs, valid_hyps)

    
    #if valid_bleu > best_valid_bleu:
    if valid_f1 > best_valid_f1 and epoch > start_saving_epoch:
        tlpt = model.state_dict()
        torch.save(tlpt, save_path)
        best_valid_f1 = valid_f1
        print('Best valid F1, model saved.')
    
    print('Epoch {}: train_loss: {:5.2f}  train_bleu: {:2.4f}  train_f1: {:2.4f}  valid_loss: {:5.2f}  valid_bleu: {:2.4f}  valid_f1: {:2.4f}'.format(
            epoch, train_loss, train_bleu, train_f1, valid_loss, valid_bleu, valid_f1))
    
    print('-- Prediction example --')
    for i, (ref, pred) in enumerate(zip(valid_refs[:5], valid_hyps[:5])):
        print(i, 'REF  :',ids_to_sentence(vocab_target, trim_eos(ref)))
        print(i, 'PRED :',ids_to_sentence(vocab_target, trim_eos(pred)))
        print('-'*80)
    
    print('-'*80)
    
    scheduler.step()

## Evaluation

In [0]:
# load best model
tlpt = torch.load(save_path)
model.load_state_dict(tlpt)
model.eval()

In [0]:
del train_dataloader, valid_dataloader

In [0]:

#$ wc -l java-small.val.c2s
#23844 java-small.val.c2s

test_loader_param = {
    'data_path' : 'java-small/java-small.val.c2s',
    'num_examples' : 23844,
    'batch_size' : 1,
    'num_k' : 200,
    'vocab_subtoken' : vocab_subtoken,
    'vocab_nodes' : vocab_nodes,
    'vocab_target' : vocab_target,
    'sector_size' : 1
}

test_dataloader = DataLoader(**test_loader_param)

In [0]:
refs_list = []
hyp_list = []

for batch in tqdm(test_dataloader,
                      total=test_dataloader.num_examples // test_dataloader.batch_size,
                      desc='TEST'):
    
    batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch
    target_max_length = batch_Y.size(0)
    use_teacher_forcing = False
    
    pred_Y = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)
    
    refs = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()[0]
    pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()[0]
    
    refs_list.append(refs)
    hyp_list.append(pred)


In [0]:
print('-- Prediction example --')
for i, (ref, pred) in enumerate(zip(refs_list[:5], hyp_list[:5])):
    print(i, 'REF  :',ids_to_sentence(vocab_target, trim_eos(ref)))
    print(i, 'PRED :',ids_to_sentence(vocab_target, trim_eos(pred)))
    print('-'*80)
    
test_precision, test_recall, test_f1 = calculate_results(refs_list, hyp_list)

print('Precision :', test_precision)
print('Recall :', test_recall)
print('F1 : ', test_f1)