# Copy-Generator Transformer

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
import numpy as np
import torchtext
import tqdm
from torchnlp.metrics import get_moses_multi_bleu
from torchtext.data import Field, BucketIterator
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

import tensorflow as tf
import tensorflow_datasets as tfds
from tokenize import tokenize, untokenize, NUMBER, STRING, NAME, OP
from io import BytesIO

import linecache
import sys
import os
import re
import random
import time
import operator
import collections

from base_transformer import TransformerModel, PositionalEncoding
from copy_gen_transformer import Transformer, TransformerDecoderLayer, TransformerDecoder
import beam_search
from IPython.core.debugger import set_trace as tr
%load_ext autoreload
%autoreload 2

In [2]:
from datetime import datetime
def super_print(filename):
    '''filename is the file where output will be written'''
    def wrap(func):
        '''func is the function you are "overriding", i.e. wrapping'''
        def wrapped_func(*args,**kwargs):
            '''*args and **kwargs are the arguments supplied 
            to the overridden function'''
            #use with statement to open, write to, and close the file safely
            with open(filename,'a') as outputfile:
                now = datetime.now()
                dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
                outputfile.write("[{}] ".format(dt_string))
                outputfile.write(" ".join(str(x) for x in args))
                outputfile.write("\n")
            #now original function executed with its arguments as normal
            return func(*args,**kwargs)
        return wrapped_func
    return wrap

print = super_print('logs-copy-gen.txt')(print)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    torch.cuda.set_device(0) # choose GPU from nvidia-smi 
print("Using:", device)

Using: cuda


In [4]:
text = "create variable student_names with string 'foo bar baz'"

def string_split(s):
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(_|\W)', s))) # this will chunk all code properly by plits strings with quotes
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|_|\W)', s))) # this keeps the strings intact
    return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|\W)', s)))

print(string_split(text))

['create', 'variable', 'student_names', 'with', 'string', "'foo bar baz'"]


In [5]:
def corpus_to_array(src_fp, tgt_fp):
    lines = []
    with open(src_fp, "r") as src_file, open(tgt_fp, "r") as tgt_file:
        for src, tgt in zip(src_file, tgt_file):
            lines.append((src, tgt))
    return lines

In [6]:
def filter_corpus(data, max_seq_length=200, tokenizer=string_split):
    return [(src, tgt) for src, tgt in data if len(string_split(src)) <= max_seq_length and len(string_split(tgt)) <= max_seq_length]

In [7]:
def samples_to_dataset(samples):
    """
    Args:
        samples: [(src_string),(tgt_string)]
        src/tgt_tokenizer: a func that takes a string and returns an array of strings
    """
    examples = []
    TEXT_FIELD = Field(sequential=True, use_vocab=False, init_token='<sos>',eos_token='<eos>')
    
    for sample in samples:
        src_string, tgt_string = sample
        examples.append(torchtext.data.Example.fromdict({"src":src_string, "tgt":tgt_string}, 
                                        fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD)}))
        
    dataset = torchtext.data.Dataset(examples,fields={"src":src_field, "tgt":tgt_field})
    return dataset

In [8]:
data = corpus_to_array("datasets/all-fixed.desc", "datasets/all.code")
# data = corpus_to_array("datasets/all.desc", "datasets/all.code")
random.shuffle(data)
print("Max src length:", max([len(string_split(src)) for src, tgt in data]))
print("Max tgt length:", max([len(string_split(tgt)) for src, tgt in data]))

print("Full dataset size:", len(data))
max_seq_length=200
data = filter_corpus(data, max_seq_length=50, tokenizer=string_split)
print("Limited dataset size:", len(data))

Max src length: 557
Max tgt length: 527
Full dataset size: 18805
Limited dataset size: 18632


In [9]:
stoi = {"<unk>":0, "<sos>":1, "<eos>":2, "<pad>":3}
max_vocab = 1000 - len(stoi)

all_toks = []
for (src, tgt) in data:
    all_toks += string_split(src)
    all_toks += string_split(tgt)

most_freq = collections.Counter(all_toks).most_common(max_vocab)

for tok, count in most_freq:
    stoi[tok] = len(stoi)
    
itos = [k for k,v in sorted(stoi.items(), key=lambda kv: kv[1])]

In [10]:
def encode_input(string):
    OOVs = []
    IDs = []
    words = string_split(string)
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            IDs.append(len(stoi) + len(OOVs))
            OOVs.append(word)
    return IDs, OOVs

In [11]:
def encode_output(string, OOVs):
    IDs = []
    words = string_split(string)
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            try:
                IDs.append(len(stoi) + OOVs.index(word))
            except ValueError as e:
                IDs.append(stoi["<unk>"])
    return IDs

In [12]:
def decode(ids, OOVs):
    extended_itos = itos.copy()
    extended_itos += [OOV+"(COPY)" for OOV in OOVs]
    return " ".join([extended_itos[id] for id in ids if id<len(extended_itos)])

