In [None]:
import numpy as np
import logging
import os
import time
import numpy as np
import glob
import sys
import math
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from nltk import ngrams
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

from collections import Counter
import editdistance

In [None]:
CUDA = torch.cuda.is_available()
CUDA #Boolean variable to check the presence of CUDA

False

In [None]:
class CorpusSearcher(object):
    """
      To create corpuses and retreive attributes closest to the given query.

    """

    def __init__(self, query_corpus, key_corpus, value_corpus, vectorizer, make_binary=True):
        self.vectorizer = vectorizer
        self.vectorizer.fit(key_corpus)

        self.query_corpus = query_corpus
        self.key_corpus = key_corpus
        self.value_corpus = value_corpus
        
        # rows = docs, cols = features
        self.key_corpus_matrix = self.vectorizer.transform(key_corpus)
        if make_binary:
            # make binary
            self.key_corpus_matrix = (self.key_corpus_matrix != 0).astype(int)

        
    def most_similar(self, key_idx, n=10):
        """ 
          Score the query against the key corpus and return the values corresponding to the 
          top N scores from the value corpus.
          Used for retrieving attributes from sentences with similar content to the query sentence.

          Parameter:
          - key_idx : id of the attribute in the query corpus
          - n : number of closest attributes to be returned

          Returns:
          - selected : list of closest N values in the value corpus with their index and score 

        """

        query = self.query_corpus[key_idx]
        query_vec = self.vectorizer.transform([query])

        scores = np.dot(self.key_corpus_matrix, query_vec.T)
        scores = np.squeeze(scores.toarray())
        scores_indices = zip(scores, range(len(scores)))
        selected = sorted(scores_indices, reverse=True)[:n]

        # use the retrieved index 'i' to pick examples from the VALUE corpus
        selected = [ (self.value_corpus[i], i, score) for (score, i) in selected ]

        return selected

In [None]:
def build_vocab_maps(vocab_file):
    """
      Creates and returns two dictionaries, one to map vocabulary words to unique ids and 
      one to map the unique ids back to the vocabulary words.
      
      Parameters: 
      - vocab_file : os path to vocabulary file
      
      Returns:
      - tok_to_id : dictionary which fetches id from token key
      - id_to_tok : dictionary which fetches token from id key

    """

    assert os.path.exists(vocab_file), "The vocab file %s does not exist" % vocab_file
    unk = '<unk>'
    pad = '<pad>'
    sos = '<s>'
    eos = '</s>'

    lines = [x.strip() for x in open(vocab_file)]

    assert lines[0] == unk and lines[1] == pad and lines[2] == sos and lines[3] == eos, \
        "The first words in %s are not %s, %s, %s, %s" % (vocab_file, unk, pad, sos, eos)

    tok_to_id = {}
    id_to_tok = {}
    for i, vi in enumerate(lines):
        tok_to_id[vi] = i
        id_to_tok[i] = vi

    # appending an extra vocab item for empty attribute lines
    empty_tok_idx =  len(id_to_tok)
    tok_to_id['<empty>'] = empty_tok_idx
    id_to_tok[empty_tok_idx] = '<empty>'

    return tok_to_id, id_to_tok

In [None]:
def extract_attributes(line, attribute_vocab, use_ngrams=False):
    """
      Split the given sentence into its attribute markers and attribute-independent content.
      This is the 'DELETE' process in the paper

      Parameters:
      - line: the given sentence
      - attribute_vocab: the complete vocabulary of attributes
      - use_ngrams: boolean, True if ngrams of the sentence should be checked instead of individual words

      Returns:
      - line
      - content : attribute-independent content remaining after attributes are deleted
      - attribute_markers : list of attributes from the given sentence

    """

    if use_ngrams:
        # generate all ngrams for the sentence
        grams = []
        for i in range(1, 5):
            try:
                i_grams = [ " ".join(gram) for gram in ngrams(line, i) ]
                grams.extend(i_grams)
            except RuntimeError:
                continue

        # filter ngrams by whether they appear in the attribute_vocab
        candidate_markers = [ (gram, attribute_vocab[gram]) for gram in grams if gram in attribute_vocab ]

        # sort attribute markers by score and prepare for 'deletion'
        content = " ".join(line)
        candidate_markers.sort(key=lambda x: x[1], reverse=True)

        candidate_markers = [marker for (marker, score) in candidate_markers]
        
        # seperate attributes and attribute-independent content
        attribute_markers = []
        for marker in candidate_markers:
            if marker in content:
                attribute_markers.append(marker)
                content = content.replace(marker, "")
        content = content.split()
        
    else:
        # same thing, but without the use of ngrams
        content = []
        attribute_markers = []
        for tok in line:
            if tok in attribute_vocab:
                attribute_markers.append(tok)
            else:
                content.append(tok)

    return line, content, attribute_markers




