# Copy-Generator Transformer

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
import numpy as np
import torchtext
import tqdm
from torchnlp.metrics import get_moses_multi_bleu
from torchtext.data import Field, BucketIterator
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

import tensorflow as tf
import tensorflow_datasets as tfds
from tokenize import tokenize, untokenize, NUMBER, STRING, NAME, OP
from io import BytesIO

import linecache
import sys
import os
import re
import random
import time
import operator
import collections
import json
from dotmap import DotMap

from base_transformer import TransformerModel, PositionalEncoding
from copy_gen_transformer import Transformer, TransformerDecoderLayer, TransformerDecoder
import beam_search
from IPython.core.debugger import set_trace as tr
%load_ext autoreload
%autoreload 2

## Load the config file

In [2]:
os.environ['CONFIG'] = "config_copy_gen_vcb1600_maxlen200.json"

In [3]:
try:
    config_fp = os.environ['CONFIG']
except:
    config_fp = "config.json"
    
with open(config_fp) as config_file:
    config = json.loads(config_file.read())
    config = DotMap(config)

In [4]:
from datetime import datetime
def super_print(filename):
    '''filename is the file where output will be written'''
    def wrap(func):
        '''func is the function you are "overriding", i.e. wrapping'''
        def wrapped_func(*args,**kwargs):
            '''*args and **kwargs are the arguments supplied 
            to the overridden function'''
            #use with statement to open, write to, and close the file safely
            with open(filename,'a', encoding="utf-8") as outputfile:
                now = datetime.now()
                dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
                outputfile.write("[{}] ".format(dt_string))
                outputfile.write(" ".join(str(x) for x in args))
                outputfile.write("\n")
            #now original function executed with its arguments as normal
            return func(*args,**kwargs)
        return wrapped_func
    return wrap

print = super_print(config.log_file_name)(print)
print("env CONFIG:",config_fp)
print(config)

env CONFIG: config_copy_gen_vcb1600_maxlen200.json
DotMap(max_seq_length=200, vocab_size=1600, log_file_name='logs_copy_gen_vcb1600_maxlen2000.txt', out_file_name='out_copy_gen_vcb1600_maxlen2000.txt', train_steps=1000000, train_learning_rate=0.005, src_file='./datasets/all-fixed.desc', tgt_file='./datasets/all.code', eval_interval=10000, log_interval=400, eval_beam_size=1, train_batch_size=32, eval_batch_size=32, model_layers=4, model_att_heads=8, model_embed_dim=512, model_dim_feedforward=1024, model_att_mask_noise=0.0)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    torch.cuda.set_device(0) # choose GPU from nvidia-smi 
print("Using:", device)

Using: cuda


In [8]:
text = "create variable student_names with string 'foo bar baz'"

def string_split(s):
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(_|\W)', s))) # this will chunk all code properly by plits strings with quotes
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|_|\W)', s))) # this keeps the strings intact
    return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|\W)', s)))

print(string_split(text))

['create', 'variable', 'student_names', 'with', 'string', "'foo bar baz'"]


In [7]:
def corpus_to_array(src_fp, tgt_fp):
    lines = []
    with open(src_fp, "r", encoding="utf-8") as src_file, open(tgt_fp, "r", encoding="utf-8") as tgt_file:
        for src, tgt in zip(src_file, tgt_file):
            lines.append((src, tgt))
    return lines

In [8]:
def filter_corpus(data, max_seq_length=200, tokenizer=string_split):
    return [(src, tgt) for src, tgt in data if len(string_split(src)) <= max_seq_length and len(string_split(tgt)) <= max_seq_length]

In [9]:
def samples_to_dataset(samples):
    """
    Args:
        samples: [(src_string),(tgt_string)]
        src/tgt_tokenizer: a func that takes a string and returns an array of strings
    """
    examples = []
    TEXT_FIELD = Field(sequential=True, use_vocab=False, init_token='<sos>',eos_token='<eos>')
    
    for sample in samples:
        src_string, tgt_string = sample
        examples.append(torchtext.data.Example.fromdict({"src":src_string, "tgt":tgt_string}, 
                                        fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD)}))
        
    dataset = torchtext.data.Dataset(examples,fields={"src":src_field, "tgt":tgt_field})
    return dataset

