# Copy-Generator Transformer

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
import numpy as np
import torchtext
import tqdm
from torchnlp.metrics import get_moses_multi_bleu
from torchtext.data import Field, BucketIterator
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

import tensorflow as tf
import tensorflow_datasets as tfds
from tokenize import tokenize, untokenize, NUMBER, STRING, NAME, OP
from io import BytesIO

import linecache
import sys
import os
import re
import random
import time
import operator
import collections
from pprint import pprint

from models_and_trainers.base_transformer import TransformerModel, PositionalEncoding
from models_and_trainers.copy_gen_transformer import Transformer, TransformerDecoderLayer, TransformerDecoder
from models_and_trainers.retrieval import PyLuceneRetriever
import beam_search
from utils.edit_tagger import build_matrix, single_step_edits, perform_edits, get_tags

from IPython.core.debugger import set_trace as tr
%load_ext autoreload
%autoreload 2

In [5]:
from datetime import datetime
def super_print(filename):
    '''filename is the file where output will be written'''
    def wrap(func):
        '''func is the function you are "overriding", i.e. wrapping'''
        def wrapped_func(*args,**kwargs):
            '''*args and **kwargs are the arguments supplied 
            to the overridden function'''
            #use with statement to open, write to, and close the file safely
            with open(filename,'a') as outputfile:
                now = datetime.now()
                dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
                outputfile.write("[{}] ".format(dt_string))
                outputfile.write(" ".join(str(x) for x in args))
                outputfile.write("\n")
            #now original function executed with its arguments as normal
            return func(*args,**kwargs)
        return wrapped_func
    return wrap

print = super_print('logs-copy-gen.txt')(print)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    torch.cuda.set_device(0) # choose GPU from nvidia-smi 
print("Using:", device)

Using: cpu


In [7]:
text = "create variable student_names with string 'foo bar baz'"

def string_split(s):
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(_|\W)', s))) # this will chunk all code properly by plits strings with quotes
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|_|\W)', s))) # this keeps the strings intact
    return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|\W)', s)))

print(string_split(text))

['create', 'variable', 'student_names', 'with', 'string', "'foo bar baz'"]


In [8]:
def corpus_to_array(src_fp, tgt_fp):
    lines = []
    with open(src_fp, "r") as src_file, open(tgt_fp, "r") as tgt_file:
        for src, tgt in zip(src_file, tgt_file):
            lines.append((src, tgt))
    return lines

In [9]:
def filter_corpus(data, max_seq_length=200, tokenizer=string_split):
    return [(src, tgt) for src, tgt in data if len(string_split(src)) <= max_seq_length and len(string_split(tgt)) <= max_seq_length]

In [10]:
def samples_to_dataset(samples):
    """
    Args:
        samples: [(src_string),(tgt_string)]
        src/tgt_tokenizer: a func that takes a string and returns an array of strings
    """
    examples = []
    TEXT_FIELD = Field(sequential=True, use_vocab=False, init_token='<sos>',eos_token='<eos>')
    
    for sample in samples:
        src_string, tgt_string = sample
        examples.append(torchtext.data.Example.fromdict({"src":src_string, "tgt":tgt_string}, 
                                        fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD)}))
        
    dataset = torchtext.data.Dataset(examples,fields={"src":src_field, "tgt":tgt_field})
    return dataset

In [11]:
data = corpus_to_array("datasets/all.desc", "datasets/all.code")
# data = corpus_to_array("datasets/all.desc", "datasets/all.code")
random.shuffle(data)
print("Max src length:", max([len(string_split(src)) for src, tgt in data]))
print("Max tgt length:", max([len(string_split(tgt)) for src, tgt in data]))

print("Full dataset size:", len(data))
max_seq_length=200
data = filter_corpus(data, max_seq_length=max_seq_length, tokenizer=string_split)
data = [(" ".join(string_split(src))," ".join(string_split(tgt))) for src, tgt in data]
print("Limited dataset size:", len(data))

Max src length: 557
Max tgt length: 527
Full dataset size: 18805
Limited dataset size: 18794


In [12]:
stoi = {"<unk>":0, "<sos>":1, "<eos>":2, "<pad>":3, "<gen>":4}
max_vocab = 10000 - len(stoi)