In [None]:
def read_nmt_data(src, config, tgt, attribute_vocab, train_src=None, train_tgt=None, ngram_attributes=False):
    """
      Initializer function to read data from files and store it in 'src' and 'tgt'

      Parameters:
      - src, tgt : os paths to files containging source and target sentences respectively
      - config : contains path to vocab files
      - attribute_vocab : path to attribute vocabulary 
      - train_src, train_tgt : 
      - ngram_attributes : boolean, if True then attributes are ngrams instead of direct mappings

      Returns:
      - src, tgt: dictionaries containing source and target date respectively

    """
    
    if ngram_attributes:
        # read attribute vocab as a dictionary mapping attributes to scores
        pre_attr = {}
        post_attr = {}
        with open(attribute_vocab) as attr_file:
            next(attr_file) # skip header
            for line in attr_file:
                parts = line.strip().split()
                pre_salience = float(parts[-2])
                post_salience = float(parts[-1])
                attr = ' '.join(parts[:-2])
                pre_attr[attr] = pre_salience
                post_attr[attr] = post_salience
    else:
        pre_attr = post_attr = set([x.strip() for x in open(attribute_vocab)])

    src_lines = [l.strip().lower().split() for l in open(src, 'r')]
    src_lines, src_content, src_attribute = list(zip(*[extract_attributes(line, pre_attr, pre_attr) for line in src_lines]))
    src_tok2id, src_id2tok = build_vocab_maps(config['data']['src_vocab'])

    # during train time, just pick attributes that are close to the current (using word distance)
    # we don't need to do the TFIDF thing with the source because test is strictly in the src => tgt direction. 
    # But we still measure both src and tgt dist because training is bidirectional
    # (i.e., we're autoencoding src and tgt sentences during training)

    src_dist_measurer = CorpusSearcher(
        query_corpus=[' '.join(x) for x in src_attribute],
        key_corpus=[' '.join(x) for x in src_attribute],
        value_corpus=[' '.join(x) for x in src_attribute],
        vectorizer=CountVectorizer(vocabulary=src_tok2id),
        make_binary=True
    )
    src = {
        'data': src_lines, 'content': src_content, 'attribute': src_attribute,
        'tok2id': src_tok2id, 'id2tok': src_id2tok, 'dist_measurer': src_dist_measurer
    }

    tgt_lines = [l.strip().lower().split() for l in open(tgt, 'r')] if tgt else None
    tgt_lines, tgt_content, tgt_attribute = list(zip(*[extract_attributes(line, post_attr, post_attr) for line in tgt_lines]))
    tgt_tok2id, tgt_id2tok = build_vocab_maps(config['data']['tgt_vocab'])

    # during train time, just pick attributes that are close to the current (using word distance)
    # since this is only used to noise the inputs

    if train_src is None or train_tgt is None:
        tgt_dist_measurer = CorpusSearcher(
            query_corpus=[' '.join(x) for x in tgt_attribute],
            key_corpus=[' '.join(x) for x in tgt_attribute],
            value_corpus=[' '.join(x) for x in tgt_attribute],
            vectorizer=CountVectorizer(vocabulary=tgt_tok2id),
            make_binary=True
        )

    # during test time, scan through train content (using tfidf) and retrieve corresponding attributes
    
    else:
        tgt_dist_measurer = CorpusSearcher(
            query_corpus=[' '.join(x) for x in src_content],
            key_corpus=[' '.join(x) for x in train_tgt['content']],
            value_corpus=[' '.join(x) for x in train_tgt['attribute']],
            vectorizer=TfidfVectorizer(vocabulary=tgt_tok2id),
            make_binary=False
        )
    tgt = {
        'data': tgt_lines, 'content': tgt_content, 'attribute': tgt_attribute,
        'tok2id': tgt_tok2id, 'id2tok': tgt_id2tok, 'dist_measurer': tgt_dist_measurer
    }

    return src, tgt


In [None]:
def sample_replace(lines, dist_measurer, sample_rate, corpus_idx):
    """
      Replace sample_rate * batch_size lines with nearby examples (according to dist_measurer).
      This is not exactly the same as the paper (words are shared during train) but its essentially the same idea and easier to implement.

      Parameters:
      - lines : list of sentences
      - dist_measurer : object of CorpusSearcher
      - sample_rate : percentage of samples to be replaced with examples close to given one
      - corpus_idx : given sample

      Returns:
      - out: list of sentences after changing

    """

    out = [None for _ in range(len(lines))]
    for i, line in enumerate(lines):
        if random.random() < sample_rate:
            # ignore first line since top match is the current line
            sims = dist_measurer.most_similar(corpus_idx + i)[1:]
            
            try:
                line = next( (
                    tgt_attr.split() for tgt_attr, _, _ in sims
                    if set(tgt_attr.split()) != set(line[1:-1]) # and tgt_attr != ''
                ) )
            # all the matches are blanks
            except StopIteration:
                line = []
            line = ['<s>'] + line + ['</s>']

        # special empty token for empty sequences (just start/end tok)
        if len(line) == 2:
            line.insert(1, '<empty>')
        out[i] = line

    return out