In [10]:
data = corpus_to_array(config.src_file, config.tgt_file)
# data = corpus_to_array("datasets/all.desc", "datasets/all.code")
random.shuffle(data)
print("Max src length:", max([len(string_split(src)) for src, tgt in data]))
print("Max tgt length:", max([len(string_split(tgt)) for src, tgt in data]))

print("Full dataset size:", len(data))
max_seq_length=config.max_seq_length
data = filter_corpus(data, max_seq_length=max_seq_length, tokenizer=string_split)
print("Limited dataset size:", len(data))

Max src length: 557
Max tgt length: 527
Full dataset size: 18805
Limited dataset size: 18797


In [11]:
inputs = [
          "my favourite foods are banana and toast",
          "my favourite foods are eggs and bacon and beans",
          "my favourite food is chocolate",
          "my favourite food is avocado"
]

outputs = [
           "would you like banana and toast ?",
           "would you like eggs and bacon and beans ?",
           "would you like chocolate ?",
           "would you like avocado ?"
]
# max_seq_length = 9
# data = list(zip(inputs, outputs))

In [12]:
stoi = {"<unk>":0, "<sos>":1, "<eos>":2, "<pad>":3}
max_vocab = config.vocab_size - len(stoi)

all_toks = []
for (src, tgt) in data:
    all_toks += string_split(src)
    all_toks += string_split(tgt)

most_freq = collections.Counter(all_toks).most_common(max_vocab)

for tok, count in most_freq:
    stoi[tok] = len(stoi)
    
itos = [k for k,v in sorted(stoi.items(), key=lambda kv: kv[1])]

In [13]:
%%capture
stoi

In [14]:
def encode_input(string):
    OOVs = []
    IDs = []
    words = string_split(string)
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            IDs.append(len(stoi) + len(OOVs))
            OOVs.append(word)
    return IDs, OOVs

In [15]:
def encode_output(string, OOVs):
    IDs = []
    words = string_split(string)
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            try:
                IDs.append(len(stoi) + OOVs.index(word))
            except ValueError as e:
                IDs.append(stoi["<unk>"])
    return IDs

In [16]:
def decode(ids, OOVs):
    extended_itos = itos.copy()
    extended_itos += [OOV+"(COPY)" for OOV in OOVs]
    return " ".join([extended_itos[id] for id in ids if id<len(extended_itos)])

In [17]:
TEXT_FIELD = Field(sequential=True, use_vocab=False, unk_token=0, init_token=1,eos_token=2, pad_token=3)
OOV_TEXT_FIELD = Field(sequential=True, use_vocab=False, pad_token=3)

OOV_stoi = {}
OOV_itos = {}
OOV_starter_count = 30000
OOV_count = OOV_starter_count

examples = []

for (src, tgt) in data:
    src_ids, OOVs = encode_input(src)
    tgt_ids = encode_output(tgt, OOVs)
    OOV_ids = []
    
    for OOV in OOVs:
        try:
            idx = OOV_stoi[OOV]
            OOV_ids.append(idx)
        except KeyError as e:
            OOV_count += 1
            OOV_stoi[OOV] = OOV_count
            OOV_itos[OOV_count] = OOV
            OOV_ids.append(OOV_count)
            
    examples.append(torchtext.data.Example.fromdict({"src":src_ids, "tgt":tgt_ids, "OOVs":OOV_ids}, 
                                                    fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD), "OOVs":("OOVs", OOV_TEXT_FIELD)}))

## For Hearthstone

In [29]:
test_data = corpus_to_array(config.eval_src_file, config.eval_tgt_file)
test_data = filter_corpus(test_data, max_seq_length=max_seq_length, tokenizer=string_split)

test_examples = []

