In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
import numpy as np
import torchtext
import tqdm
from torchnlp.metrics import get_moses_multi_bleu
from torchtext.data import Field, BucketIterator
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

import tensorflow as tf
import tensorflow_datasets as tfds
from tokenize import tokenize, untokenize, NUMBER, STRING, NAME, OP
from io import BytesIO

import linecache
import sys
import os
import re
import random
import time
import operator

from base_transformer import TransformerModel

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    torch.cuda.set_device(0) # choose GPU from nvidia-smi 
print("Using:", device)

Using: cuda


In [3]:
!nvidia-smi

Tue Nov 12 11:47:26 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.26       Driver Version: 430.26       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN RTX           Off  | 00000000:B2:00.0 Off |                  N/A |
| 41%   29C    P8     4W / 280W |   1627MiB / 24220MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------

## Loading the dataset

In [4]:
try:
    os.mkdir("./datasets")
except FileExistsError:
    print("Directories already exists")

# getting descriptions
!wget https://raw.githubusercontent.com/odashi/ase15-django-dataset/master/django/all.anno -O ./datasets/all.desc

# getting code
!wget https://raw.githubusercontent.com/odashi/ase15-django-dataset/master/django/all.code -O ./datasets/all.code

Directories already exists
--2019-11-12 11:47:27--  https://raw.githubusercontent.com/odashi/ase15-django-dataset/master/django/all.anno
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.16.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.16.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1382085 (1.3M) [text/plain]
Saving to: './datasets/all.desc'


2019-11-12 11:47:28 (15.5 MB/s) - './datasets/all.desc' saved [1382085/1382085]

--2019-11-12 11:47:28--  https://raw.githubusercontent.com/odashi/ase15-django-dataset/master/django/all.code
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.16.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.16.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 906732 (885K) [text/plain]
Saving to: './datasets/all.code'


2019-11-12 11:47:29 (10.2 MB/s) - './datasets/all.code' saved [906732

## Creating a token text encoder
An encoder will take a file and a splitting function and return an object able to encode and decode a string. It will also be able to save a vocab file and retrieve from file.

In [5]:
text = " append rel_to to string 'ForeignKey, (substitute the result for field_type.)"

# looks like code split need parenthesis to be matched in the same string, if not it gives an error...
def code_split(s):
    return [x.string for x in tokenize(BytesIO(s.encode('utf-8')).readline) if x.string != '' and x.string != "\n" and not x.string.isspace()][1:]

print(code_split(text))

['append', 'rel_to', 'to', 'string', "'", 'ForeignKey', ',', '(', 'substitute', 'the', 'result', 'for', 'field_type', '.', ')']


In [6]:
text = " append rel_to to string 'ForeignKey, (subs__titute the result' for field_type."

def string_split(s):
    return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(_|\W)', s))) # this will chunk all code properly by plits strings with quotes
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|_|\W)', s))) # this keeps the strings intact

print(string_split(text))

['append', 'rel', '_', 'to', 'to', 'string', "'", 'ForeignKey', ',', '(', 'subs', '_', '_', 'titute', 'the', 'result', "'", 'for', 'field', '_', 'type', '.']


## Making the input pipeline

In [7]:
def corpus_to_array(src_fp, tgt_fp):
    lines = []
    with open(src_fp, "r") as src_file, open(tgt_fp, "r") as tgt_file:
        for src, tgt in zip(src_file, tgt_file):
            lines.append((src, tgt))
    return lines

In [8]:
def filter_corpus(data, max_seq_length=200, tokenizer=string_split):
    return [(src, tgt) for src, tgt in data if len(string_split(src)) <= max_seq_length and len(string_split(tgt)) <= max_seq_length]

In [9]:
def samples_to_dataset(samples, src_field, tgt_field):
    """
    Args:
        samples: [(src_string),(tgt_string)]
        src/tgt_tokenizer: a func that takes a string and returns an array of strings
    """
    examples = []
    
    for sample in samples:
        src_string, tgt_string = sample
        examples.append(torchtext.data.Example.fromdict({"src":src_string, "tgt":tgt_string}, 
                                        fields={"src":("src",src_field), "tgt":("tgt",tgt_field)}))
        
    dataset = torchtext.data.Dataset(examples,fields={"src":src_field, "tgt":tgt_field})
    return dataset