In [None]:
def get_minibatch(lines, tok2id, index, batch_size, max_len, sort=False, idx=None, dist_measurer=None, sample_rate=0.0):
    """
      To prepare minibatch. No sorting since we care about the order of outputs. Also acts as a helper function to implement 'Retrieve' 

      Parameters:
      - lines : list of sentences
      - tok2id : dictionary which fetches id from token key
      - index : list of indices for which output is wanted
      - batch_size : size of the minibatch
      - max_len : maximum allowed length for sentences
      - sort : boolean, False by default
      - idx : optional; if passed the fn will only return values for those ids
      - dist_measurer : distance measuring function, used for sample_replace
      - sample rate : between 0 and 1

      Returns:
      - input_lines : tokenized input lines
      - output_lines : tokenized output lines
      - lens : list of lengths of sentences
      - mask : 0/1 masking for each sentence
      - idx : idx
    
    """

    lines = [
        ['<s>'] + line[:max_len] + ['</s>']
        for line in lines[index:index + batch_size]
    ]

    if dist_measurer is not None:
        lines = sample_replace(lines, dist_measurer, sample_rate, index)

    lens = [len(line) - 1 for line in lines]
    max_len = max(lens)

    unk_id = tok2id['<unk>']
    input_lines = [
        [tok2id.get(w, unk_id) for w in line[:-1]] +
        [tok2id['<pad>']] * (max_len - len(line) + 1)
        for line in lines
    ]

    output_lines = [
        [tok2id.get(w, unk_id) for w in line[1:]] +
        [tok2id['<pad>']] * (max_len - len(line) + 1)
        for line in lines
    ]

    mask = [
        ([1] * l) + ([0] * (max_len - l))
        for l in lens
    ]

    if sort:
        idx = [x[0] for x in sorted(enumerate(lens), key=lambda x: -x[1])]

    if idx is not None:
        lens = [lens[j] for j in idx]
        input_lines = [input_lines[j] for j in idx]
        output_lines = [output_lines[j] for j in idx]
        mask = [mask[j] for j in idx]

    input_lines = Variable(torch.LongTensor(input_lines))
    output_lines = Variable(torch.LongTensor(output_lines))
    mask = Variable(torch.FloatTensor(mask))

    if CUDA:
        input_lines = input_lines.cuda()
        output_lines = output_lines.cuda()
        mask = mask.cuda()

    return input_lines, output_lines, lens, mask, idx




In [None]:
def minibatch(src, tgt, idx, batch_size, max_len, model_type, is_test=False):
    """
      Fetches the inputs, outputs and attributes using get_minibatch depending on model type.

      Parameters:
      - src, tgt : source and target data dictionaries
      - idx : list of indices for which output is wanted
      - batch_size : size of the minibatch
      - max_len : maximum allowed length for each sentence
      - model_type: determines which model is followed:
          - 'delete' for DeleleOnly
          - 'delete_retrieve' for DeleteAndRetrieve
          - 'seq2seq' for TemplateBased

      Returns:
      - inputs : input_lines from get_minibatch
      - outputs : output_lines from get_minibatch
      - attributes : attributes generated

    """
    if not is_test:
        use_src = random.random() < 0.5
        in_dataset = src if use_src else tgt
        out_dataset = in_dataset
        attribute_id = 0 if use_src else 1
    else:
        in_dataset = src
        out_dataset = tgt
        attribute_id = 1

    if model_type == 'delete':
        inputs = get_minibatch(in_dataset['content'], in_dataset['tok2id'], idx, batch_size, max_len, sort=True)
        outputs = get_minibatch(out_dataset['data'], out_dataset['tok2id'], idx, batch_size, max_len, idx=inputs[-1])

        # since true length could be less than batch_size at end of data
        batch_len = len(outputs[0])
        attribute_ids = [attribute_id for _ in range(batch_len)]
        attribute_ids = Variable(torch.LongTensor(attribute_ids))
        if CUDA:
            attribute_ids = attribute_ids.cuda()

        attributes = (attribute_ids, None, None, None, None)

    elif model_type == 'delete_retrieve':
        inputs =  get_minibatch(in_dataset['content'], in_dataset['tok2id'], idx, batch_size, max_len, sort=True)
        outputs = get_minibatch(out_dataset['data'], out_dataset['tok2id'], idx, batch_size, max_len, idx=inputs[-1])

        if is_test:
            # This dist_measurer has sentence attributes for values, so setting 
            # the sample rate to 1 means the output is always replaced with an
            # attribute. So we're still getting attributes even though
            # the method is being fed content. 
            attributes =  get_minibatch(
                in_dataset['content'], out_dataset['tok2id'], idx, 
                batch_size, max_len, idx=inputs[-1],
                dist_measurer=out_dataset['dist_measurer'], sample_rate=1.0)
        else:
            attributes =  get_minibatch(
                out_dataset['attribute'], out_dataset['tok2id'], idx, 
                batch_size, max_len, idx=inputs[-1],
                dist_measurer=out_dataset['dist_measurer'], sample_rate=0.1)
            
        attributes = (None, None, None, None, None)

    else:
        raise Exception('Unsupported model_type: %s' % model_type)

    return inputs, attributes, outputs


