# Copy-Generator Transformer
This architecturre involves allowing the transformer model to generate a new token from a pre-existing vocabulary and also copy a word directly from the input.

An inportant note of this architecture is that since words need to be copied from source to target language, it is easiest to have a common vocabulary and thus share token IDs.

As a first step, the common words from the input to the output need to be identified. This is challenging since we want to have a vocabulary limit, allowing frequent tokens to be part of the vocabulary, and infrequent ones common to the input and output have a specific copyable token unique to the sentence.

### Imports

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
import numpy as np
import torchtext
import tqdm
from torchnlp.metrics import get_moses_multi_bleu
from torchtext.data import Field, BucketIterator
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

import tensorflow as tf
import tensorflow_datasets as tfds
from tokenize import tokenize, untokenize, NUMBER, STRING, NAME, OP
from io import BytesIO

import linecache
import sys
import os
import re
import random
import time
import operator
import collections

from base_transformer import TransformerModel
from IPython.core.debugger import set_trace as tr
%load_ext autoreload
%autoreload 2

Setting the device to use: CPU or GPU

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    torch.cuda.set_device(0) # choose GPU from nvidia-smi 
print("Using:", device)

Using: cuda


### Helper functions

In [3]:
text = "create variable student_names with string 'foo bar baz'"

def string_split(s):
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(_|\W)', s))) # this will chunk all code properly by plits strings with quotes
#     return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|_|\W)', s))) # this keeps the strings intact
    return list(filter(lambda x: x != '' and x != "\n" and not x.isspace(), re.split('(\\\'.*?\\\'|\\\".*?\\\"|\W)', s)))

print(string_split(text))

['create', 'variable', 'student_names', 'with', 'string', "'foo bar baz'"]


In [4]:
def corpus_to_array(src_fp, tgt_fp):
    lines = []
    with open(src_fp, "r") as src_file, open(tgt_fp, "r") as tgt_file:
        for src, tgt in zip(src_file, tgt_file):
            lines.append((src, tgt))
    return lines

In [5]:
def filter_corpus(data, max_seq_length=200, tokenizer=string_split):
    return [(src, tgt) for src, tgt in data if len(string_split(src)) <= max_seq_length and len(string_split(tgt)) <= max_seq_length]

In [6]:
def samples_to_dataset(samples):
    """
    Args:
        samples: [(src_string),(tgt_string)]
        src/tgt_tokenizer: a func that takes a string and returns an array of strings
    """
    examples = []
    TEXT_FIELD = Field(sequential=True, use_vocab=False, init_token='<sos>',eos_token='<eos>')
    
    for sample in samples:
        src_string, tgt_string = sample
        examples.append(torchtext.data.Example.fromdict({"src":src_string, "tgt":tgt_string}, 
                                        fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD)}))
        
    dataset = torchtext.data.Dataset(examples,fields={"src":src_field, "tgt":tgt_field})
    return dataset

In [34]:
data = corpus_to_array("datasets/all-fixed.desc", "datasets/all.code")
random.shuffle(data)
print("Max src length:", max([len(string_split(src)) for src, tgt in data]))
print("Max tgt length:", max([len(string_split(tgt)) for src, tgt in data]))

print("Full dataset size:", len(data))
max_seq_length=200
data = filter_corpus(data, max_seq_length=50, tokenizer=string_split)
print("Limited dataset size:", len(data))

Max src length: 557
Max tgt length: 527
Full dataset size: 18805
Limited dataset size: 18632


## Making a shared vocabulary
The idea of the copy generator network is to give the model a chance to copy words from the input to the output, some it might know already, but others might be completly unknown to it.

In [65]:
stoi = {"<unk>":0, "<sos>":1, "<eos>":2, "<pad>":3}
max_vocab = 1000 - len(stoi)

all_toks = []
for (src, tgt) in data:
    all_toks += string_split(src)
    all_toks += string_split(tgt)

most_freq = collections.Counter(all_toks).most_common(max_vocab)

for tok, count in most_freq:
    stoi[tok] = len(stoi)
    
itos = [k for k,v in sorted(stoi.items(), key=lambda kv: kv[1])]