for (src, tgt) in test_data:
    src_ids, OOVs = encode_input(src)
    tgt_ids = encode_output(tgt, OOVs)
    OOV_ids = []
    
    for OOV in OOVs:
        try:
            idx = OOV_stoi[OOV]
            OOV_ids.append(idx)
        except KeyError as e:
            OOV_count += 1
            OOV_stoi[OOV] = OOV_count
            OOV_itos[OOV_count] = OOV
            OOV_ids.append(OOV_count)
            
    test_examples.append(torchtext.data.Example.fromdict({"src":src_ids, "tgt":tgt_ids, "OOVs":OOV_ids}, 
                                                    fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD), "OOVs":("OOVs", OOV_TEXT_FIELD)}))

val_dataset = torchtext.data.Dataset(test_examples,fields={"src":TEXT_FIELD, "tgt":TEXT_FIELD, "OOVs":OOV_TEXT_FIELD})

In [18]:
dataset = torchtext.data.Dataset(examples,fields={"src":TEXT_FIELD, "tgt":TEXT_FIELD, "OOVs":OOV_TEXT_FIELD})
train_dataset, val_dataset = dataset.split([0.9,0.1])
# train_dataset = dataset

In [19]:
batch_size = config.train_batch_size

train_iterator = BucketIterator(
    train_dataset,
    batch_size = batch_size,
    repeat=True,
    shuffle=True,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

# The iterator generates batches with padded length for sequences with similar sizes, a batch is [seq_length, batch_size]

for i, batch in enumerate(train_iterator):
    idx = 3
#     print([SRC_TEXT.vocab.itos[id] for id in batch.src.cpu().numpy()[:,idx]])
    OOVs = [OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] # 3 is the <pad> token
    src_ids = batch.src.cpu()[:,idx].tolist()
    src_ids = src_ids[:src_ids.index(2)+1]
    tgt_ids = batch.tgt.cpu()[:,idx].tolist()
    tgt_ids = tgt_ids[:tgt_ids.index(2)+1]
    
    print(batch.src.shape)
    print(batch.tgt.shape)
    
    print("SOURCE:",decode(src_ids, OOVs))
    print()
    print("TARGET:",decode(tgt_ids, OOVs))
    break

torch.Size([43, 32])
torch.Size([24, 32])
SOURCE: <sos> if not , <eos>

TARGET: <sos> else : <eos>


# foo bar

In [20]:
class CopyModel(nn.Module):

    def __init__(self, src_vocab_size, tgt_vocab_size, dropout=0.5):
        super(CopyModel, self).__init__()
        self.model_type = 'Transformer'
        
        self.embedding_size = config.model_embed_dim
        self.pos_encoder = PositionalEncoding(self.embedding_size, dropout)
        self.src_encoder = nn.Embedding(src_vocab_size, self.embedding_size)
        self.tgt_encoder = nn.Embedding(tgt_vocab_size, self.embedding_size)
        
        self.transformer = Transformer(d_model=config.model_embed_dim, 
                                       nhead=config.model_att_heads, 
                                       num_encoder_layers=config.model_layers, 
                                       num_decoder_layers=config.model_layers, 
                                       dim_feedforward=config.model_dim_feedforward)
        self.decoder = nn.Linear(self.embedding_size, tgt_vocab_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.p_generator = nn.Linear(config.model_embed_dim,1)

        self.init_weights()
        self.tgt_mask = None

    def init_weights(self):
        initrange = 0.1
        self.src_encoder.weight.data.uniform_(-initrange, initrange)
        self.tgt_encoder.weight.data.uniform_(-initrange, initrange)
        
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def _generate_square_subsequent_mask(self, sz):
#         noise_e = 0.05 if self.training else 0.0 # this is code to add noise to the decoding process during training
        noise_e = config.model_att_mask_noise if self.training else 0.0
        noise_mask = (torch.rand(sz,sz) > noise_e).float()

        mask = (torch.triu(torch.ones(sz,sz))).transpose(0, 1)
        mask = torch.mul(mask, noise_mask)
        v = (torch.sum(mask, dim=-1) == 0).float()

        fix_mask = torch.zeros(sz,sz)
        fix_mask[:,0] = 1.0
        v = v.repeat(sz, 1).transpose(0,1)
        fix_mask = torch.mul(fix_mask,v)

        mask += fix_mask
        
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt):
        self.tgt_mask = self._generate_square_subsequent_mask(len(tgt)).to(self.device)
        

        src_emb = self.src_encoder(src) * math.sqrt(self.embedding_size)
        src_emb = self.pos_encoder(src_emb)
        
        tgt_emb = self.tgt_encoder(tgt) * math.sqrt(self.embedding_size)
        tgt_emb = self.pos_encoder(tgt_emb)
        
        output, atts = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)
        
        
        src_scat = src.transpose(0,1)
        src_scat = src_scat.unsqueeze(0)
        src_scat = torch.repeat_interleave(src_scat, tgt.shape[0], dim=0)