def unsort(arr, idx):
    """
      Unsort a list given a list of each element's original index
    """
    unsorted_arr = arr[:]
    for i, origin in enumerate(idx):
        unsorted_arr[origin] = arr[i]
    return unsorted_arr

In [None]:
#Encoder
class Encoder(nn.Module):
    """ 
      Bi-directional LSTM to encode sentence+attributes.
    """

    def __init__(self, emb_dim, hidden_dim, layers, bidirectional, dropout, pack=True):
        super(Encoder, self).__init__()

        self.num_directions = 2 if bidirectional else 1

        self.lstm = nn.LSTM(
            emb_dim,
            hidden_dim // self.num_directions,
            layers,
            bidirectional=bidirectional,
            batch_first=True,
            dropout=dropout)

        self.pack = pack

    def init_state(self, input):
        batch_size = input.size(0) 
        h0 = Variable(torch.zeros(
            self.lstm.num_layers * self.num_directions,
            batch_size,
            self.lstm.hidden_size
        ), requires_grad=False)
        c0 = Variable(torch.zeros(
            self.lstm.num_layers * self.num_directions,
            batch_size,
            self.lstm.hidden_size
        ), requires_grad=False)

        if CUDA:
            return h0.cuda(), c0.cuda()
        else:
            return h0, c0


    def forward(self, src_embedding, srclens, srcmask, temp=1):
        h0, c0 = self.init_state(src_embedding)

        if self.pack:
            inputs = pack_padded_sequence(src_embedding, srclens, batch_first=True)
        else:
            inputs = src_embedding

        outputs, (h_final, c_final) = self.lstm(inputs, (h0, c0))

        if self.pack:
            outputs, _ = pad_packed_sequence(outputs, batch_first=True)

        return outputs, (h_final, c_final)


In [None]:
#Decoders
class BilinearAttention(nn.Module):
    """ 
      Bilinear attention layer: score(H_j, q) = H_j^T W_a q (where W_a = self.in_projection)
    """

    def __init__(self, hidden):
        super(BilinearAttention, self).__init__()
        self.in_projection = nn.Linear(hidden, hidden, bias=False)
        self.softmax = nn.Softmax()
        self.out_projection = nn.Linear(hidden * 2, hidden, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, query, keys, srcmask=None, values=None):
        """
            query: [batch, hidden]
            keys: [batch, len, hidden]
            values: [batch, len, hidden] (optional, if none will = keys)

            compare query to keys, use the scores to find weighted sum of values
            if no value is specified, then values = keys
        """
        
        if values is None:
            values = keys
    
        # [Batch, Hidden, 1]
        decoder_hidden = self.in_projection(query).unsqueeze(2)
        # [Batch, Source length]
        attn_scores = torch.bmm(keys, decoder_hidden).squeeze(2)
        if srcmask is not None:
            attn_scores = attn_scores.masked_fill(srcmask, -float('inf'))
            
        attn_probs = self.softmax(attn_scores)
        # [Batch, 1, source length]
        attn_probs_transposed = attn_probs.unsqueeze(1)
        # [Batch, hidden]
        weighted_context = torch.bmm(attn_probs_transposed, values).squeeze(1)

        context_query_mixed = torch.cat((weighted_context, query), 1)
        context_query_mixed = self.tanh(self.out_projection(context_query_mixed))

        return weighted_context, context_query_mixed, attn_probs


