In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np 
import time
import random
from models.rnnpb import RNNPB
from misc.plots import plot_pbs
from misc.dataset import Dataset
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split


In [None]:
batch_size = 64
vocab1_max = 15000
vocab2_max = 15000

data = Dataset("data/deu-tok.en", "data/deu-tok.de", batch_size, vocab1_max, vocab2_max, length_limit = 50, easy_subset=50000)
#data = Dataset("data/deu-tok.en", "data/deu-tok.de", batch_size, length_limit = 100)

_, unique_inds = np.unique(data.X, axis = 0, return_index = True)

#print(len(unique_inds))

X = data.X[unique_inds]
X_lengths = np.array(data.X_lengths)[unique_inds]

#with open("data/autoenc/easy50000/en", "w", encoding="utf-8") as f:
#    shuffled_inds = list(range(X.shape[0]))
#    random.shuffle(shuffled_inds)
#    for i in shuffled_inds:
#        f.write(" ".join(data.first.idx2w[w] for w in X[i, 1:X_lengths[i]-1]) + "\n")

X, X_test, X_lengths, X_test_lengths = train_test_split(X, X_lengths, test_size = 2000, shuffle=True)

num_seqs = len(X)
ids = list(range(num_seqs))

train = tf.data.Dataset.from_tensor_slices((X, X_lengths, ids)).shuffle(len(X))
train = train.batch(batch_size, drop_remainder = True)

num_seqs_test = len(X_test)
ids_test = list(range(num_seqs_test))
test = tf.data.Dataset.from_tensor_slices((X_test, X_test_lengths, ids_test)).shuffle(len(X_test))
test.batch(1, drop_remainder = False)

START = data.first.w2idx["<START>"]
END = data.first.w2idx["<END>"]

In [None]:
plt.hist(X_lengths, bins='auto', log=True);

In [None]:
def sentence_to_tensor(s, w2idx):
    s = [w2idx.get(w, w2idx["<UNK>"]) for w in s.split(" ")]
    x = tf.convert_to_tensor(s)
    return x

def tensor_to_sentence(x, idx2w):
    return " ".join(idx2w[i] for i in x)

def translate(A, B, x, B_idx2w, eps=0.0001, return_pb = False):
    x = tf.expand_dims(x, 0)
    
    pb = A.recognize(x, eps = eps)[0]
    y = B.generate(pb, max_length = round(1.5 * len(x[0])), start=START, end=END)
    
    s = tensor_to_sentence(y, B_idx2w)
    return (s, pb[0]) if return_pb else s


In [None]:
embedding_size = 64
units = 256
num_layers = 2
learning_rate = 0.001
recog_lr = 1
num_PB = 32
gradient_clip = 1.0

warmup_epochs = 0
reset_pbs_every = 0 # epoch

In [None]:
A = RNNPB(vocab_size=data.first.vocab_size, embedding_size=embedding_size, units=units, num_layers=num_layers, num_PB=num_PB, num_sequences=num_seqs, recog_lr = recog_lr, gradient_clip=gradient_clip, batch_size=batch_size)

In [None]:
# Training

optimizer = tf.keras.optimizers.Adam(learning_rate)

epochs = 100

test_iter = iter(test)
for epoch in range(epochs):
    print("====== EPOCH {} ======".format(epoch+1))
    start = time.time()
    epoch_loss = 0
    
    if epoch != 0 and reset_pbs_every != 0 and epoch % reset_pbs_every == 0:
        A.reset_pbs()
        B.reset_pbs()
    
    for (batch, (A_inp, A_lengths, ids)) in enumerate(train):
        
        max_A_length = np.max(A_lengths)
        A_inp = A_inp[:, :max_A_length]
         
        
        with tf.GradientTape() as tape:
            
            A_inputs = A_inp[:, :-1]
            A_targets = A_inp[:, 1:]
            
            A_outputs, _, A_pb = A(A_inputs, ids)
         
            A_mask = 1 - np.equal(A_targets, 0)
            A_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = A_targets, logits=A_outputs) * A_mask
            
            batch_loss = tf.reduce_mean(A_loss)
                    
        if epoch < warmup_epochs:
            variables = A.variables_expt_pb()
        else:
            variables = A.variables
        
        gradients = tape.gradient(batch_loss, variables)
        
        gradients, _ = tf.clip_by_global_norm(gradients, gradient_clip)
        
        optimizer.apply_gradients(zip(gradients, variables))
        
        epoch_loss += batch_loss
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1, batch, batch_loss.numpy()))
    print('Epoch loss {:.8f}'.format(epoch_loss.numpy()))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
    
    if epoch >= warmup_epochs and epoch % 5 == 0:
        print("Prediction test")
        A_preds = tf.argmax(A_outputs, axis = 2)[0, :A_lengths[0]-1].numpy()
        print(tensor_to_sentence(A_inp[0, :A_lengths[0]].numpy(), data.first.idx2w), "   =>   ", end="")
        print("<START> " + tensor_to_sentence(A_preds, data.first.idx2w))

        print("\nPB stats:")
        A.print_pb_stats()

        print("\nAutoencoding test")
        
        x, x_l, _ = next(test_iter)
        x = x[:x_l].numpy()
        
        trans = translate(A, A, x, data.first.idx2w)
        
        print(tensor_to_sentence(x, data.first.idx2w), "   =>   ", trans)

In [None]:
s = "<START> What exactly did you do ? <END>"
x = sentence_to_tensor(s, en_w2idx)

print(translate(A, A, x, en_idx2w))
print(translate(A, B, x, de_idx2w))