#         print("src_scat.shqape", src_scat.shape)
        
        p_gens = self.p_generator(output).sigmoid()
        atts = atts.transpose(0,1)
#         print("att.shqape", atts.shape)
        atts = atts * (1 - p_gens)
                
        output = self.decoder(output)
#         output[:,:,12:] = -np.inf
        output = output.softmax(-1)
        output = output * p_gens
        
        output = output.scatter_add_(2,src_scat,atts)
        output = output
        
        return output.log()

In [21]:
vocab_size = len(itos) + max_seq_length

model = CopyModel(vocab_size,vocab_size).to(device) 
src = torch.randint(0, vocab_size, (3,2)).to(device)
tgt = torch.randint(0, vocab_size, (5,2)).to(device)

out = model(src, tgt)
out.shape

torch.Size([5, 2, 1800])

In [22]:
def nltk_bleu(refrence, prediction):
    """
    Implementation from ReCode
    and moses multi belu script sets BLEU to 0.0 if len(toks) < 4
    """
    ngram_weights = [0.25] * min(4, len(refrence))
    return sentence_bleu([refrence], prediction, weights=ngram_weights, 
                          smoothing_function=SmoothingFunction().method3)

In [33]:
valid_iterator = BucketIterator(val_dataset,
    batch_size = config.eval_batch_size,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

def batch_filter_ids(batch_list):
    return [[id for id in l if id not in [1,2,3]] for l in batch_list]

def evaluate(beam_size=1, log=False):
    model.eval() # Turn on the evaluation mode
    with torch.no_grad(), open(config.out_file_name, "w", encoding="utf-8") as out_fp:
        BLEU_scores = []
        for i, batch in enumerate(valid_iterator):
            batch_size = batch.src.shape[1]
            
            encoder_inputs = batch.src
            predictions = beam_search.beam_search_decode(model,
                              batch_encoder_ids=encoder_inputs,
                              SOS_token=stoi["<sos>"],
                              EOS_token=stoi["<eos>"],
                              PAD_token=stoi["<pad>"],
                              beam_size=beam_size,
                              max_length=30,
                              num_out=1)
            
            sources = encoder_inputs.transpose(0,1).cpu().tolist()
            sources = batch_filter_ids(sources)
            
            predictions = [t[0].view(-1).cpu().tolist() for t in predictions]
            predictions = batch_filter_ids(predictions)
            
            targets = batch.tgt.transpose(0,1).cpu().tolist()
            targets = batch_filter_ids(targets)
            
#             print(batch.tgt)
            
            OOVss = [[OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] for idx in range(batch_size)]
            
            if i % int(len(valid_iterator)/3) == 0:
                print("| EVALUATION | {:5d}/{:5d} batches |".format(i, len(valid_iterator)))
            
            for j in range(batch_size):
                BLEU = nltk_bleu(targets[j], predictions[j])
                BLEU_scores.append(BLEU)
                
                out_fp.write("SRC  :" + decode(sources[j], OOVss[j]) + "\n")
                out_fp.write("TGT  :" + decode(targets[j], OOVss[j]) + "\n")
                out_fp.write("PRED :" + decode(predictions[j], OOVss[j]) + "\n")
                out_fp.write("BLEU :" + str(BLEU) + "\n")
                out_fp.write("\n")
                
                if log:
                    print("SRC  :" + decode(sources[j], OOVss[j]))
                    print("TGT  :" + decode(targets[j], OOVss[j]))
                    print("PRED :" + decode(predictions[j], OOVss[j]))
                    print("BLEU :" + str(BLEU))
                    print()
        out_fp.write("\n\n| EVALUATION | BLEU: {:5.2f} |\n".format(np.average(BLEU_scores)))
        print("| EVALUATION | BLEU: {:5.3f} |".format(np.average(BLEU_scores)))

In [35]:
evaluate(beam_size=1)

| EVALUATION |     0/   59 batches |
| EVALUATION |    19/   59 batches |
| EVALUATION |    38/   59 batches |
| EVALUATION |    57/   59 batches |
| EVALUATION | BLEU: 0.002 |


In [36]:
def train_step(batch):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    tgt_vocab_size = len(itos) + max_seq_length
    encoder_input = batch.src
    decoder_input = batch.tgt[:-1]
    targets = batch.tgt[1:]

    optimizer.zero_grad()
    output = model(encoder_input, decoder_input)

    loss = criterion(output.view(-1, tgt_vocab_size), targets.view(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    elapsed = time.time() - start_time
    return loss

In [37]:
criterion = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'])
lr = config.train_learning_rate # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.99)

In [None]:
def train(steps=10000, log_interval=200, learning_interval=4000, eval_interval=1000):
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    step = 1
    for batch in train_iterator:
        loss = train_step(batch)
        total_loss += loss.item()
        
        if step % log_interval == 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} steps | '
                  'lr {:02.4f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    step, steps, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        
        if step % eval_interval == 0:
            print("Evaluating model")
            evaluate()
            model.train()
        
        if step % learning_interval == 0:
            scheduler.step()
        
        step += 1
        if step >= steps:
            print("Finished training")



            return

train(steps=config.train_steps,eval_interval=config.eval_interval,log_interval=config.log_interval)

|   400/1000000 steps | lr 0.0050 | ms/batch 35.83 | loss  3.30 | ppl    27.23
|   800/1000000 steps | lr 0.0050 | ms/batch 36.13 | loss  3.21 | ppl    24.78
|  1200/1000000 steps | lr 0.0050 | ms/batch 35.99 | loss  3.11 | ppl    22.41
|  1600/1000000 steps | lr 0.0050 | ms/batch 36.19 | loss  3.04 | ppl    20.89
|  2000/1000000 steps | lr 0.0050 | ms/batch 35.77 | loss  2.97 | ppl    19.53
|  2400/1000000 steps | lr 0.0050 | ms/batch 36.25 | loss  2.91 | ppl    18.42
|  2800/1000000 steps | lr 0.0050 | ms/batch 36.08 | loss  2.87 | ppl    17.59
|  3200/1000000 steps | lr 0.0050 | ms/batch 35.84 | loss  2.81 | ppl    16.60
|  3600/1000000 steps | lr 0.0050 | ms/batch 36.37 | loss  2.78 | ppl    16.04
|  4000/1000000 steps | lr 0.0050 | ms/batch 36.27 | loss  2.74 | ppl    15.44
|  4400/1000000 steps | lr 0.0050 | ms/batch 37.16 | loss  2.68 | ppl    14.57
|  4800/1000000 steps | lr 0.0050 | ms/batch 37.01 | loss  2.65 | ppl    14.15
|  5200/1000000 steps | lr 0.0050 | ms/batch 36.40 |

# Save model

In [40]:
torch.save((model, optimizer, scheduler, stoi, itos), "./saved_copy_gen_DJANGO_vcb1600_maxSeq500_79BLEU.pytorch")