In [13]:
TEXT_FIELD = Field(sequential=True, use_vocab=False, unk_token=0, init_token=1,eos_token=2, pad_token=3)
OOV_TEXT_FIELD = Field(sequential=True, use_vocab=False, pad_token=3)

OOV_stoi = {}
OOV_itos = {}
OOV_starter_count = 30000
OOV_count = OOV_starter_count

examples = []

for (src, tgt) in data:
    src_ids, OOVs = encode_input(src)
    tgt_ids = encode_output(tgt, OOVs)
    OOV_ids = []
    
    for OOV in OOVs:
        try:
            idx = OOV_stoi[OOV]
            OOV_ids.append(idx)
        except KeyError as e:
            OOV_count += 1
            OOV_stoi[OOV] = OOV_count
            OOV_itos[OOV_count] = OOV
            OOV_ids.append(OOV_count)
            
    examples.append(torchtext.data.Example.fromdict({"src":src_ids, "tgt":tgt_ids, "OOVs":OOV_ids}, 
                                                    fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD), "OOVs":("OOVs", OOV_TEXT_FIELD)}))

In [14]:
dataset = torchtext.data.Dataset(examples,fields={"src":TEXT_FIELD, "tgt":TEXT_FIELD, "OOVs":OOV_TEXT_FIELD})
train_dataset, val_dataset = dataset.split([0.9,0.1])

In [15]:
batch_size = 16

train_iterator = BucketIterator(
    train_dataset,
    batch_size = batch_size,
    repeat=True,
    shuffle=True,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

# The iterator generates batches with padded length for sequences with similar sizes, a batch is [seq_length, batch_size]

for i, batch in enumerate(train_iterator):
    idx = 5
#     print([SRC_TEXT.vocab.itos[id] for id in batch.src.cpu().numpy()[:,idx]])
    OOVs = [OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] # 3 is the <pad> token
    src_ids = batch.src.cpu()[:,idx].tolist()
    src_ids = src_ids[:src_ids.index(2)+1]
    tgt_ids = batch.tgt.cpu()[:,idx].tolist()
    tgt_ids = tgt_ids[:tgt_ids.index(2)+1]
    
    print(batch.src.shape)
    print(batch.tgt.shape)
    
    print("SOURCE:",decode(src_ids, OOVs))
    print()
    print("TARGET:",decode(tgt_ids, OOVs))
    break

torch.Size([51, 16])
torch.Size([34, 16])
SOURCE: <sos> raise an ValueError exception with an argument string '' Unable(COPY) to configure(COPY) root logger : % s ' formated with e . <eos>

TARGET: <sos> raise ValueError ( <unk> <unk> % e ) <eos>


In [16]:
class CopyModel(nn.Module):

    def __init__(self, src_vocab_size, tgt_vocab_size, embedding_size=512, dropout=0.5):
        super(CopyModel, self).__init__()
        self.model_type = 'Transformer'
        
        self.embedding_size = embedding_size
        self.pos_encoder = PositionalEncoding(embedding_size, dropout)
        self.src_encoder = nn.Embedding(src_vocab_size, embedding_size)
        self.tgt_encoder = nn.Embedding(tgt_vocab_size, embedding_size)
        
        self.transformer = Transformer(d_model=embedding_size, nhead=8, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=1024)
        self.decoder = nn.Linear(embedding_size, tgt_vocab_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.p_generator = nn.Linear(embedding_size,1)

        self.init_weights()
        self.tgt_mask = None

    def init_weights(self):
        initrange = 0.1
        self.src_encoder.weight.data.uniform_(-initrange, initrange)
        self.tgt_encoder.weight.data.uniform_(-initrange, initrange)
        
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, tgt):
        self.tgt_mask = self.transformer.generate_square_subsequent_mask(len(tgt)).to(self.device)
        

        src_emb = self.src_encoder(src) * math.sqrt(self.embedding_size)
        src_emb = self.pos_encoder(src_emb)
        
        tgt_emb = self.tgt_encoder(tgt) * math.sqrt(self.embedding_size)
        tgt_emb = self.pos_encoder(tgt_emb)
        
        output, atts = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)
        
        
        src_scat = src.transpose(0,1)
        src_scat = src_scat.unsqueeze(0)
        src_scat = torch.repeat_interleave(src_scat, tgt.shape[0], dim=0)
        
        p_gens = self.p_generator(output)
        atts = atts.transpose(0,1)
        atts = atts * (1 - p_gens)
        
        output = self.decoder(output)
        output = output * p_gens
        
        output = output.scatter_add_(2,src_scat,atts)
        
        return output

In [17]:
vocab_size = len(itos) + max_seq_length

model = CopyModel(vocab_size,vocab_size).to(device) 
src = torch.randint(0, vocab_size, (3,2)).to(device)
tgt = torch.randint(0, vocab_size, (5,2)).to(device)

out = model(src, tgt)
out.shape

torch.Size([5, 2, 1200])

