In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np 
import time
import random
from models.rnnpb import RNNPB
from misc.plots import plot_pbs
from misc.dataset import Dataset
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
tf.enable_eager_execution()

In [None]:
batch_size = 64
vocab1_max = 15000
vocab2_max = 15000

data = Dataset("data/deu-tok.en", "data/deu-tok.de", batch_size, vocab1_max, vocab2_max, length_limit = 50, easy_subset=50000)

X, X_test, X_lengths, X_test_lengths, Y, Y_test, Y_lengths, Y_test_lengths = train_test_split(data.X, data.X_lengths, data.Y, data.Y_lengths, test_size = 5)

num_seqs = len(X)
ids = list(range(num_seqs))

train = tf.data.Dataset.from_tensor_slices((X, X_lengths, Y, Y_lengths, ids)).shuffle(len(X))
train = train.batch(batch_size, drop_remainder = True)

START = data.first.w2idx["<START>"]
END = data.first.w2idx["<END>"]

In [None]:
plt.hist(X_lengths, bins='auto', log=True);

In [None]:
plt.hist(Y_lengths, bins='auto', log=True);

In [None]:
def sentence_to_tensor(s, w2idx):
    s = [w2idx.get(w, w2idx["<UNK>"]) for w in s.split(" ")]
    x = tf.convert_to_tensor(s)
    return x

def tensor_to_sentence(x, idx2w):
    return " ".join(idx2w[i] for i in x)

def translate(A, B, x, B_idx2w, eps=0.0001, return_pb = False):
    pb = A.recognize(x, eps = eps)
    y = B.generate(pb, max_length = round(1.5 * len(x)), start=START, end=END)
    
    s = tensor_to_sentence(y, B_idx2w)
    
    return s, pb[0] if return_pb else s


In [None]:
embedding_dim = 64
units = 256
num_layers = 2
learning_rate = 0.001
recog_lr = 0.001
num_PB = 32
binding_strength = 1
gradient_clip = 1.0
bind_hard = True

warmup_epochs = 0
reset_pbs_every = 0 # epoch

In [None]:
A = RNNPB(vocab_size=data.first.vocab_size, embedding_dim=embedding_dim, units=units, num_layers=num_layers, num_PB=num_PB, num_sequences=num_seqs, recog_lr = recog_lr, gradient_clip=gradient_clip, batch_size=batch_size, bind_hard=bind_hard)
B = RNNPB(vocab_size=data.second.vocab_size, embedding_dim=embedding_dim, units=units, num_layers=num_layers, num_PB=num_PB, num_sequences=num_seqs, recog_lr = recog_lr, gradient_clip=gradient_clip, batch_size=batch_size, bind_hard=bind_hard)

In [None]:
# Training

optimizer = tf.train.AdamOptimizer(learning_rate)

epochs = 100

for epoch in range(epochs):
    print("====== EPOCH {} ======".format(epoch))
    start = time.time()
    epoch_loss = 0
    
    if epoch != 0 and reset_pbs_every != 0 and epoch % reset_pbs_every == 0:
        A.reset_pbs()
        B.reset_pbs()
    
    for (batch, (A_inp, A_lengths, B_inp, B_lengths,ids)) in enumerate(train):
        
        max_A_length = np.max(A_lengths)
        A_inp = A_inp[:, :max_A_length]

        max_B_length = np.max(B_lengths)
        B_inp = B_inp[:, :max_B_length]
        
        with tf.GradientTape() as tape:
            
            A_inputs = A_inp[:, :-1]
            A_targets = A_inp[:, 1:]
            
            B_inputs = B_inp[:, :-1]
            B_targets = B_inp[:, 1:]
            
            A_outputs, _, A_pb = A(A_inputs, ids)
            B_outputs, _, B_pb = B(B_inputs, ids)
            
            A_mask = 1 - np.equal(A_targets, 0)
            A_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = A_targets, logits=A_outputs) * A_mask
            
            B_mask = 1 - np.equal(B_targets, 0)
            B_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = B_targets, logits=B_outputs) * B_mask
            
            batch_loss = tf.reduce_mean(A_loss) + tf.reduce_mean(B_loss)
            
            bound_loss = batch_loss if bind == 'hard' else batch_loss + binding_strength * tf.nn.l2_loss(A_pb - B_pb)
        
        if epoch < warmup_epochs:
            variables = A.variables_expt_pb() + B.variables_expt_pb()
        else:
            variables = A.variables + B.variables
        
        gradients = tape.gradient(bound_loss, variables)
        
        gradients, _ = tf.clip_by_global_norm(gradients, gradient_clip)
        
        optimizer.apply_gradients(zip(gradients, variables))
        
        epoch_loss += batch_loss
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f} Bound loss {:.4f}'.format(epoch+1, batch, batch_loss.numpy(), bound_loss.numpy()))
    
    print('Epoch loss {:.8f}'.format(epoch_loss.numpy()))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
    
    if epoch >= warmup_epochs and epoch % 5 == 0:
        print("Autoencoding test")
        A_preds = tf.argmax(A_outputs, axis = 2)[0, :A_lengths[0]-1].numpy()
        print(tensor_to_sentence(A_inp[0, :A_lengths[0]].numpy(), data.first.idx2w), "   =>   ", end="")
        print("<START> " + tensor_to_sentence(A_preds, data.first.idx2w))
        B_preds = tf.argmax(B_outputs, axis = 2)[0, :B_lengths[0]-1].numpy()
        print(tensor_to_sentence(B_inp[0, :B_lengths[0]].numpy(), data.second.idx2w), "   =>   ", end="")
        print("<START> " + tensor_to_sentence(B_preds, data.second.idx2w))
        
        print("\nPB stats:")
        print("A")
        A.print_pb_stats()
        print("B")
        B.print_pb_stats()    
        
        
        ids = ids[0:7]
        A_pbs = A.get_pbs(ids)
        B_pbs = B.get_pbs(ids)
        
        use_dev = (epoch % 10 == 0)
        if use_dev:
            print("\nTesting translation using dev set")
            test_i = np.random.randint(0, len(X_test))
            x = X_test[test_i, :X_test_lengths[test_i]]
            trans, _ = translate(A, B, x, data.second.idx2w, eps=0.001, return_pb=True)
            
            print(tensor_to_sentence(x, data.first.idx2w), "   =>   ", trans)
            
            # Only bother plotting pbs with dev example for the sake of showing soft binding
            if bind == 'soft':
                plot_pbs(A_pbs, B_pbs, plot_zero = True)
            print()
        else:
            print("\nTesting translation using training set")
            x = A_inp[0, :A_lengths[0]].numpy()
            pb_journey = A.recognize(x, iters=1000, step = 100)
            pb = pb_journey[-1]
            
            print("Target dist from origin:", np.sqrt(np.sum(A_pbs[0]**2)))
            print("Recognized pb dist from origin:", np.sqrt(np.sum(pb**2)))
            
            trans = tensor_to_sentence(B.generate(pb, max_length = round(1.5 * len(x)), start=START, end=END), data.second.idx2w)
            print(tensor_to_sentence(x, data.first.idx2w), "   =>   ", trans)
            
            print("\n PB plot (red square corresponds to red point)")
            plot_pbs(A_pbs, B_pbs, pb_journey = pb_journey, plot_zero = True)
            print()

In [None]:
s = "<START> What exactly did you do ? <END>"
x = sentence_to_tensor(s, en_w2idx)

print(translate(A, A, x, en_idx2w))
print(translate(A, B, x, de_idx2w))