In [None]:
class AttentionalLSTM(nn.Module):
    """
      A LSTM cell with attention.
    """

    def __init__(self, input_dim, hidden_dim, config, attention):
        super(AttentionalLSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = 1
        self.use_attention = attention
        self.config = config
        self.cell = nn.LSTMCell(input_dim, hidden_dim)

        if self.use_attention:
            self.attention_layer = BilinearAttention(hidden_dim)


    def forward(self, input, hidden, ctx, srcmask, kb=None):
        input = input.transpose(0, 1)

        output = []
        timesteps = range(input.size(0))
        for i in timesteps:
            hy, cy = self.cell(input[i], hidden)
            if self.use_attention:
                _, h_tilde, alpha = self.attention_layer(hy, ctx, srcmask)
                hidden = h_tilde, cy
                output.append(h_tilde)
            else: 
                hidden = hy, cy
                output.append(hy)

        # combine outputs, and get into [time, batch, dim]
        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        output = output.transpose(0, 1)

        return output, hidden

In [None]:
class StackedAttentionLSTM(nn.Module):
    """ 
      Stacked LSTM with input feeding.
    """

    def __init__(self, cell_class=AttentionalLSTM, config=None):
        super(StackedAttentionLSTM, self).__init__()
        self.options=config['model']

        self.dropout = nn.Dropout(self.options['dropout'])

        self.layers = []
        input_dim = self.options['emb_dim']
        hidden_dim = self.options['tgt_hidden_dim']
        for i in range(self.options['tgt_layers']):
            layer = cell_class(input_dim, hidden_dim, config, config['model']['attention'])
            self.add_module('layer_%d' % i, layer)
            self.layers.append(layer)
            input_dim = hidden_dim

    def forward(self, input, hidden, ctx, srcmask, kb=None):
        h_final, c_final = [], []
        for i, layer in enumerate(self.layers):
            output, (h_final_i, c_final_i) = layer(input, hidden, ctx, srcmask, kb)

            input = output

            if i != len(self.layers):
                input = self.dropout(input)

            h_final.append(h_final_i)
            c_final.append(c_final_i)

        h_final = torch.stack(h_final)
        c_final = torch.stack(c_final)

        return input, (h_final, c_final)

In [None]:
def get_latest_ckpt(ckpt_dir):
    """
      Fetch latest checkpoint.
    """
    ckpts = glob.glob(os.path.join(ckpt_dir, '*.ckpt'))
    # if no checkpoints are found, continue with fresh parameters
    if len(ckpts) == 0:
        return -1, None
    ckpts = map(lambda ckpt: (int(ckpt.split('.')[1]), ckpt), ckpts)
    # get most recent checkpoint
    epoch, ckpt_path = sorted(ckpts)[-1]
    return epoch, ckpt_path


def attempt_load_model(model, checkpoint_dir=None, checkpoint_path=None):
    """
      Load model from latest checkpoint (get_latest_ckpt).
    """
    assert checkpoint_dir or checkpoint_path

    if checkpoint_dir:
        epoch, checkpoint_path = get_latest_ckpt(checkpoint_dir)
    else:
        epoch = int(checkpoint_path.split('.')[-2])

    if checkpoint_path:
        model.load_state_dict(torch.load(checkpoint_path))
        print('Load from %s sucessful!' % checkpoint_path)
        return model, epoch + 1
    else:
        return model, 0



In [None]:
class SeqModel(nn.Module):
    """
      Sequential Model
    """

    def __init__(self, src_vocab_size, tgt_vocab_size, pad_id_src, pad_id_tgt, config=None):
        super(SeqModel, self).__init__()
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.pad_id_src = pad_id_src
        self.pad_id_tgt = pad_id_tgt
        self.batch_size = config['data']['batch_size']
        self.config = config
        self.options = config['model']
        self.model_type = config['model']['model_type']

        self.src_embedding = nn.Embedding(self.src_vocab_size, self.options['emb_dim'], self.pad_id_src)

        if self.config['data']['share_vocab']:
            self.tgt_embedding = self.src_embedding
        else:
            self.tgt_embedding = nn.Embedding(
                self.tgt_vocab_size,
                self.options['emb_dim'],
                self.pad_id_tgt)

        if self.options['encoder'] == 'lstm':
            self.encoder = Encoder(
                self.options['emb_dim'],
                self.options['src_hidden_dim'],
                self.options['src_layers'],
                self.options['bidirectional'],
                self.options['dropout'])
            self.ctx_bridge = nn.Linear(
                self.options['src_hidden_dim'],
                self.options['tgt_hidden_dim'])

        else:
            raise NotImplementedError('unknown encoder type')
        
        if self.model_type == 'delete':
            self.attribute_embedding = nn.Embedding(num_embeddings=2, embedding_dim=self.options['emb_dim'])
            attr_size = self.options['emb_dim']

        elif self.model_type == 'delete_retrieve':
            self.attribute_encoder = Encoder(
                self.options['emb_dim'],
                self.options['src_hidden_dim'],
                self.options['src_layers'],
                self.options['bidirectional'],
                self.options['dropout'],
                pack=False)
            attr_size = self.options['src_hidden_dim']

        else:
            raise NotImplementedError('unknown model type')

        self.c_bridge = nn.Linear(attr_size + self.options['src_hidden_dim'], self.options['tgt_hidden_dim'])
        self.h_bridge = nn.Linear(attr_size + self.options['src_hidden_dim'], self.options['tgt_hidden_dim'])

        self.decoder = StackedAttentionLSTM(config=config)

        self.output_projection = nn.Linear(self.options['tgt_hidden_dim'], tgt_vocab_size)

        self.softmax = nn.Softmax(dim=-1)

        self.init_weights()

    def init_weights(self):
        """Initialize weights."""
        initrange = 0.1
        self.src_embedding.weight.data.uniform_(-initrange, initrange)
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)
        self.h_bridge.bias.data.fill_(0)
        self.c_bridge.bias.data.fill_(0)
        self.output_projection.bias.data.fill_(0)

    def forward(self, input_src, input_tgt, srcmask, srclens, input_attr, attrlens, attrmask):
        src_emb = self.src_embedding(input_src)

        srcmask = (1-srcmask).byte()

        src_outputs, (src_h_t, src_c_t) = self.encoder(src_emb, srclens, srcmask)

        if self.options['bidirectional']:
            h_t = torch.cat((src_h_t[-1], src_h_t[-2]), 1)
            c_t = torch.cat((src_c_t[-1], src_c_t[-2]), 1)
        else:
            h_t = src_h_t[-1]
            c_t = src_c_t[-1]

        src_outputs = self.ctx_bridge(src_outputs)

        if self.model_type == 'delete':
            a_ht = self.attribute_embedding(input_attr)
            c_t = torch.cat((c_t, a_ht), -1)
            h_t = torch.cat((h_t, a_ht), -1)

        elif self.model_type == 'delete_retrieve':
            attr_emb = self.src_embedding(input_attr)
            _, (a_ht, a_ct) = self.attribute_encoder(attr_emb, attrlens, attrmask)
            if self.options['bidirectional']:
                a_ht = torch.cat((a_ht[-1], a_ht[-2]), 1)
                a_ct = torch.cat((a_ct[-1], a_ct[-2]), 1)

            h_t = torch.cat((h_t, a_ht), -1)
            c_t = torch.cat((c_t, a_ct), -1)
            
        c_t = self.c_bridge(c_t)
        h_t = self.h_bridge(h_t)

        tgt_emb = self.tgt_embedding(input_tgt)
        tgt_outputs, (_, _) = self.decoder(tgt_emb, (h_t, c_t), src_outputs, srcmask)

        tgt_outputs_reshape = tgt_outputs.contiguous().view(
            tgt_outputs.size()[0] * tgt_outputs.size()[1],
            tgt_outputs.size()[2])
        decoder_logit = self.output_projection(tgt_outputs_reshape)
        decoder_logit = decoder_logit.view(
            tgt_outputs.size()[0],
            tgt_outputs.size()[1],
            decoder_logit.size()[1])

        probs = self.softmax(decoder_logit)

        return decoder_logit, probs

    def count_params(self):
        n_params = 0
        for param in self.parameters():
            n_params += np.prod(param.data.cpu().numpy().shape)
        return n_params