In [66]:
def encode_input(string):
    OOVs = []
    IDs = []
    words = string_split(string)
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            IDs.append(len(stoi) + len(OOVs))
            OOVs.append(word)
    return IDs, OOVs

In [67]:
encode_input(text)

([639, 1000, 1001, 12, 29, 1002],
 ['variable', 'student_names', "'foo bar baz'"])

In [68]:
def encode_output(string, OOVs):
    IDs = []
    words = string_split(string)
    for word in words:
        try:
            id = stoi[word]
            IDs.append(id)
        except KeyError as e:
            # word is OOV
            try:
                IDs.append(len(stoi) + OOVs.index(word))
            except ValueError as e:
                IDs.append(stoi["<unk>"])
    return IDs

In [69]:
encode_output(text,['variable', 'student_names', "'foo bar baz'"])

[639, 1000, 1001, 12, 29, 1002]

In [54]:
def decode(ids, OOVs):
    extended_itos = itos.copy()
    extended_itos += [OOV+"(COPY)" for OOV in OOVs]
    return " | ".join([extended_itos[id] for id in ids])

In [70]:
decode([1,639, 1000, 1001, 12, 29, 1002,2], ['variable', 'student_names', "'foo bar baz'"])

"<sos> | create | variable(COPY) | student_names(COPY) | with | string | 'foo bar baz'(COPY) | <eos>"

In [71]:
TEXT_FIELD = Field(sequential=True, use_vocab=False, unk_token=0, init_token=1,eos_token=2, pad_token=3)
OOV_TEXT_FIELD = Field(sequential=True, use_vocab=False, pad_token=3)

OOV_stoi = {}
OOV_itos = {}
OOV_starter_count = 30000
OOV_count = OOV_starter_count

examples = []

for (src, tgt) in data:
    src_ids, OOVs = encode_input(src)
    tgt_ids = encode_output(tgt, OOVs)
    OOV_ids = []
    
    for OOV in OOVs:
        try:
            idx = OOV_stoi[OOV]
            OOV_ids.append(idx)
        except KeyError as e:
            OOV_count += 1
            OOV_stoi[OOV] = OOV_count
            OOV_itos[OOV_count] = OOV
            OOV_ids.append(OOV_count)
            
    examples.append(torchtext.data.Example.fromdict({"src":src_ids, "tgt":tgt_ids, "OOVs":OOV_ids}, 
                                                    fields={"src":("src",TEXT_FIELD), "tgt":("tgt",TEXT_FIELD), "OOVs":("OOVs", OOV_TEXT_FIELD)}))

In [72]:
dataset = torchtext.data.Dataset(examples,fields={"src":TEXT_FIELD, "tgt":TEXT_FIELD, "OOVs":OOV_TEXT_FIELD})
train_dataset, val_dataset = dataset.split([0.9,0.1])

In [74]:
batch_size = 16

train_iterator = BucketIterator(
    train_dataset,
    batch_size = batch_size,
    repeat=True,
    shuffle=True,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

valid_iterator = BucketIterator(val_dataset,
    batch_size = batch_size,
    sort_key = lambda x: len(x.src)+len(x.tgt),
    device = device)

# The iterator generates batches with padded length for sequences with similar sizes, a batch is [seq_length, batch_size]

for i, batch in enumerate(train_iterator):
    idx = 5
#     print([SRC_TEXT.vocab.itos[id] for id in batch.src.cpu().numpy()[:,idx]])
    OOVs = [OOV_itos[OOV] for OOV in batch.OOVs.cpu()[:,idx].tolist() if OOV != 3] # 3 is the <pad> token
    src_ids = batch.src.cpu()[:,idx].tolist()
    src_ids = src_ids[:src_ids.index(2)+1]
    tgt_ids = batch.tgt.cpu()[:,idx].tolist()
    tgt_ids = tgt_ids[:tgt_ids.index(2)+1]
    
    print("SOURCE:",decode(src_ids, OOVs))
    print("TARGET:",decode(tgt_ids, OOVs))
    break

SOURCE: <sos> | if | status | is | not | equal | to | STATUS_OK(COPY) | , | <eos>
TARGET: <sos> | if | status | ! | = | STATUS_OK(COPY) | : | <eos>