In [18]:
def nltk_bleu(refrence, prediction):
    """
    Implementation from ReCode
    and moses multi belu script sets BLEU to 0.0 if len(toks) < 4
    """
    ngram_weights = [0.25] * min(4, len(refrence))
    return sentence_bleu([refrence], prediction, weights=ngram_weights, 
                          smoothing_function=SmoothingFunction().method3)

In [19]:
valid_iterator = BucketIterator(val_dataset,
    batch_size = 128,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

def batch_filter_ids(batch_list):
    return [[id for id in l if id not in [0,1,2,3]] for l in batch_list]

def evaluate(beam_size=1):
    model.eval() # Turn on the evaluation mode
    with torch.no_grad(), open("out.txt", "w") as out_fp:
        BLEU_scores = []
        for i, batch in enumerate(valid_iterator):
            batch_size = batch.src.shape[1]
            
            encoder_inputs = batch.src
            predictions = beam_search.beam_search_decode(model,
                              batch_encoder_ids=encoder_inputs,
                              SOS_token=stoi["<sos>"],
                              EOS_token=stoi["<eos>"],
                              PAD_token=stoi["<pad>"],
                              beam_size=beam_size,
                              max_length=20,
                              num_out=1)
            
            sources = encoder_inputs.transpose(0,1).cpu().tolist()
            sources = batch_filter_ids(sources)
            
            predictions = [t[0].view(-1).cpu().tolist() for t in predictions]
            predictions = batch_filter_ids(predictions)
            
            targets = batch.tgt.transpose(0,1).cpu().tolist()
            targets = batch_filter_ids(targets)
            
            OOVss = [[OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] for idx in range(batch_size)]
            
            if i % int(len(valid_iterator)/3) == 0:
                print("| EVALUATION | {:5d}/{:5d} batches |".format(i, len(valid_iterator)))
            
            for j in range(batch_size):
                BLEU = nltk_bleu(targets[j], predictions[j])
                BLEU_scores.append(BLEU)
                
                out_fp.write("SRC  :" + decode(sources[j], OOVss[j]) + "\n")
                out_fp.write("TGT  :" + decode(targets[j], OOVss[j]) + "\n")
                out_fp.write("PRED :" + decode(predictions[j], OOVss[j]) + "\n")
                out_fp.write("BLEU :" + str(BLEU) + "\n")
                out_fp.write("\n")
        out_fp.write("\n\n| EVALUATION | BLEU: {:5.2f} |\n".format(np.average(BLEU_scores)))
        print("| EVALUATION | BLEU: {:5.2f} |".format(np.average(BLEU_scores)))

In [20]:
def train_step(batch):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    tgt_vocab_size = len(itos) + max_seq_length
    encoder_input = batch.src
    decoder_input = batch.tgt[:-1]
    targets = batch.tgt[1:]

    optimizer.zero_grad()
    output = model(encoder_input, decoder_input)

    loss = criterion(output.view(-1, tgt_vocab_size), targets.view(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    elapsed = time.time() - start_time
    return loss

In [21]:
criterion = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'])
lr = 0.005 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.99)

In [27]:
def train(steps=10000, log_interval=200, learning_interval=4000, eval_interval=1000):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    step = 1
    for batch in train_iterator:
        loss = train_step(batch)
        total_loss += loss.item()
        
        if step % log_interval == 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} steps | '
                  'lr {:02.4f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    step, steps, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        
        if step % eval_interval == 0:
            print("Evaluating model")
            evaluate()
            model.train()
        
        if step % learning_interval == 0:
            scheduler.step()
        
        step += 1
        if step >= steps:
            print("Finished training")
            return

train(steps=1000000,eval_interval=1000000,log_interval=200)

|    50/1000000 steps | lr 0.0050 | ms/batch 32.90 | loss  7.16 | ppl  1288.12
|   100/1000000 steps | lr 0.0050 | ms/batch 32.13 | loss  7.06 | ppl  1161.64
|   150/1000000 steps | lr 0.0050 | ms/batch 33.01 | loss  6.41 | ppl   608.08
|   200/1000000 steps | lr 0.0050 | ms/batch 32.26 | loss  5.78 | ppl   323.49
|   250/1000000 steps | lr 0.0050 | ms/batch 32.25 | loss  5.40 | ppl   221.94
|   300/1000000 steps | lr 0.0050 | ms/batch 32.32 | loss  5.21 | ppl   182.89
|   350/1000000 steps | lr 0.0050 | ms/batch 33.29 | loss  5.06 | ppl   156.89
|   400/1000000 steps | lr 0.0050 | ms/batch 32.82 | loss  4.98 | ppl   145.84
|   450/1000000 steps | lr 0.0050 | ms/batch 32.05 | loss  4.92 | ppl   137.39
|   500/1000000 steps | lr 0.0050 | ms/batch 33.32 | loss  4.89 | ppl   132.88


KeyboardInterrupt: 

# Evaluate

In [None]:
evaluate(beam_size=5)