In [None]:
# BLEU functions

def bleu_stats(hypothesis, reference):
    """
      Compute statistics for BLEU.
    """
    
    stats = []
    stats.append(len(hypothesis))
    stats.append(len(reference))
    for n in range(1, 5):
        s_ngrams = Counter(
            [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)]
        )
        r_ngrams = Counter(
            [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)]
        )
        stats.append(max([sum((s_ngrams & r_ngrams).values()), 0]))
        stats.append(max([len(hypothesis) + 1 - n, 0]))
    return stats

def bleu(stats):
    """
      Compute BLEU given n-gram statistics.
    """
    if len(list(filter(lambda x: x == 0, stats))) > 0:
        return 0
    (c, r) = stats[:2]
    log_bleu_prec = sum([math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])]) / 4.
    return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec)

def get_bleu(hypotheses, reference):
    """
      Get validation BLEU score for dev set.
    """
    stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
    for hyp, ref in zip(hypotheses, reference):
        stats += np.array(bleu_stats(hyp, ref))
    return 100 * bleu(stats)

def get_edit_distance(hypotheses, reference):
    ed = 0
    for hyp, ref in zip(hypotheses, reference):
        ed += editdistance.eval(hyp, ref)

    return ed * 1.0 / len(hypotheses)


def decode_minibatch(max_len, start_id, model, src_input, srclens, srcmask, aux_input, auxlens, auxmask):
    """ 
      Decoding minibatch
    """
    # Initialize target with <s> for every sentence
    tgt_input = Variable(torch.LongTensor([[start_id] for i in range(src_input.size(0))]))
    if CUDA:
        tgt_input = tgt_input.cuda()

    for i in range(max_len):
        # run input through the model
        decoder_logit, word_probs = model(src_input, tgt_input, srcmask, srclens, aux_input, auxmask, auxlens)
        decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1)
        # select the predicted "next" tokens, attach to target-side inputs
        next_preds = Variable(torch.from_numpy(decoder_argmax[:, -1]))
        if CUDA:
            next_preds = next_preds.cuda()
        tgt_input = torch.cat((tgt_input, next_preds.unsqueeze(1)), dim=1)

    return tgt_input

def decode_dataset(model, src, tgt, config):
    """
      Evaluate model.
    """
    inputs = []
    preds = []
    auxs = []
    ground_truths = []

    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output = minibatch(
            src, tgt, j, 
            config['data']['batch_size'], 
            config['data']['max_len'], 
            config['model']['model_type'],
            is_test=True)
        input_lines_src, output_lines_src, srclens, srcmask, indices = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output

        tgt_pred = decode_minibatch(
            config['data']['max_len'], tgt['tok2id']['<s>'], 
            model, input_lines_src, srclens, srcmask,
            input_ids_aux, auxlens, auxmask)

        # convert seqs to tokens
        def ids_to_toks(tok_seqs, id2tok):
            out = []
            # take off the gpu
            tok_seqs = tok_seqs.cpu().numpy()
            # convert to toks, cut off at </s>, delete any start tokens (preds were kickstarted w them)
            for line in tok_seqs:
                toks = [id2tok[x] for x in line]
                if '<s>' in toks: 
                    toks.remove('<s>')
                cut_idx = toks.index('</s>') if '</s>' in toks else len(toks)
                out.append( toks[:cut_idx] )
            # unsort
            out = unsort(out, indices)
            return out

        # convert inputs/preds/targets/aux to human-readable form
        inputs += ids_to_toks(output_lines_src, src['id2tok'])
        preds += ids_to_toks(tgt_pred, tgt['id2tok'])
        ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok'])
        
        if config['model']['model_type'] == 'delete':
            auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy()] 
        elif config['model']['model_type'] == 'delete_retrieve':
            auxs += ids_to_toks(input_ids_aux, tgt['id2tok'])
        elif config['model']['model_type'] == 'seq2seq':
            auxs += ['None' for _ in range(len(tgt_pred))]

    return inputs, preds, ground_truths, auxs