In [10]:
data = corpus_to_array("datasets/all.desc", "datasets/all.code")
random.shuffle(data)

In [11]:
print("Max src length:", max([len(string_split(src)) for src, tgt in data]))
print("Max tgt length:", max([len(string_split(tgt)) for src, tgt in data]))

Max src length: 586
Max tgt length: 1087


In [12]:
print("Full dataset size:", len(data))
max_seq_length=200
data = filter_corpus(data, max_seq_length=200, tokenizer=string_split)
print("Limited dataset size:", len(data))

Full dataset size: 18805
Limited dataset size: 18781


In [16]:
try:
    SRC_TEXT = torch.load("./src_vocab.vcb")
except:
    SRC_TEXT = Field(sequential=True, tokenize=string_split, init_token='<sos>',eos_token='<eos>')

try:
    TGT_TEXT = torch.load("./tgt_vocab.vcb")
except:
    TGT_TEXT = Field(sequential=True, tokenize=string_split, init_token='<sos>',eos_token='<eos>')

dataset = samples_to_dataset(data, SRC_TEXT, TGT_TEXT)

train_dataset, val_dataset = dataset.split([0.9,0.1])

In [17]:
if not hasattr(SRC_TEXT, "vocab"):
    print("creating src vocab")
    SRC_TEXT.build_vocab(train_dataset)
if not hasattr(TGT_TEXT, "vocab"):
    print("creating tgt vocab")
    TGT_TEXT.build_vocab(train_dataset)


sample = dataset[2].src
for tok, id in zip(sample, SRC_TEXT.numericalize([sample])):
    print("{} -> {}".format(tok, id.numpy()[0]))

substitute -> 21
it -> 41
for -> 13
self -> 10
. -> 4
view -> 356
_ -> 5
name -> 29
. -> 4
define -> 30
the -> 7
method -> 16
_ -> 5
_ -> 5
getitem -> 582
_ -> 5
_ -> 5
with -> 9
arguments -> 24
self -> 10
and -> 14
index -> 175
. -> 4


## Creating the dataset iterator
This will create a finction returning a different batch. The `train_iterator` is infinitely repeating. while the validation one is not.

In [18]:
batch_size = 32