all_toks = []
for (src, tgt) in data:
    all_toks += string_split(src)
    all_toks += string_split(tgt)

most_freq = collections.Counter(all_toks).most_common(max_vocab)

for tok, count in most_freq:
    stoi[tok] = len(stoi)
    
itos = [k for k,v in sorted(stoi.items(), key=lambda kv: kv[1])]

In [13]:
print(f"Vocabulary size: {len(stoi)}")

Vocabulary size: 9178


In [14]:
def encode_input(string, from_list=False):
    OOVs = []
    IDs = []
    if not from_list:
        words = string_split(string)
    else:
        words = string
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            IDs.append(len(stoi) + len(OOVs))
            OOVs.append(word)
    return IDs, OOVs

In [15]:
def encode_output(string, OOVs, from_list=False):
    IDs = []
    if not from_list:
        words = string_split(string)
    else:
        words = string
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            try:
                IDs.append(len(stoi) + OOVs.index(word))
            except ValueError as e:
                IDs.append(stoi["<unk>"])
    return IDs

In [16]:
def decode(ids, OOVs):
    extended_itos = itos.copy()
    extended_itos += [OOV+"(COPY)" for OOV in OOVs]
    return " ".join([extended_itos[id] for id in ids if id<len(extended_itos)])

## Retrieval 
We want to find the most appropriate code to the input description. For this we use PyLucene to provide a BM25 search over the english descriptions.

In [17]:
data[0]

("return value of the npath function with string '.mo' appended to the base_path as argument , and return value of the npath function with string '.po' appended to the base_path as argument , substitute it for args . call the popen_wrapper with args as the argument , assign the result to the output , errors and status , respectively .",
 'output , errors , status = popen_wrapper ( args )')

In [18]:
retriever = PyLuceneRetriever()

src_data = [src for src,tgt in data]
retriever.add_multiple_docs(src_data)

In [19]:
doc_ranking = retriever.BM25_search("is greater than")
doc_ids = [x.doc for x in doc_ranking]
retrieved_samples = [data[i] for i in doc_ids]
print("10 best matched samples")
for doc in retrieved_samples[:10]:
    print(f"Description: {doc[0]}")
    print(f"Code       : {doc[1]}")
    print()

10 best matched samples
Description: if age is greater than max_age ,
Code       : if age > max_age :

Description: if start is greater than upto ,
Code       : if start > upto :

Description: if doublecolon_len is greater than best_doublecolon_len ,
Code       : if doublecolon_len > best_doublecolon_len :



## Creating Edit steps
While there is a vocabulary for the shared english and code. The editing tokens also need to be converted to integers.

In [53]:
edit_stoi = {"K":0, "D":1, "R":2, "<pad>":3}
edit_itos = {0:"K", 1:"D", 2:"R", 3:"<pad>"}
max_insertions = 20

In [54]:
def encode_commands(commands):
    return [edit_stoi[command] for command in commands]

def decode_commands(commands):
    return [edit_itos[command] for command in commands]

### A single edit example
Let's make a complete sample from the dataset.
![alt text](./images/edit_transformer_diagram.png)

In [55]:
x, y = data[0]
doc_ranking = retriever.BM25_search(x)
top_2_doc = data[doc_ranking[2].doc] # get the second best document since the best doc is the same from the description
x_, y_ = top_2_doc

# convert to token array
x = string_split(x)
y = ["<sos>"] + string_split(y)
x_ = string_split(x_)
y_ = ["<sos>"] + string_split(y_)

print(y, y_)

dataset_edit_samples = []
while y_ != y:
    edit_steps = single_step_edits(y_, y)
    commands, insertions, replacements = edit_steps
    sample = {
        "encoder_input": x,
        "decoder_input": y_.copy(),
        "target_commands": commands,
        "target_insertions": insertions,
        "target_replacements": replacements,
        "code_target": y
    }
    dataset_edit_samples.append(sample)
    y_ = perform_edits(y_, edit_steps, gen_tok_id="<gen>")

pprint(dataset_edit_samples)