def inference_metrics(model, src, tgt, config):
    """ 
      Decode and evaluate BLEU scores. 
    """

    inputs, preds, ground_truths, auxs = decode_dataset(
        model, src, tgt, config)

    bleu = get_bleu(preds, ground_truths)
    edit_distance = get_edit_distance(preds, ground_truths)

    inputs = [' '.join(seq) for seq in inputs]
    preds = [' '.join(seq) for seq in preds]
    ground_truths = [' '.join(seq) for seq in ground_truths]
    auxs = [' '.join(seq) for seq in auxs]

    return bleu, edit_distance, inputs, preds, ground_truths, auxs


def evaluate_lpp(model, src, tgt, config):
    """ 
      Evaluate log perplexity WITHOUT decoding (i.e., with teacher forcing)
    """
    
    weight_mask = torch.ones(len(tgt['tok2id']))
    if CUDA:
        weight_mask = weight_mask.cuda()
    weight_mask[tgt['tok2id']['<pad>']] = 0
    loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    if CUDA:
        loss_criterion = loss_criterion.cuda()

    losses = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output = minibatch(
            src, tgt, j, 
            config['data']['batch_size'], 
            config['data']['max_len'], 
            config['model']['model_type'],
            is_test=True)
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output

        decoder_logit, decoder_probs = model(
            input_lines_src, input_lines_tgt, srcmask, srclens,
            input_ids_aux, auxlens, auxmask)

        loss = loss_criterion(
            decoder_logit.contiguous().view(-1, len(tgt['tok2id'])),
            output_lines_tgt.view(-1)
        )
        losses.append(loss.item())

    return np.mean(losses)


In [None]:
Bleu = True
overfit = False

In [None]:
# config file which has our fine-tuned model
config = {
  "training": {
    "optimizer": "adam",
    "learning_rate": 0.0003,
    "max_norm": 3.0,
    "epochs": 45,
    "batches_per_report": 200,
    "batches_per_sampling": 500,
    "random_seed": 1
  },
  "data": {
    "src": "/content/drive/MyDrive/data/yelp/sentiment.train.0",
    "tgt": "/content/drive/MyDrive/data/yelp/sentiment.train.1",
    "src_test": "/content/drive/MyDrive/data/yelp/reference.test.0",
    "tgt_test": "/content/drive/MyDrive/data/yelp/reference.test.1",
    "src_vocab": "/content/drive/MyDrive/data/yelp/vocab",
    "tgt_vocab": "/content/drive/MyDrive/data/yelp/vocab",
    "share_vocab": True,
    "attribute_vocab": "/content/drive/MyDrive/data/yelp/ngram.15.attribute",
    "ngram_attributes": True,
    "batch_size": 256,
    "max_len": 50,
    "working_dir": "/content/drive/MyDrive/data/working_dir"
  },
    "model": {
        "model_type": "delete",
        "emb_dim": 128,
        "attention": False,
        "encoder": "lstm",
        "src_hidden_dim": 512,
        "src_layers": 1,
        "bidirectional": True,
        "tgt_hidden_dim": 512,
        "tgt_layers": 1,
        "decode": "greedy",
        "dropout": 0.2
    }
}


In [None]:
train_losses = []
scores = []

In [None]:
working_dir = config['data']['working_dir']

if not os.path.exists(working_dir):
    os.makedirs(working_dir)

# set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='%s/train_log' % working_dir,
)

console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

logging.info('Reading data ...')
src, tgt = read_nmt_data(
    src=config['data']['src'],
    config=config,
    tgt=config['data']['tgt'],
    attribute_vocab=config['data']['attribute_vocab'],
    ngram_attributes=config['data']['ngram_attributes']
)

src_test, tgt_test = read_nmt_data(
    src=config['data']['src_test'],
    config=config,
    tgt=config['data']['tgt_test'],
    attribute_vocab=config['data']['attribute_vocab'],
    ngram_attributes=config['data']['ngram_attributes'],
    train_src=src,
    train_tgt=tgt
)
logging.info('...done!')


batch_size = config['data']['batch_size']
max_length = config['data']['max_len']
src_vocab_size = len(src['tok2id'])
tgt_vocab_size = len(tgt['tok2id'])