train_iterator = BucketIterator(
    train_dataset,
    batch_size = batch_size,
    repeat=True,
#     shuffle=True,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

valid_iterator = BucketIterator(val_dataset,
    batch_size = batch_size,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

# The iterator generates batches with padded length for sequences with similar sizes, a batch is [seq_length, batch_size]

for i, batch in enumerate(train_iterator):
    idx = 0
    print([SRC_TEXT.vocab.itos[id] for id in batch.src.cpu().numpy()[:,idx]])
    print(batch.src.cpu().numpy()[:,idx])
    print(batch.tgt.cpu().numpy()[:,idx])
    break

['<sos>', 'if', 'errors', '_', 'on', '_', 'separate', '_', 'row', 'and', 'bf', '_', 'errors', 'are', 'both', 'true', ',', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
[   2   15  167    5  172    5 2210    5  415   14  927    5  167  207
  458   53    6    3    1    1    1    1    1    1    1    1    1    1
    1    1    1    1    1]
[   2   14  125    4  511    4 1778    4  331   51  854    4  125   11
    3    1    1    1    1    1    1    1    1    1    1    1    1    1
    1    1    1    1    1    1    1    1    1    1    1    1    1    1
    1    1    1    1    1    1    1    1    1    1    1    1    1    1
    1    1]


Sample transformer without positional encoding, it uses the built in transformer model

In [16]:
rand_transformer_model = nn.Transformer() # uses default hyperparameters
src = torch.rand((10, 32, 512)) # [src_seq_length, batch_size, embedding_size]
tgt = torch.rand((20, 32, 512)) # [tgt_seq_length, batch_size, embedding_size]
rand_transformer_model(src, tgt).shape # [tgt_seq_length, batch_size, embedding_size]

torch.Size([20, 32, 512])

## Building the model


In [19]:
src_vocab_size = len(SRC_TEXT.vocab.itos)
tgt_vocab_size = len(TGT_TEXT.vocab.itos)

model = TransformerModel(src_vocab_size, tgt_vocab_size, dropout=0.2).to(device) 

In [20]:
def greedy_decode_batch_ids(encoder_input, max_seq_length=50):
    batch_len = encoder_input.shape[1]
    sos_id = TGT_TEXT.vocab.stoi["<sos>"]
    decoder_input = torch.zeros((1, batch_len), dtype=torch.long, device=device).fill_(sos_id)

    for i in range(max_seq_length):
        output = model(encoder_input, decoder_input)
        last_pred = output[-1:].argmax(dim=2)

        decoder_input = torch.cat((decoder_input, last_pred))
    return decoder_input

In [21]:
class BeamSearchNode(object):
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length
        
    def __lt__(self, other):
        return True

    def eval(self, alpha=1.0):
        reward = 0
        beta = 4.0
        # Add here a function for shaping a reward

        return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward

In [20]:
torch.tensor([[1.,2.]]).softmax(1)

tensor([[0.2689, 0.7311]])

In [21]:
# %load_ext line_profiler

In [22]:
# %lprun -f beam_decode beam_decode(model, batch_size=1, encoder_states=src_ids)

In [23]:
def beam_decode(model, encoder_states):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''

    beam_width = 10
    topk = 3  # how many sentence do you want to generate
    decoded_batch = []
    
    SOS_token = TGT_TEXT.vocab.stoi["<sos>"]
    EOS_token = TGT_TEXT.vocab.stoi["<eos>"]
    MAX_LENGTH = 7
    
    batch_size = encoder_states.shape[1]

    # decoding goes sentence by sentence
    for idx in range(batch_size):
        encoder_input = encoder_states[:, idx].view(-1,1)
        
        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([[SOS_token]]).to(device)

        # Number of sentence to generate
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))

        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(decoder_input, None, SOS_token, 0, 1)
        nodes = PriorityQueue()

        # start the queue
        nodes.put((-node.eval(), node))
        qsize = 1

        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 400: break

            # fetch the best node
            score, n = nodes.get()
#             decoder_input = n.wordid
            decoder_input = n.h

            if n.wordid == EOS_token and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    break
                else:
                    continue

            # decode for one step using decoder
#             decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
#             print(encoder_input)
#             print(decoder_input)
            decoder_output = model(encoder_input, decoder_input)
            last_token_logits = decoder_output[-1]
            last_token_logs = last_token_logits.log_softmax(1)
            # PUT HERE REAL BEAM SEARCH OF TOP
            log_prob, indexes = torch.topk(last_token_logs, beam_width)
            nextnodes = []

            for new_k in range(beam_width):
                decoded_t = indexes[0][new_k]
                log_p = log_prob[0][new_k].item()
                decoder_input = torch.cat((decoder_input,decoded_t.view(1,-1)))
                node = BeamSearchNode(decoder_input, n, decoded_t.cpu().item(), n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))

            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
            qsize += len(nextnodes) - 1

        # choose nbest paths, back trace them
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(topk)]

        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.wordid)
            # back trace
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.wordid)

            utterance = utterance[::-1]
            utterances.append(utterance)

        decoded_batch.append(utterances)

    return decoded_batch

sent1 = ["<sos>"] + SRC_TEXT.preprocess("call the options.get method with string 'CULL_FREQUENCY' and integer 3 as arguments, use the string 'cull_frequency' and previous result as the arguments for the call to the params.get method, substitute the result for cull_frequency.") + ["<eos>"] + ["<pad>"]
# sent2 = ["<sos>"] + SRC_TEXT.preprocess("if not,") + ["<eos>"]
src_ids = SRC_TEXT.numericalize([sent1], device=device)
# print("input ids:", src_ids)
outs = beam_decode(model, encoder_states=src_ids)

for b in outs:
    for sent in b:
        print([TGT_TEXT.vocab.itos[id] for id in sent])