['<sos>', 'output', ',', 'errors', ',', 'status', '=', 'popen_wrapper', '(', 'args', ')'] ['<sos>', 'msgs', ',', 'errors', ',', 'status', '=', 'popen_wrapper', '(', 'args', ')']
[{'code_target': ['<sos>',
                  'output',
                  ',',
                  'errors',
                  ',',
                  'status',
                  '=',
                  'popen_wrapper',
                  '(',
                  'args',
                  ')'],
  'decoder_input': ['<sos>',
                    'msgs',
                    ',',
                    'errors',
                    ',',
                    'status',
                    '=',
                    'popen_wrapper',
                    '(',
                    'args',
                    ')'],
  'encoder_input': ['return',
                    'value',
                    'of',
                    'the',
                    'npath',
                    'function',
                    'with',
                    'stri

### Making the dataset
The trick to modularizing models effectively is to make important functions that are necessary to them. `data2dataset()` is one suuch example. Taking in the dataset provided from the paper and converting it into the format needed to train our edit model.

In [60]:
TEXT_FIELD = Field(sequential=True, use_vocab=False, unk_token=0, pad_token=3)
OOV_TEXT_FIELD = Field(sequential=True, use_vocab=False, pad_token=3)

def data2dataset(data, desc_rank=1):
    TEXT_FIELD = Field(sequential=True, use_vocab=False, unk_token=0, pad_token=3)
    OOV_TEXT_FIELD = Field(sequential=True, use_vocab=False, pad_token=3)
    
    OOV_stoi = {}
    OOV_itos = {}
    OOV_starter_count = 30000
    OOV_count = OOV_starter_count
    
    examples = []

    for (src, tgt) in data:
        
#         print(src)
        doc_ranking = retriever.BM25_search(src)
        if len(doc_ranking) > 1:
            top_2_doc = data[doc_ranking[desc_rank].doc]
            x_, y_ = top_2_doc
        else:
            x_, y_ = "", ""
        src_ids, OOVs = encode_input(src)
        decoder_input = encode_output(y_, OOVs)
        ground_truth_code = encode_output(tgt, OOVs)
        
        decoder_input = [stoi["<sos>"]] + decoder_input
        ground_truth_code = [stoi["<sos>"]] + ground_truth_code
        
        ran_once = False
        
        while decoder_input != ground_truth_code or not ran_once:
            ran_once = True
            edit_steps = single_step_edits(decoder_input, ground_truth_code, pad_token=stoi["<pad>"], token_insertions=2)
            commands, target_insertions, target_replacements = edit_steps
            
            target_commands = encode_commands(commands)
#             target_replacements = encode_output(replacements, OOVs, from_list=True)
            
#             print(src_ids)
#             print(decoder_input)
#             print(decode(decoder_input, OOVs))
#             print(ground_truth_code)
#             print(decode(ground_truth_code, OOVs))
#             print(target_commands)
#             print(target_insertions)
#             print(target_replacements)
#             print()
            
#             print(len(decoder_input) == len(target_commands))

            
            OOV_ids = []

            for OOV in OOVs:
                try:
                    idx = OOV_stoi[OOV]
                    OOV_ids.append(idx)
                except KeyError as e:
                    OOV_count += 1
                    OOV_stoi[OOV] = OOV_count
                    OOV_itos[OOV_count] = OOV
                    OOV_ids.append(OOV_count)
                    
            
            if "<DELETE_ME>" in decoder_input:
                print(decoder_input)

            example = torchtext.data.Example.fromdict({"encoder_input":src_ids, 
                                                       "ground_truth_code":ground_truth_code, 
                                                       "OOVs":OOV_ids, 
                                                       "decoder_input":decoder_input,
                                                       "target_commands":target_commands, 
                                                       "target_insertions": target_insertions,
                                                       "target_replacements":target_replacements}, 
                                                    fields={"encoder_input":("encoder_input",TEXT_FIELD), 
                                                            "ground_truth_code":("ground_truth_code",TEXT_FIELD),
                                                            "OOVs":("OOVs", OOV_TEXT_FIELD), 
                                                            "decoder_input":("decoder_input",TEXT_FIELD),
                                                            "target_commands":("target_commands",TEXT_FIELD), 
                                                            "target_insertions": ("target_insertions",TEXT_FIELD),
                                                            "target_replacements":("target_replacements",TEXT_FIELD)})
            examples.append(example)
            decoder_input = perform_edits(decoder_input, edit_steps, gen_tok_id=stoi["<gen>"])
    return examples

examples = data2dataset(data, desc_rank=1)

In [61]:
p = ["foo","errd","d"]
    
sorted(p, key=lambda x: -len(x))

['errd', 'foo', 'd']

In [62]:
dataset = torchtext.data.Dataset(examples,fields={"encoder_input":TEXT_FIELD, 
                                                  "ground_truth_code":TEXT_FIELD, 
                                                  "OOVs":OOV_TEXT_FIELD, 
                                                  "decoder_input":TEXT_FIELD, 
                                                  "target_commands":TEXT_FIELD, 
                                                  "target_insertions":TEXT_FIELD, 
                                                  "target_replacements":TEXT_FIELD})

train_dataset, val_dataset = dataset.split([0.9,0.1])
# train_dataset = val_dataset = dataset

In [65]:
batch_size = 4

train_iterator = BucketIterator(
    train_dataset,
    batch_size = batch_size,
    repeat=True,
    shuffle=True,
    sort_key = lambda x: len(x.encoder_input)+len(x.decoder_input), # this doesn't seem to work, check it out later
    device = device)

# The iterator generates batches with padded length for sequences with similar sizes, a batch is [seq_length, batch_size]

for i, batch in enumerate(train_iterator):
    idx = 2
#     print([SRC_TEXT.vocab.itos[id] for id in batch.src.cpu().numpy()[:,idx]])
    OOVs = [OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] # 3 is the <pad> token
    encoder_input = batch.encoder_input.cpu()[:,idx].tolist()
    decoder_input = batch.decoder_input.cpu()[:,idx].tolist()
    ground_truth_code = batch.ground_truth_code.cpu()[:,idx].tolist()
    target_commands = batch.target_commands.cpu()[:,idx].tolist()
    
    print("encoder_input    :",decode(encoder_input, OOVs))
    print("decoder_input    :",decode(decoder_input, OOVs))
    print("ground_truth_code:",decode(ground_truth_code, OOVs))
    print("target_commands  :", target_commands)
    print(len(decoder_input), len(target_commands))
    print()
    break

encoder_input    : substitute http_cookies . Morsel for Morsel . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
decoder_input    : <sos> dict . __setitem__ ( self , key , http_cookies . Morsel ( ) )
ground_truth_code: <sos> Morsel = http_cookies . Morsel <pad> <pad> <pad> <pad> <pad> <pad> <pad>
target_commands  : [0, 1, 1, 1, 1, 1, 1, 2, 2, 0, 0, 0, 1, 1, 1]
15 15



# The Model

In [66]:
class CopyModel(nn.Module):

    def __init__(self, vocab_size, embedding_size=512, dropout=0.5):
        super(CopyModel, self).__init__()
        self.model_type = 'Transformer'
        
        self.embedding_size = embedding_size
        self.pos_encoder = PositionalEncoding(embedding_size, dropout)
        self.src_encoder = nn.Embedding(vocab_size, embedding_size)
        self.tgt_encoder = nn.Embedding(vocab_size, embedding_size)
        
        self.transformer = Transformer(d_model=embedding_size, nhead=8, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=1024)
        self.replacement_decoder = nn.Linear(embedding_size, vocab_size)
        self.insertion_decoder = nn.Linear(embedding_size, max_insertions)
        self.command_decoder = nn.Linear(embedding_size, len(edit_stoi))
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.p_generator = nn.Linear(embedding_size,1)

        self.init_weights()
        self.tgt_mask = None

    def init_weights(self):
        initrange = 0.1
        self.src_encoder.weight.data.uniform_(-initrange, initrange)
        self.tgt_encoder.weight.data.uniform_(-initrange, initrange)
        
        self.replacement_decoder.bias.data.zero_()
        self.replacement_decoder.weight.data.uniform_(-initrange, initrange)
        
        self.insertion_decoder.bias.data.zero_()
        self.insertion_decoder.weight.data.uniform_(-initrange, initrange)
        
        self.command_decoder.bias.data.zero_()
        self.command_decoder.weight.data.uniform_(-initrange, initrange)
        
    def _generate_square_subsequent_mask(self, sz):
#         noise_e = 0.05 if self.training else 0.0 # this is code to add noise to the decoding process during training
        noise_e = 0.0 if self.training else 0.0
        noise_mask = (torch.rand(sz,sz) > noise_e).float()

        mask = (torch.triu(torch.ones(sz,sz))).transpose(0, 1)
        mask = torch.mul(mask, noise_mask)
        v = (torch.sum(mask, dim=-1) == 0).float()

        fix_mask = torch.zeros(sz,sz)
        fix_mask[:,0] = 1.0
        v = v.repeat(sz, 1).transpose(0,1)
        fix_mask = torch.mul(fix_mask,v)

        mask += fix_mask
        
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt):
        self.tgt_mask = self._generate_square_subsequent_mask(len(tgt)).to(self.device)
        

        src_emb = self.src_encoder(src) * math.sqrt(self.embedding_size)
        src_emb = self.pos_encoder(src_emb)
        
        tgt_emb = self.tgt_encoder(tgt) * math.sqrt(self.embedding_size)
        tgt_emb = self.pos_encoder(tgt_emb)
        
        output, atts = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)
        
        
        src_scat = src.transpose(0,1)
        src_scat = src_scat.unsqueeze(0)
        src_scat = torch.repeat_interleave(src_scat, tgt.shape[0], dim=0)
#         print("src_scat.shqape", src_scat.shape)
        
        p_gens = self.p_generator(output).sigmoid()
        atts = atts.transpose(0,1)
#         print("att.shqape", atts.shape)
        atts = atts * (1 - p_gens)
                
        target_replacements = self.replacement_decoder(output)
#         output[:,:,12:] = -np.inf
        target_replacements = target_replacements.softmax(-1)
        target_replacements = target_replacements * p_gens
        
        target_replacements = target_replacements.scatter_add_(2,src_scat,atts)
        
        target_insertions = self.insertion_decoder(output)
        
        target_commands = self.command_decoder(output)
        
        return target_commands, target_insertions, target_replacements.log()

In [67]:
stoi["','"]

670

In [74]:
vocab_size = len(itos) + max_seq_length

model = CopyModel(vocab_size).to(device) 
src = torch.randint(0, vocab_size, (3,2)).to(device)
tgt = torch.randint(0, vocab_size, (5,2)).to(device)

target_commands, target_insertions, target_replacements = model(src, tgt)
target_commands.shape, target_insertions.shape, target_replacements.shape

(torch.Size([5, 2, 4]), torch.Size([5, 2, 20]), torch.Size([5, 2, 9378]))

In [75]:
def nltk_bleu(refrence, prediction):
    """
    Implementation from ReCode
    and moses multi belu script sets BLEU to 0.0 if len(toks) < 4
    """
    ngram_weights = [0.25] * min(4, len(refrence))
    return sentence_bleu([refrence], prediction, weights=ngram_weights, 
                          smoothing_function=SmoothingFunction().method3)

In [108]:
def rmpad(arr):
    pad_id = stoi["<pad>"]
    return [x for x in arr if x != pad_id]

In [142]:
def outputs2code(target_commands, target_insertions, target_replacements, decoder_input, OOVss):
    batch_size = target_commands.shape[1]
    
    edited_code_samples = []
    for i in range(batch_size):
        code_to_edit = rmpad(decoder_input[:,i].tolist())
        code_length = len(code_to_edit)
        
        sample_commands = target_commands[:code_length,i].view(-1).tolist()
#         print(sample_commands)
        sample_commands = decode_commands(sample_commands)
        sample_insertions = target_insertions[:code_length,i].view(-1).tolist()
        sample_replacements = target_replacements[:code_length,i].view(-1).tolist()
        
        
        
        edits = (sample_commands, sample_insertions, sample_replacements)
        
        edited_code = perform_edits(code_to_edit, edits, gen_tok_id=stoi["<gen>"])
        
        edited_code = decode(rmpad(edited_code), OOVss[:,i].tolist())
        edited_code_samples.append(edited_code)
    return edited_code_samples
    
        
batch = next(iter(train_iterator))
decoder_input = batch.decoder_input
encoder_input = batch.encoder_input
ground_truth_code = batch.ground_truth_code
OOVss = batch.OOVs

gt_target_commands = batch.target_commands
gt_target_insertions = batch.target_insertions
gt_target_replacements = batch.target_replacements

gt_edited_code = outputs2code(gt_target_commands, gt_target_insertions, gt_target_replacements, decoder_input, OOVss)

target_commands, target_insertions, target_replacements = model(encoder_input, decoder_input)

_, argmax_target_commands = target_commands.max(2)
_, argmax_target_insertions = target_insertions.max(2)
_, argmax_target_replacements = target_replacements.max(2)

outputs2code(argmax_target_commands, argmax_target_insertions, argmax_target_replacements, decoder_input, OOVss)

for i in range(batch_size):
    print("encoder_input      :", decode(rmpad(encoder_input[:,i].tolist()), []))
    print("ground_truth_code  :",decode(rmpad(ground_truth_code[:,i].tolist()), []))
    print("decoder_input      :",decode(rmpad(decoder_input[:,i].tolist()), []))
    print("gt_target_commands :", rmpad(gt_target_commands[:,i].tolist()))
    print("gt_target_insertions:", rmpad(gt_target_insertions[:,i].tolist()))
    print("gt_target_replacements:", decode(rmpad(gt_target_replacements[:,i].tolist()), []))
    print("gt_edited_code     :",gt_edited_code[i])
    print()

encoder_input      : if escaped is true ,
ground_truth_code  : <sos> if escaped :
decoder_input      : <sos> if escaped :
gt_target_commands : [0, 0, 0, 0]
gt_target_insertions: [0, 0, 0, 0]
gt_target_replacements: 
gt_edited_code     : <sos> if escaped :

gt_target_commands : [0, 0, 0]
gt_target_insertions: [0, 0, 0]
gt_target_replacements: 

encoder_input      : try ,
ground_truth_code  : <sos> try :
decoder_input      : <sos> try :
gt_target_commands : [0, 0, 0]
gt_target_insertions: [0, 0, 0]
gt_target_replacements: 
gt_edited_code     : <sos> try :

encoder_input      : default is a string 'DEFAULT' .
ground_truth_code  : <sos> default = 'DEFAULT'
decoder_input      : <sos> DEFAULT_CACHE_ALIAS = 'default'
gt_target_commands : [0, 2, 0, 2]
gt_target_insertions: [0, 0, 0, 0]
gt_target_replacements: default 'DEFAULT'
gt_edited_code     : <sos> default = 'DEFAULT'



In [32]:
valid_iterator = BucketIterator(val_dataset,
    batch_size = 32,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

def batch_filter_ids(batch_list):
    return [[id for id in l if id not in [1,2,3]] for l in batch_list]

def evaluate(beam_size=1, log=False):
    model.eval() # Turn on the evaluation mode
    with torch.no_grad(), open("out.txt", "w") as out_fp:
        BLEU_scores = []
        for i, batch in enumerate(valid_iterator):
            batch_size = batch.src.shape[1]
            
            encoder_inputs = batch.src
            predictions = beam_search.beam_search_decode(model,
                              batch_encoder_ids=encoder_inputs,
                              SOS_token=stoi["<sos>"],
                              EOS_token=stoi["<eos>"],
                              PAD_token=stoi["<pad>"],
                              beam_size=beam_size,
                              max_length=20,
                              num_out=1)
            
            sources = encoder_inputs.transpose(0,1).cpu().tolist()
            sources = batch_filter_ids(sources)
            
            predictions = [t[0].view(-1).cpu().tolist() for t in predictions]
            predictions = batch_filter_ids(predictions)
            
            targets = batch.tgt.transpose(0,1).cpu().tolist()
            targets = batch_filter_ids(targets)
            
#             print(batch.tgt)
            
            OOVss = [[OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] for idx in range(batch_size)]
            
            if i % int(len(valid_iterator)/3) == 0:
                print("| EVALUATION | {:5d}/{:5d} batches |".format(i, len(valid_iterator)))
            
            for j in range(batch_size):
                BLEU = nltk_bleu(targets[j], predictions[j])
                BLEU_scores.append(BLEU)
                
                out_fp.write("SRC  :" + decode(sources[j], OOVss[j]) + "\n")
                out_fp.write("TGT  :" + decode(targets[j], OOVss[j]) + "\n")
                out_fp.write("PRED :" + decode(predictions[j], OOVss[j]) + "\n")
                out_fp.write("BLEU :" + str(BLEU) + "\n")
                out_fp.write("\n")
                
                if log:
                    print("SRC  :" + decode(sources[j], OOVss[j]))
                    print("TGT  :" + decode(targets[j], OOVss[j]))
                    print("PRED :" + decode(predictions[j], OOVss[j]))
                    print("BLEU :" + str(BLEU))
                    print()
        out_fp.write("\n\n| EVALUATION | BLEU: {:5.2f} |\n".format(np.average(BLEU_scores)))
        print("| EVALUATION | BLEU: {:5.3f} |".format(np.average(BLEU_scores)))

In [33]:
evaluate(beam_size=1)

AttributeError: 'Batch' object has no attribute 'src'

In [76]:
criterion = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'])
lr = 0.005 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.99)

In [77]:
def train_step(batch):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    vocab_size = len(itos) + max_seq_length
    encoder_input = batch.encoder_input
    decoder_input = batch.decoder_input
    ground_truth_commands = batch.target_commands
    ground_truth_insertions = batch.target_insertions
    ground_truth_replacements = batch.target_replacements

    optimizer.zero_grad()
    target_commands, target_insertions, target_replacements = model(encoder_input, decoder_input)

    
    command_loss = criterion(target_commands.view(-1, len(edit_stoi)), ground_truth_commands.view(-1))
    insertion_loss = criterion(target_insertions.view(-1, max_insertions), ground_truth_insertions.view(-1))
    replacement_loss = criterion(target_replacements.view(-1, vocab_size), ground_truth_replacements.view(-1))
    
    loss = command_loss + insertion_loss + replacement_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    elapsed = time.time() - start_time
    return loss

In [78]:
def train(steps=10000, log_interval=200, learning_interval=4000, eval_interval=1000):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    step = 1
    for batch in train_iterator:
        loss = train_step(batch)
        total_loss += loss.item()
        
        if step % log_interval == 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} steps | '
                  'lr {:02.4f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    step, steps, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        
        if step % eval_interval == 0:
            print("Evaluating model")
            evaluate()
            model.train()
        
        if step % learning_interval == 0:
            scheduler.step()
        
        step += 1
        if step >= steps:
            print("Finished training")



            return