weight_mask = torch.ones(tgt_vocab_size)
weight_mask[tgt['tok2id']['<pad>']] = 0
loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
if CUDA:
    weight_mask = weight_mask.cuda()
    loss_criterion = loss_criterion.cuda()

torch.manual_seed(config['training']['random_seed'])
np.random.seed(config['training']['random_seed'])

model = SeqModel(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    pad_id_src=src['tok2id']['<pad>'],
    pad_id_tgt=tgt['tok2id']['<pad>'],
    config=config
)

logging.info('MODEL HAS %s params' %  model.count_params())
model, start_epoch = attempt_load_model(
    model=model,
    checkpoint_dir=working_dir)
if CUDA:
    model = model.cuda()


if config['training']['optimizer'] == 'adam':
    lr = config['training']['learning_rate']
    optimizer = optim.Adam(model.parameters(), lr=lr)
elif config['training']['optimizer'] == 'sgd':
    lr = config['training']['learning_rate']
    optimizer = optim.SGD(model.parameters(), lr=lr)
elif config['training']['optimizer'] == 'adadelta':
    lr = config['training']['learning_rate']
    optimizer = optim.Adadelta(model.parameters(), lr=lr)
else:
    raise NotImplementedError("Learning method not recommend for task")

epoch_loss = []
start_since_last_report = time.time()

words_since_last_report = 0
losses_since_last_report = []
best_metric = 0.0
best_epoch = 0
cur_metric = 0.0 # log perplexity or BLEU
num_examples = min(len(src['content']), len(tgt['content']))
num_batches = num_examples / batch_size

STEP = 0
for epoch in range(start_epoch, config['training']['epochs']):
    if cur_metric > best_metric:
        # delete old checkpoint to save memory
        for ckpt_path in glob.glob(working_dir + '/model.*'):
            os.system("rm %s" % ckpt_path)
        # replace with new checkpoint
        torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch)

        best_metric = cur_metric
        best_epoch = epoch - 1

    losses = []
    for i in range(0, num_examples, batch_size):

        if overfit:
            i = 50

        batch_idx = i / batch_size

        input_content, input_aux, output = minibatch(
            src, tgt, i, batch_size, max_length, config['model']['model_type'])
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output
        
        decoder_logit, decoder_probs = model(
            input_lines_src, input_lines_tgt, srcmask, srclens,
            input_ids_aux, auxlens, auxmask)

        optimizer.zero_grad()

        loss = loss_criterion(
            decoder_logit.contiguous().view(-1, tgt_vocab_size),
            output_lines_tgt.view(-1)
        )

        losses.append(loss.item())
        losses_since_last_report.append(loss.item())
        epoch_loss.append(loss.item())
        loss.backward()
        norm = nn.utils.clip_grad_norm_(model.parameters(), config['training']['max_norm'])


        optimizer.step()

        if overfit or batch_idx % config['training']['batches_per_report'] == 0:

            s = float(time.time() - start_since_last_report)
            eps = (batch_size * config['training']['batches_per_report']) / s
            avg_loss = np.mean(losses_since_last_report)
            info = (epoch, batch_idx, num_batches, eps, avg_loss, cur_metric)
            logging.info('EPOCH: %s ITER: %s/%s EPS: %.2f LOSS: %.4f METRIC: %.4f' % info)
            start_since_last_report = time.time()
            words_since_last_report = 0
            losses_since_last_report = []

        STEP += 1

    logging.info('EPOCH %s COMPLETE. EVALUATING...' % epoch)
    start = time.time()
    model.eval()
    dev_loss = evaluate_lpp(
            model, src_test, tgt_test, config)

    if Bleu and epoch >= config['training'].get('inference_start_epoch', 1):
        cur_metric, edit_distance, _, preds, _, _ = inference_metrics(
            model, src_test, tgt_test, config)

        with open(working_dir + '/preds.%s' % epoch, 'w') as f:
            f.write('\n'.join(preds) + '\n')


    else:
        cur_metric = dev_loss

    model.train()

    logging.info('METRIC: %s. TIME: %.2fs CHECKPOINTING...' % (
        cur_metric, (time.time() - start)))
    avg_loss = np.mean(epoch_loss)
    train_losses.append(avg_loss)
    scores.append(cur_metric)
    epoch_loss = []


In [None]:
# Delete_retrieve model
import matplotlib.pyplot as plt
plt.plot(train_losses)
plt.title('DeleteAndRetrieve model -- train and dev losses')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

In [None]:
plt.plot(scores)
plt.title('scores')
plt.ylabel('DeleteAndRetrieve model -- scores')
plt.xlabel('epoch')
plt.show()

In [None]:
#Delete model
import matplotlib.pyplot as plt
plt.plot(train_losses)
plt.plot(dev_losses)
plt.title("Delete model -- train and dev losses")
plt.ylabel('losses')
plt.xlabel('epoch')
plt.show()

In [None]:
plt.plot(scores)
plt.title('Delete model -- scores')
plt.ylabel('scores')
plt.xlaebl('epoch')
plt.show()

In [None]:
!python eval.py