['<sos>', 'ZeroDivisionError', 'expire', 'write', 'StringIO', 'formatstr', 'def', 'GMT', 'GMT', 'formatstr', 'permitted', 'LazyDescr', 'debugging', 'DebugLexer', 'formatstr', 'formatstr', 'DebugLexer', 'bool', 'Tags', 'zh', 'GMT', 'Tags', 'formatstr', 'DebugLexer', 'never', 'formatstr', 'xmlutils', 'lang', 'permitted', 'GMT', 'GMT', 'Negative']
['<sos>', 'ZeroDivisionError', 'expire', 'write', 'StringIO', 'formatstr', 'def', 'GMT', 'GMT', 'formatstr', 'permitted', 'LazyDescr', 'debugging', 'DebugLexer', 'formatstr', 'formatstr', 'DebugLexer', 'bool', 'Tags', 'zh', 'GMT', 'Tags', 'formatstr', 'DebugLexer', 'never', 'formatstr', 'xmlutils', 'lang', 'permitted', 'GMT', 'GMT', 'formatstr']
['<sos>', 'ZeroDivisionError', 'expire', 'write', 'StringIO', 'formatstr', 'def', 'GMT', 'GMT', 'formatstr', 'permitted', 'LazyDescr', 'debugging', 'DebugLexer', 'formatstr', 'formatstr', 'DebugLexer', 'bool', 'Tags', 'zh', 'GMT', 'Tags', 'formatstr', 'DebugLexer', 'never', 'formatstr', 'xmlutils', 'lang

In [205]:
s = "for every log in existing ,"
sent1 = ["<sos>"] + SRC_TEXT.preprocess(s) + ["<eos>"]
src_ids = SRC_TEXT.numericalize([sent1], device=device)

decode_ids = SRC_TEXT.numericalize([['<sos>', 'self', '.', 'name']], device=device)

output = model(src_ids, decode_ids)
print(output)
print([TGT_TEXT.vocab.itos[f] for f in output.argmax(dim=-1).view(-1)])

tensor([[[ 3.1062,  0.5920, -1.4175,  ...,  0.2437,  0.2524, -1.2225]],

        [[ 2.4309,  0.8538, -1.8844,  ...,  0.6236, -0.1974, -1.6820]],

        [[ 2.7053,  0.8010, -0.7450,  ...,  1.4385, -0.7941, -1.2559]],

        [[ 1.7970,  0.0554, -0.6959,  ...,  0.8729,  0.0657, -1.1797]]],
       device='cuda:0', grad_fn=<AddBackward0>)
['def', 'self', 'raise', '(']


In [172]:
output = greedy_decode_batch_ids(src_ids, max_seq_length=20)
print([TGT_TEXT.vocab.itos[f] for f in output.view(-1)])

['<sos>', 'def', '=', 'self', '.', '_', 'name', '(', 'self', ',', '*', ',', '*', ')', ':', '<eos>', ',', '*', ')', ':', '<eos>']


In [197]:
outputs = beam_search.beam_search_decode(model,TGT_TEXT,
                              batch_encoder_ids=src_ids,
                              SOS_token=SOS_token,
                              EOS_token=EOS_token,
                              PAD_token=PAD_token,
                              beam_size=3,
                              max_length=20,
                              num_out=1)

for out in outputs:
    for sent in out:
        print([TGT_TEXT.vocab.itos[id] for id in sent.view(-1).cpu().tolist()])
    print()

STEP
0 -0.0 <sos>
STEP
0 tensor(-8.8836, device='cuda:0', grad_fn=<NegBackward>) def
0 tensor(-8.3951, device='cuda:0', grad_fn=<NegBackward>) if
0 tensor(-8.2147, device='cuda:0', grad_fn=<NegBackward>) self
STEP
0 tensor(-19.3050, device='cuda:0', grad_fn=<NegBackward>) .
0 tensor(-17.2590, device='cuda:0', grad_fn=<NegBackward>) =
0 tensor(-16.6746, device='cuda:0', grad_fn=<NegBackward>) =
STEP
0 tensor(-25.6465, device='cuda:0', grad_fn=<NegBackward>) _
0 tensor(-25.6335, device='cuda:0', grad_fn=<NegBackward>) =
0 tensor(-25.1448, device='cuda:0', grad_fn=<NegBackward>) name
STEP
0 tensor(-34.4153, device='cuda:0', grad_fn=<NegBackward>) =
0 tensor(-33.5966, device='cuda:0', grad_fn=<NegBackward>) _
0 tensor(-32.6721, device='cuda:0', grad_fn=<NegBackward>) (
STEP
0 tensor(-41.5725, device='cuda:0', grad_fn=<NegBackward>) =
0 tensor(-40.9403, device='cuda:0', grad_fn=<NegBackward>) self
0 tensor(-40.8322, device='cuda:0', grad_fn=<NegBackward>) name
STEP
0 tensor(-51.1459, device

In [250]:
import beam_search
import importlib
importlib.reload(beam_search)

sent1 = ["<sos>"] + SRC_TEXT.preprocess("for every log in existing") + ["<eos>"] + ["<pad>"]
sent2 = ["<sos>"] + SRC_TEXT.preprocess("for every log in existing ,") + ["<eos>"]
src_ids = SRC_TEXT.numericalize([sent2], device=device)

SOS_token = TGT_TEXT.vocab.stoi["<sos>"]
EOS_token = TGT_TEXT.vocab.stoi["<eos>"]
PAD_token = TGT_TEXT.vocab.stoi["<pad>"]

outputs = beam_search.beam_search_decode(model,TGT_TEXT,
                              batch_encoder_ids=src_ids,
                              SOS_token=SOS_token,
                              EOS_token=EOS_token,
                              PAD_token=PAD_token,
                              beam_size=4,
                              max_length=20,
                              num_out=5)

for out in outputs:
    for sent in out:
        print([TGT_TEXT.vocab.itos[id] for id in sent.view(-1).cpu().tolist()])
    print()

FOUND, -18.929054260253906
FOUND, -18.690196990966797
FOUND, -23.49768829345703
FOUND, -23.96718978881836
FOUND, -23.51727294921875
['<sos>', 'if', 'self', '.', 'name', '_', 'name', '(', 'self', ')', '<eos>']
['<sos>', 'if', 'self', '.', 'name', '_', 'name', '(', 'self', ')', ':', '<eos>']
['<sos>', 'def', '=', 'self', '.', 'name', '_', 'name', '(', 'self', ',', 'name', ')', '<eos>']
['<sos>', 'def', '=', 'self', '.', 'name', '_', 'name', '(', 'self', ',', 'value', ')', '<eos>']
['<sos>', 'def', '=', 'self', '.', 'name', '_', 'name', '(', 'self', ',', 'name', ')', ':', '<eos>']



In [23]:
def nltk_bleu(refrence, prediction):
    """
    Implementation from ReCode
    and moses multi belu script sets BLEU to 0.0 if len(toks) < 4
    """
    ngram_weights = [0.25] * min(4, len(refrence))
    return sentence_bleu([refrence], prediction, weights=ngram_weights, 
                          smoothing_function=SmoothingFunction().method3)

nltk_bleu(np.array([1,2,3,4,5,6]), np.array([1,2,5,6]))

0.2740311596835683

In [82]:
def evaluate(beam_size=1):
    model.eval() # Turn on the evaluation mode
    total_loss = 0.
    with torch.no_grad():
        sources = []
        results = []
        targets = []
        BLEU_scores = []
        for i, batch in enumerate(valid_iterator):
            encoder_inputs = batch.src
            target = batch.tgt
            if beam_size == 1:
                predictions = greedy_decode_batch_ids(encoder_inputs, max_seq_length=20)
                results += predictions.transpose(0,1).cpu().tolist()
            else:
                predictions = beam_decode(model, encoder_inputs)
                results += [sent[0] for sent in predictions]
            
            sources += encoder_inputs.transpose(0,1).cpu().tolist()
            targets += target.transpose(0,1).cpu().tolist()
            if i % 2 == 0:
                print("| EVALUATION | {:5d}/{:5d} batches |".format(i, len(valid_iterator)))
        
        for r_ids, target in zip(results, targets):
            eos_id = TGT_TEXT.vocab.stoi["<eos>"]
            eos_index = r_ids.index(eos_id) if eos_id in r_ids else None
            cut_ids = r_ids[:eos_index]
            filtered_ids = [id for id in cut_ids if id not in [0,1,2,3]]
            filtered_target_ids = [id for id in target if id not in [0,1,2,3]]
            BLEU_scores.append(nltk_bleu(filtered_target_ids, filtered_ids))
        
        with open("out.txt", "w") as out_fp:
            for source, result, target, BLEU in zip(sources, results, targets, BLEU_scores):
                eos_id = TGT_TEXT.vocab.stoi["<eos>"]
                eos_index = result.index(eos_id) if eos_id in result else None
                cut_ids = result[:eos_index]
                filtered_ids = [id for id in cut_ids if id not in [0,1,2,3]]
                filtered_target_ids = [id for id in target if id not in [0,1,2,3]]
                filtered_source_ids = [id for id in source if id not in [0,1,2,3]]
                
                out_fp.write("SRC  :" + " ".join([SRC_TEXT.vocab.itos[id] for id in filtered_source_ids]) + "\n")
                out_fp.write("TGT  :" + " ".join([TGT_TEXT.vocab.itos[id] for id in filtered_target_ids]) + "\n")
                out_fp.write("PRED :" + " ".join([TGT_TEXT.vocab.itos[id] for id in filtered_ids]) + "\n")
                out_fp.write("BLEU :" + str(BLEU) + "\n")
                out_fp.write("\n")
            out_fp.write("\n\n| EVALUATION | BLEU: {:5.2f} |\n".format(np.average(BLEU_scores)))
                
        print("| EVALUATION | BLEU: {:5.2f} |".format(np.average(BLEU_scores)))
        

In [24]:
def train_step(batch):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    tgt_vocab_size = len(TGT_TEXT.vocab.itos)
    encoder_input = batch.src
    decoder_input = batch.tgt[:-1]
    targets = batch.tgt[1:]

    optimizer.zero_grad()
    output = model(encoder_input, decoder_input)

    loss = criterion(output.view(-1, tgt_vocab_size), targets.view(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    elapsed = time.time() - start_time
    return loss

In [25]:
criterion = nn.CrossEntropyLoss(ignore_index=TGT_TEXT.vocab.stoi['<pad>'])
lr = 0.005 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.99)

In [26]:
def train(steps=10000, log_interval=200, learning_interval=4000, eval_interval=1000):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    step = 1
    for batch in train_iterator:
        loss = train_step(batch)
        total_loss += loss.item()
        
        if step % log_interval == 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} steps | '
                  'lr {:02.4f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    step, steps, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        
        if step % eval_interval == 0:
            print("Evaluating model")
            evaluate()
            model.train()
        
        if step % learning_interval == 0:
            scheduler.step()
        
        step += 1
        if step >= steps:
            print("Finished training")
            return

train(steps=1000000,eval_interval=8000,log_interval=200)

|   200/1000000 steps | lr 0.0050 | ms/batch 42.96 | loss  7.11 | ppl  1225.49
|   400/1000000 steps | lr 0.0050 | ms/batch 40.49 | loss  5.84 | ppl   344.99
|   600/1000000 steps | lr 0.0050 | ms/batch 42.05 | loss  5.66 | ppl   287.82
|   800/1000000 steps | lr 0.0050 | ms/batch 42.22 | loss  5.53 | ppl   252.22
|  1000/1000000 steps | lr 0.0050 | ms/batch 41.02 | loss  5.42 | ppl   225.42
|  1200/1000000 steps | lr 0.0050 | ms/batch 42.30 | loss  5.34 | ppl   209.01
|  1400/1000000 steps | lr 0.0050 | ms/batch 41.45 | loss  5.25 | ppl   189.83
|  1600/1000000 steps | lr 0.0050 | ms/batch 41.39 | loss  5.20 | ppl   181.06
|  1800/1000000 steps | lr 0.0050 | ms/batch 42.85 | loss  5.12 | ppl   167.92
|  2000/1000000 steps | lr 0.0050 | ms/batch 41.89 | loss  5.03 | ppl   153.41
|  2200/1000000 steps | lr 0.0050 | ms/batch 40.00 | loss  4.99 | ppl   147.23
|  2400/1000000 steps | lr 0.0050 | ms/batch 43.35 | loss  4.94 | ppl   139.81
|  2600/1000000 steps | lr 0.0050 | ms/batch 42.47 |

KeyboardInterrupt: 

In [79]:
torch.save(model.state_dict(), "./saved_model.pytorch")

In [80]:
src_vocab_size = len(SRC_TEXT.vocab.itos)
tgt_vocab_size = len(TGT_TEXT.vocab.itos)

model = TransformerModel(src_vocab_size, tgt_vocab_size, dropout=0.2).to(device) 
model.load_state_dict(torch.load("./saved_model.pytorch"))
# model.eval()

<All keys matched successfully>

In [103]:
evaluate(beam_size=1)

| EVALUATION |     0/   59 batches |
| EVALUATION |     2/   59 batches |
| EVALUATION |     4/   59 batches |
| EVALUATION |     6/   59 batches |
| EVALUATION |     8/   59 batches |
| EVALUATION |    10/   59 batches |
| EVALUATION |    12/   59 batches |
| EVALUATION |    14/   59 batches |
| EVALUATION |    16/   59 batches |
| EVALUATION |    18/   59 batches |
| EVALUATION |    20/   59 batches |
| EVALUATION |    22/   59 batches |
| EVALUATION |    24/   59 batches |
| EVALUATION |    26/   59 batches |
| EVALUATION |    28/   59 batches |
| EVALUATION |    30/   59 batches |
| EVALUATION |    32/   59 batches |
| EVALUATION |    34/   59 batches |
| EVALUATION |    36/   59 batches |
| EVALUATION |    38/   59 batches |
| EVALUATION |    40/   59 batches |
| EVALUATION |    42/   59 batches |
| EVALUATION |    44/   59 batches |
| EVALUATION |    46/   59 batches |
| EVALUATION |    48/   59 batches |
| EVALUATION |    50/   59 batches |
| EVALUATION |    52/   59 batches |
|

### Evaluating one sample

In [32]:
" ".join([SRC_TEXT.vocab.itos[i] for i in [ 2,21,83,13,10, 4, 5, 5,83, 4, 3]])

'<sos> substitute args for self . _ _ args . <eos>'

In [33]:
" ".join([TGT_TEXT.vocab.itos[i] for i in [ 2,12,5,4]])

'<sos> self . _'

In [117]:
def translate(s):
    src_ids = SRC_TEXT.numericalize([["<sos>"] + SRC_TEXT.preprocess(s) + ["<eos>"]], device=device)
#     src_ids = torch.tensor([ [2],[21],[83],[13],[10], [4], [5], [5],[83], [4], [3]], device=device)
    print("SRC ids shape:",src_ids)
    model.eval()
    with torch.no_grad():
        sos_id = TGT_TEXT.vocab.stoi["<sos>"]
#         decoder_input = torch.zeros((1, 1), dtype=torch.long, device=device).fill_(sos_id)
        decoder_input = torch.tensor(np.array([ [2]]), device=device)
#         print("Decoder input shape:", decoder_input.shape)
        
        for i in range(10):
#             print("Decoder input", decoder_input)
            output = model(src_ids, decoder_input)
#             print(model.tgt_mask)
#             print("output:", output)
#             print("predicted ids:", output.argmax(dim=-1))
            last_pred = output[-1:].argmax(dim=2)
#             decoder_input[i+1][0] = last_pred
#             print("last pred:", TGT_TEXT.vocab.itos[last_pred.cpu().numpy()[0][0]], last_pred.cpu().numpy()[0][0])
            print(TGT_TEXT.vocab.itos[last_pred.cpu().numpy()[0][0]],'', end = '')
            
            decoder_input = torch.cat((decoder_input, last_pred))
#             print("Decoder input", decoder_input)
#             break

translate("if PY3 is true ,")

SRC ids shape: tensor([[  2],
        [ 15],
        [533],
        [ 11],
        [ 53],
        [  6],
        [  3]], device='cuda:0')
if if if = None , name , name , 

In [37]:
np.array([torch.tensor([1.0]),torch.tensor([2.0])])

array([1., 2.], dtype=float32)

Moses Multi-BLEU perl script returns 0.0 for any sentence less than 4 tokens long.
It will be best to use a function by NLTK

In [39]:
get_moses_multi_bleu(["this is a test"], ["this is a for"])

0.0