train(steps=1000000,eval_interval=10000,log_interval=4)

|     4/1000000 steps | lr 0.0050 | ms/batch 528.08 | loss 11.92 | ppl 150092.90
|     8/1000000 steps | lr 0.0050 | ms/batch 588.28 | loss  9.20 | ppl  9873.35
|    12/1000000 steps | lr 0.0050 | ms/batch 574.45 | loss  8.94 | ppl  7595.55
|    16/1000000 steps | lr 0.0050 | ms/batch 568.42 | loss  9.91 | ppl 20176.67
|    20/1000000 steps | lr 0.0050 | ms/batch 507.27 | loss  7.97 | ppl  2900.74
|    24/1000000 steps | lr 0.0050 | ms/batch 524.54 | loss  7.79 | ppl  2427.73
|    28/1000000 steps | lr 0.0050 | ms/batch 582.48 | loss  8.09 | ppl  3250.94
|    32/1000000 steps | lr 0.0050 | ms/batch 561.84 | loss  8.66 | ppl  5759.95
|    36/1000000 steps | lr 0.0050 | ms/batch 599.12 | loss  9.50 | ppl 13351.32
|    40/1000000 steps | lr 0.0050 | ms/batch 631.95 | loss  8.60 | ppl  5448.22
|    44/1000000 steps | lr 0.0050 | ms/batch 575.42 | loss  7.03 | ppl  1131.00
|    48/1000000 steps | lr 0.0050 | ms/batch 498.92 | loss  7.27 | ppl  1434.77
|    52/1000000 steps | lr 0.0050 | ms/

KeyboardInterrupt: 

# Evaluate