In [1]:
import random
import spacy
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from collections import defaultdict
from time import time
from tqdm import tqdm

try:
    import cPickle as pickle
except ImportError:
    import pickle

import dynet_config
dynet_config.set(mem=1024, random_seed=12345)
dynet_config.set_gpu()
import dynet as dy

In [2]:
print("Loading spaCy")
nlp = spacy.load('en')
assert nlp.path is not None
print ('Done.')

Loading spaCy
Done.


In [3]:
MAX_LEN = 100
NUM_TAGS = 3883
VOCAB_CAP = 10000

UNK = '<UNK>'
START = '<S>'
END = '</S>'

# Load Data for Parsing

In [4]:
with open('huang2016_train.aligned.pkl', 'rb') as f:
    twitter_texts, twitter_tags, twitter_histories = pickle.load(f)
    
with open('huang2016_valid.aligned.pkl', 'rb') as f:
    dev_texts, dev_tags, dev_histories = pickle.load(f)
    
with open('huang2016_test.aligned.pkl', 'rb') as f:
    test_texts, test_tags, test_histories = pickle.load(f)
    
del twitter_histories, dev_histories, test_histories

FileNotFoundError: [Errno 2] No such file or directory: 'huang2016_train.aligned.pkl'

# Load Data

In [5]:
def index_tags(tags_list, tag_set, tag_dict):
    return [[tag_dict[tag] for tag in tags if tag in tag_set] for tags in tags_list]

In [6]:
# Extract tag set
tag_counts = defaultdict(int)
for t in twitter_tags:
    for x in t:
        tag_counts[x] += 1
top_k_tags = set(sorted(tag_counts, key=tag_counts.get, reverse=True)[:NUM_TAGS])

tag_set = set()
for t in twitter_tags:
    tag_set.update(set([x for x in t if x in top_k_tags]))
    
tag_set = sorted(tag_set)
print ('{} unique tags.'.format(len(tag_set)))

tag_indexes = defaultdict(lambda: len(tag_indexes))
parsed_tags = index_tags(twitter_tags, tag_set, tag_indexes)
idx_to_tag = {v: k for k, v in tag_indexes.items()}

3883 unique tags.


In [7]:
try:
    print ('Attempting to open preprecessed TRAIN data ... ', end='')
    #raise NotImplemented
    
    t0=time()
    with open('parsed_twitter_train_data_no_histories.pkl', 'rb') as f:
        vocab, parsed_texts, parsed_tags = pickle.load(f)
    print ('DONE. ({:.3f}s)'.format(time()-t0))
        
except:
    print ('FAIL.')
    
    print ('\tParsing texts ... ', end='')
    t0=time()
    parsed_texts = [[str(w) for w in t][:MAX_LEN] for t in nlp.pipe([x.encode('ascii', 'ignore').decode('ascii').lower() for x in twitter_texts], n_threads=3, batch_size=20000)]
    print ('DONE. ({:.3f}s)'.format(time()-t0))
    
    print ('\tCounting words ... ', end='')
    word_counts = defaultdict(int)
    for t in parsed_texts:
        for x in t:
            word_counts[x] += 1
    top_k_words = set(sorted(word_counts, key=word_counts.get, reverse=True)[:VOCAB_CAP-3])

    word_set = set()
    for t in parsed_texts:
        word_set.update(set([x for x in t if x in top_k_words]))
    print ('DONE. ({:.3f}s)'.format(time()-t0)) 
    
    vocab = defaultdict(lambda: len(vocab))
    print ('\tIndexing texts ... ', end='')
    t0=time()
    parsed_texts = [[vocab[START]] + [(vocab[w] if w in word_set else vocab[UNK]) for w in t] + [vocab[END]] for t in parsed_texts]
    print ('DONE. ({:.3f}s)'.format(time()-t0))
    
    unk_idx = vocab[UNK]
    sos_idx = vocab[START]
    eos_idx = vocab[END]
    
    print ('\tSAVING parsed data ... ', end='')
    t0=time()
    with open('parsed_twitter_train_data_no_histories.pkl', 'wb') as f:
        pickle.dump((dict(vocab), parsed_texts, parsed_tags), f) 
    print ('DONE. ({:.3f}s)'.format(time()-t0))

unk_idx = vocab[UNK]
sos_idx = vocab[START]
eos_idx = vocab[END]
# Set unknown words to be UNK --> note as written, the paper does not indicate that any training data is labeled as UNK...
vocab = defaultdict(lambda: unk_idx, vocab)
idx_to_vocab = {v: k for k, v in vocab.items()}

VOCAB_SIZE = len(vocab)
print ('Vocab size:', VOCAB_SIZE)

Attempting to open preprecessed TRAIN data ... DONE. (0.590s)
Vocab size: 10000


In [8]:
try:
    print ('Attempting to open preprecessed DEV and TEST data ... ', end='')
    #raise NotImplemented
    
    t0=time()
    with open('parsed_twitter_test_dev_data_no_histories.pkl', 'rb') as f:
        parsed_dev_texts, parsed_test_texts = pickle.load(f)
    print ('DONE. ({:.3f}s)'.format(time()-t0))
        
except:
    print ('FAIL.')
    print ('\tParsing texts ... ', end='')
    t0=time()
    parsed_dev_texts = [[vocab[START]] + [vocab[str(w)] for w in t if not w.is_stop][:MAX_LEN] + [vocab[END]] for t in nlp.pipe([x.encode('ascii', 'ignore').decode('ascii').lower() for x in dev_texts], n_threads=3, batch_size=20000)]
    parsed_test_texts = [[vocab[START]] + [vocab[str(w)] for w in t if not w.is_stop][:MAX_LEN] + [vocab[END]] for t in nlp.pipe([x.encode('ascii', 'ignore').decode('ascii').lower() for x in test_texts], n_threads=3, batch_size=20000)]
    print ('DONE. ({:.3f}s)'.format(time()-t0))
    
    print ('\tSAVING parsed data ... ', end='')
    t0=time()
    with open('parsed_twitter_test_dev_data_no_histories.pkl', 'wb') as f:
        pickle.dump((parsed_dev_texts, parsed_test_texts), f) 
    print ('DONE. ({:.3f}s)'.format(time()-t0))

Attempting to open preprecessed DEV and TEST data ... DONE. (0.033s)


# Model Parameters and Settings

In [9]:
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
Q_DIM = 512
DROPOUT = 0.2
ALPHA = 0.01
EPSILON_MAX = .9
EPSILON_MIN = 0.00
KL_WEIGHT_START = 0.0

BATCH_SIZE = 16
PATIENCE = 3

In [10]:
# Initialize dynet model
model = dy.ParameterCollection()

# The paper uses AdaGrad
trainer = dy.AdamTrainer(model)

# Embedding parameters
embed = model.add_lookup_parameters((VOCAB_SIZE, EMBEDDING_DIM))

# Recurrent layers for tweet encoding
lstm_encode = dy.LSTMBuilder(1, EMBEDDING_DIM, HIDDEN_DIM, model)
lstm_decode = dy.LSTMBuilder(1, EMBEDDING_DIM, Q_DIM, model)

# Encoder MLP for tweet encoding
W_mu_tweet_p = model.add_parameters((Q_DIM, HIDDEN_DIM))
V_mu_tweet_p = model.add_parameters((HIDDEN_DIM, Q_DIM))
b_mu_tweet_p = model.add_parameters((Q_DIM))

W_sig_tweet_p = model.add_parameters((Q_DIM, HIDDEN_DIM))
V_sig_tweet_p = model.add_parameters((HIDDEN_DIM, Q_DIM))
b_sig_tweet_p = model.add_parameters((Q_DIM))

W_mu_tag_p = model.add_parameters((Q_DIM, NUM_TAGS))
V_mu_tag_p = model.add_parameters((HIDDEN_DIM, Q_DIM))
b_mu_tag_p = model.add_parameters((Q_DIM))

W_sig_tag_p = model.add_parameters((Q_DIM, NUM_TAGS))
V_sig_tag_p = model.add_parameters((HIDDEN_DIM, Q_DIM))
b_sig_tag_p = model.add_parameters((Q_DIM))

W_mu_p = model.add_parameters((Q_DIM, 2 * HIDDEN_DIM))
b_mu_p = model.add_parameters((Q_DIM))

W_sig_p = model.add_parameters((Q_DIM, 2 * HIDDEN_DIM))
b_sig_p = model.add_parameters((Q_DIM))

W_hidden_p = model.add_parameters((HIDDEN_DIM, Q_DIM))
b_hidden_p = model.add_parameters((HIDDEN_DIM))

W_tweet_softmax_p = model.add_parameters((VOCAB_SIZE, Q_DIM))
b_tweet_softmax_p = model.add_parameters((VOCAB_SIZE))

W_tag_output_p = model.add_parameters((NUM_TAGS, HIDDEN_DIM))
b_tag_output_p = model.add_parameters((NUM_TAGS))

In [11]:
def reparameterize(mu, log_sigma_squared):
    d = mu.dim()[0][0]
    sample = dy.random_normal(d)
    covar = dy.exp(log_sigma_squared * 0.5)

    return mu + dy.cmult(covar, sample)

def mlp(x, W, V, b):
    return V * dy.tanh(W * x + b)

In [12]:
def calc_loss(sent, epsilon=0.0):
    #dy.renew_cg()
    
    

    # Transduce all batch elements with an LSTM
    src = sent[0]
    tags = sent[1]

    # initialize the LSTM
    init_state_src = lstm_encode.initial_state()

    # get the output of the first LSTM
    src_output = init_state_src.add_inputs([embed[x] for x in src])[-1].output()

    # Now compute mean and standard deviation of source hidden state.
    W_mu_tweet = dy.parameter(W_mu_tweet_p)
    V_mu_tweet = dy.parameter(V_mu_tweet_p)
    b_mu_tweet = dy.parameter(b_mu_tweet_p)

    W_sig_tweet = dy.parameter(W_sig_tweet_p)
    V_sig_tweet = dy.parameter(V_sig_tweet_p)
    b_sig_tweet = dy.parameter(b_sig_tweet_p)
    
    # Compute tweet encoding
    mu_tweet      = dy.dropout(mlp(src_output, W_mu_tweet,  V_mu_tweet,  b_mu_tweet), DROPOUT)
    log_var_tweet = dy.dropout(mlp(src_output, W_sig_tweet, V_sig_tweet, b_sig_tweet), DROPOUT)
    
    W_mu_tag = dy.parameter(W_mu_tag_p)
    V_mu_tag = dy.parameter(V_mu_tag_p)
    b_mu_tag = dy.parameter(b_mu_tag_p)

    W_sig_tag = dy.parameter(W_sig_tag_p)
    V_sig_tag = dy.parameter(V_sig_tag_p)
    b_sig_tag = dy.parameter(b_sig_tag_p)
    
    # Compute tag encoding
    tags_tensor = dy.sparse_inputTensor([tags], np.ones((len(tags),)), (NUM_TAGS,))
    
    mu_tag      = dy.dropout(mlp(tags_tensor, W_mu_tag,  V_mu_tag,  b_mu_tag), DROPOUT)
    log_var_tag = dy.dropout(mlp(tags_tensor, W_sig_tag, V_sig_tag, b_sig_tag), DROPOUT)
    
    # Combine encodings for mean and diagonal covariance
    W_mu = dy.parameter(W_mu_p)
    b_mu = dy.parameter(b_mu_p)

    W_sig = dy.parameter(W_sig_p)
    b_sig = dy.parameter(b_sig_p)
    
    # Slowly phase out getting both inputs
    if random.random() < epsilon:
        mask = dy.zeros(HIDDEN_DIM)
    else:
        mask = dy.ones(HIDDEN_DIM)
        
    if random.random() < 0.5:
        mu_tweet = dy.cmult(mu_tweet, mask)
        log_var_tweet = dy.cmult(log_var_tweet, mask)
    else:
        mu_tag = dy.cmult(mu_tag, mask)
        log_var_tag = dy.cmult(log_var_tag, mask)
    
    mu      = dy.affine_transform([b_mu,  W_mu,  dy.concatenate([mu_tweet, mu_tag])])
    log_var = dy.affine_transform([b_sig, W_sig, dy.concatenate([log_var_tweet, log_var_tag])])

    # KL-Divergence loss computation
    kl_loss = -0.5 * dy.sum_elems(1 + log_var - dy.pow(mu, dy.inputVector([2])) - dy.exp(log_var))

    z = reparameterize(mu, log_var)

    # now step through the output sentence
    all_losses = []

    current_state = lstm_decode.initial_state().set_s([z, dy.tanh(z)])
    prev_word = src[0]
    W_sm = dy.parameter(W_tweet_softmax_p)
    b_sm = dy.parameter(b_tweet_softmax_p)

    for next_word in src[1:]:
        # feed the current state into the
        
        
        current_state = current_state.add_input(embed[prev_word])
        output_embedding = current_state.output()

        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        
        all_losses.append(dy.pickneglogsoftmax(s, next_word))

        # Slowly phase out teacher forcing (this may be slow??)
        if random.random() < epsilon:
            p = dy.softmax(s).npvalue()
            prev_word = np.random.choice(VOCAB_SIZE, p=p/p.sum())
        else:
            prev_word = next_word
    
    softmax_loss = dy.esum(all_losses)

    W_hidden = dy.parameter(W_hidden_p)
    b_hidden = dy.parameter(b_hidden_p)
    
    W_out = dy.parameter(W_tag_output_p)
    b_out = dy.parameter(b_tag_output_p)
    
    h = dy.dropout(dy.tanh(b_hidden + W_hidden * z), DROPOUT)
    o = dy.logistic(b_out + W_out * h)
    
    crossentropy_loss = dy.binary_log_loss(o, tags_tensor)
                               
    return kl_loss, softmax_loss, crossentropy_loss

In [13]:
train = list(zip(parsed_texts, parsed_tags))
dev_tags = index_tags(dev_tags, tag_set, tag_indexes)
dev = list(zip(parsed_dev_texts, dev_tags))

In [None]:
print ('Using batch size of {}.'.format(BATCH_SIZE))

epsilon = EPSILON_MIN
kl_weight = KL_WEIGHT_START
steps = 0
strikes = 0
last_dev_loss = np.inf
for ITER in range(100):
    # Perform training
    random.shuffle(train)
    
    batches = [train[i:i + BATCH_SIZE] for i in range(0, len(train), BATCH_SIZE)]
    
    train_words, train_loss, train_kl_loss, train_reconstruct_loss, total_tag_loss = 0, 0.0, 0.0, 0.0, 0.0
    start = time()
    
    print ('Training ... Iteration:', ITER, 'Epsilon:', epsilon)
    for i, batch in enumerate(tqdm(batches)):
        dy.renew_cg()
        losses = []
        for sent_id, sent in enumerate(batch):
            if len(sent[1]) < 1 or len(sent[0]) < 3:
                continue
            kl_loss, softmax_loss, tag_loss = calc_loss(sent, epsilon)
            #total_loss = dy.esum([kl_loss, softmax_loss, tag_loss])
            #train_loss += total_loss.value()
            
            # Gradually increase KL-Divergence loss
            if steps < 15000:
                kl_weight = 1 / (1 + np.exp(-0.001 * steps + 5))
            else:
                kl_weight = 1.0
                
            losses.append(dy.esum([kl_weight * kl_loss, softmax_loss, tag_loss]))

            # Record the KL loss and reconstruction loss separately help you monitor the training.
            train_kl_loss += kl_loss.value()
            train_reconstruct_loss += softmax_loss.value()
            total_tag_loss += tag_loss.value()
            
            train_words += len(sent[0])
        steps += 1
   
        # Batch update
        batch_loss = dy.esum(losses)/BATCH_SIZE
        train_loss += batch_loss.value()
        batch_loss.backward()
        trainer.update()
        
        
        #total_loss.backward()
        #trainer.update()
        #if (sent_id + 1) % 1000 == 0:
        #    print("--finished %r sentences" % (sent_id + 1))

    # Gradually increase KL-Divergence loss
    if steps < 100000:
        epsilon = .9 / (1 + np.exp(-0.0001 * steps + 5))
    else:
        epsilon = EPSILON_MAX
        
    #epsilon = min(EPSILON_MAX, epsilon + 0.05)
    print("iter %r: train loss/word=%.4f, kl loss/word=%.4f, reconstruction loss/word=%.4f, ppl=%.4f, tag loss=%.4fs" % (
        ITER, train_loss / train_words, train_kl_loss / train_words, train_reconstruct_loss / train_words,
        math.exp(train_loss / train_words), total_tag_loss / len(train)))

    # Evaluate on dev set
    dev_words, dev_loss, dev_kl_loss, dev_reconstruct_loss, dev_tag_loss = 0, 0.0, 0.0, 0.0, 0.0
    start = time()
    print ('Evaluating batch ... ')
    for sent_id, sent in enumerate(tqdm(dev)):
        dy.renew_cg()
        if len(sent[1]) < 1 or len(sent[0]) < 3:
                continue
        kl_loss, softmax_loss, tag_loss = calc_loss(sent)

        dev_kl_loss += kl_loss.value()
        dev_reconstruct_loss += softmax_loss.value()
        dev_tag_loss += tag_loss.value()
        dev_loss += kl_loss.value() + softmax_loss.value() + tag_loss.value()

        dev_words += len(sent[0])
        trainer.update()

    print("iter %r: dev loss/word=%.4f, kl loss/word=%.4f, reconstruction loss/word=%.4f, ppl=%.4f, tag loss=%.2fs" % (
        ITER, dev_loss / dev_words, dev_kl_loss / dev_words, dev_reconstruct_loss / dev_words,
        math.exp(dev_loss / dev_words), dev_tag_loss / len(dev)))
    if dev_loss > last_dev_loss and ITER > 9:
        strikes += 1
    else:
        strikes = 0
        last_dev_loss = dev_loss
        model.save('tweet_tag_vae.best.weights')
        
    if strikes >= PATIENCE:
        print ('Early stopping after {} iterations.')
        break

  0%|          | 0/13304 [00:00<?, ?it/s]

Using batch size of 16.
Training ... Iteration: 0 Epsilon: 0.0


100%|██████████| 13304/13304 [1:31:24<00:00,  2.43it/s]
  0%|          | 7/25817 [00:00<06:42, 64.08it/s]

iter 0: train loss/word=0.3311, kl loss/word=0.8457, reconstruction loss/word=4.3154, ppl=1.3925, tag loss=13.8994s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:11<00:00, 69.58it/s]
  0%|          | 0/13304 [00:00<?, ?it/s]

iter 0: dev loss/word=5.6335, kl loss/word=0.1024, reconstruction loss/word=4.6220, ppl=279.6253, tag loss=10.81s
Training ... Iteration: 1 Epsilon: 0.022367912868


100%|██████████| 13304/13304 [1:30:50<00:00,  2.44it/s]
  0%|          | 7/25817 [00:00<06:39, 64.54it/s]

iter 1: train loss/word=0.3108, kl loss/word=0.0321, reconstruction loss/word=4.0666, ppl=1.3645, tag loss=13.9287s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:10<00:00, 69.64it/s]


iter 1: dev loss/word=5.7202, kl loss/word=0.0414, reconstruction loss/word=4.6131, ppl=304.9764, tag loss=12.67s


  0%|          | 0/13304 [00:00<?, ?it/s]

Training ... Iteration: 2 Epsilon: 0.079135245945


100%|██████████| 13304/13304 [1:32:28<00:00,  2.40it/s]
  0%|          | 7/25817 [00:00<06:47, 63.40it/s]

iter 2: train loss/word=0.3176, kl loss/word=0.0576, reconstruction loss/word=4.0661, ppl=1.3739, tag loss=15.2875s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:09<00:00, 69.80it/s]


iter 2: dev loss/word=5.8946, kl loss/word=0.0771, reconstruction loss/word=4.5660, ppl=363.0699, tag loss=14.88s


  0%|          | 0/13304 [00:00<?, ?it/s]

Training ... Iteration: 3 Epsilon: 0.240493282516


100%|██████████| 13304/13304 [1:37:10<00:00,  2.28it/s]
  0%|          | 7/25817 [00:00<06:46, 63.53it/s]

iter 3: train loss/word=0.3496, kl loss/word=0.0808, reconstruction loss/word=4.3932, ppl=1.4186, tag loss=17.8629s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:09<00:00, 69.78it/s]
  0%|          | 0/13304 [00:00<?, ?it/s]

iter 3: dev loss/word=6.1584, kl loss/word=0.1127, reconstruction loss/word=4.5299, ppl=472.6801, tag loss=18.02s
Training ... Iteration: 4 Epsilon: 0.521742721359


100%|██████████| 13304/13304 [1:45:21<00:00,  2.10it/s]
  0%|          | 7/25817 [00:00<06:37, 64.91it/s]

iter 4: train loss/word=0.3920, kl loss/word=0.0936, reconstruction loss/word=4.9153, ppl=1.4800, tag loss=20.1474s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:09<00:00, 69.92it/s]
  0%|          | 0/13304 [00:00<?, ?it/s]

iter 4: dev loss/word=6.4051, kl loss/word=0.1595, reconstruction loss/word=4.5832, ppl=604.9246, tag loss=19.77s
Training ... Iteration: 5 Epsilon: 0.755245055665


100%|██████████| 13304/13304 [1:57:03<00:00,  1.89it/s]
  0%|          | 6/25817 [00:00<07:14, 59.39it/s]

iter 5: train loss/word=0.4162, kl loss/word=0.1014, reconstruction loss/word=5.2154, ppl=1.5162, tag loss=21.4078s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:17<00:00, 68.34it/s]
  0%|          | 0/13304 [00:00<?, ?it/s]

iter 5: dev loss/word=6.8197, kl loss/word=0.1818, reconstruction loss/word=4.7896, ppl=915.7017, tag loss=21.98s
Training ... Iteration: 6 Epsilon: 0.856595388894


100%|██████████| 13304/13304 [1:58:46<00:00,  1.87it/s]
  0%|          | 7/25817 [00:00<06:51, 62.66it/s]

iter 6: train loss/word=0.4198, kl loss/word=0.1046, reconstruction loss/word=5.2904, ppl=1.5217, tag loss=21.0838s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:19<00:00, 67.96it/s]
  0%|          | 0/13304 [00:00<?, ?it/s]

iter 6: dev loss/word=6.8885, kl loss/word=0.1950, reconstruction loss/word=4.9372, ppl=980.9103, tag loss=20.88s
Training ... Iteration: 7 Epsilon: 0.888102982863


100%|██████████| 13304/13304 [1:58:48<00:00,  1.87it/s] 
  0%|          | 7/25817 [00:00<07:05, 60.67it/s]

iter 7: train loss/word=0.4180, kl loss/word=0.1023, reconstruction loss/word=5.3051, ppl=1.5190, tag loss=20.4317s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:16<00:00, 68.54it/s]
  0%|          | 0/13304 [00:00<?, ?it/s]

iter 7: dev loss/word=6.9935, kl loss/word=0.2157, reconstruction loss/word=4.9875, ppl=1089.4894, tag loss=21.29s
Training ... Iteration: 8 Epsilon: 0.9


100%|██████████| 13304/13304 [1:56:56<00:00,  1.90it/s]
  0%|          | 6/25817 [00:00<07:16, 59.08it/s]

iter 8: train loss/word=0.4137, kl loss/word=0.1003, reconstruction loss/word=5.3105, ppl=1.5125, tag loss=19.2809s
Evaluating batch ... 


100%|██████████| 25817/25817 [06:16<00:00, 68.59it/s]


iter 8: dev loss/word=6.7943, kl loss/word=0.1870, reconstruction loss/word=5.0366, ppl=892.7160, tag loss=18.68s


  0%|          | 0/13304 [00:00<?, ?it/s]

Training ... Iteration: 9 Epsilon: 0.9


100%|██████████| 13304/13304 [1:57:18<00:00,  1.89it/s]
  0%|          | 7/25817 [00:00<06:55, 62.07it/s]

iter 9: train loss/word=0.4038, kl loss/word=0.0906, reconstruction loss/word=5.3091, ppl=1.4976, tag loss=16.9326s
Evaluating batch ... 


 81%|████████  | 20968/25817 [05:07<01:11, 68.26it/s]

In [26]:
model.save('trained_vae_joint_multimodal.weights.x')

In [13]:
model.populate('trained_vae_joint_multimodal.weights')

In [14]:
def hallucinate_tags(tweet):
    dy.renew_cg()
    
    # Transduce all batch elements with an LSTM
    src = tweet

    # initialize the LSTM
    init_state_src = lstm_encode.initial_state()

    # get the output of the first LSTM
    src_output = init_state_src.add_inputs([embed[x] for x in src])[-1].output()

    # Now compute mean and standard deviation of source hidden state.
    W_mu_tweet = dy.parameter(W_mu_tweet_p)
    V_mu_tweet = dy.parameter(V_mu_tweet_p)
    b_mu_tweet = dy.parameter(b_mu_tweet_p)

    W_sig_tweet = dy.parameter(W_sig_tweet_p)
    V_sig_tweet = dy.parameter(V_sig_tweet_p)
    b_sig_tweet = dy.parameter(b_sig_tweet_p)
    
    # Compute tweet encoding
    mu_tweet      = mlp(src_output, W_mu_tweet,  V_mu_tweet,  b_mu_tweet)
    log_var_tweet = mlp(src_output, W_sig_tweet, V_sig_tweet, b_sig_tweet)
    
    #W_mu_tag = dy.parameter(W_mu_tag_p)
    #V_mu_tag = dy.parameter(V_mu_tag_p)
    #b_mu_tag = dy.parameter(b_mu_tag_p)

    #W_sig_tag = dy.parameter(W_sig_tag_p)
    #V_sig_tag = dy.parameter(V_sig_tag_p)
    #b_sig_tag = dy.parameter(b_sig_tag_p)
    
    # Compute tag encoding
    #tags_tensor = dy.sparse_inputTensor([tags], np.ones((len(tags),)), (NUM_TAGS,))
    
    #mu_tag      = dy.dropout(mlp(tags_tensor, W_mu_tag,  V_mu_tag,  b_mu_tag), DROPOUT)
    #log_var_tag = dy.dropout(mlp(tags_tensor, W_sig_tag, V_sig_tag, b_sig_tag), DROPOUT)
    
    # Combine encodings for mean and diagonal covariance
    W_mu = dy.parameter(W_mu_p)
    b_mu = dy.parameter(b_mu_p)

    W_sig = dy.parameter(W_sig_p)
    b_sig = dy.parameter(b_sig_p)
    
    
    mu_tag = dy.zeros(HIDDEN_DIM)
    log_var_tag = dy.zeros(HIDDEN_DIM)
    
    mu      = dy.affine_transform([b_mu,  W_mu,  dy.concatenate([mu_tweet, mu_tag])])
    log_var = dy.affine_transform([b_sig, W_sig, dy.concatenate([log_var_tweet, log_var_tag])])

    # KL-Divergence loss computation
    kl_loss = -0.5 * dy.sum_elems(1 + log_var - dy.pow(mu, dy.inputVector([2])) - dy.exp(log_var))

    z = reparameterize(mu, log_var)

    # now step through the output sentence
    all_losses = []

    #current_state = lstm_decode.initial_state().set_s([z, dy.tanh(z)])
    #prev_word = src[0]
    #W_sm = dy.parameter(W_tweet_softmax_p)
    #b_sm = dy.parameter(b_tweet_softmax_p)

    #for next_word in src[1:]:
    #    # feed the current state into the
    #    current_state = current_state.add_input(embed[prev_word])
    #    output_embedding = current_state.output()

    #    s = dy.affine_transform([b_sm, W_sm, output_embedding])
    #    all_losses.append(dy.pickneglogsoftmax(s, next_word))

    #    prev_word = next_word
    
    #softmax_loss = dy.esum(all_losses)

    W_hidden = dy.parameter(W_hidden_p)
    b_hidden = dy.parameter(b_hidden_p)
    
    W_out = dy.parameter(W_tag_output_p)
    b_out = dy.parameter(b_tag_output_p)
    
    h = dy.tanh(b_hidden + W_hidden * z)
    o = dy.logistic(b_out + W_out * h)
    
    tag_ranks = o.value()
    
    gen_tags = []
    for i, p in enumerate(tag_ranks):
        if random.random() < p:
            gen_tags.append(i)
                               
    return gen_tags

In [35]:
[idx_to_vocab[i] for i in train[1][0]], [idx_to_tag[i] for i in train[1][1]]

(['young',
  ',',
  'african',
  'scientists',
  'inspiring',
  'next',
  'peer',
  'group',
  'of',
  'innovators',
  '<UNK>',
  ' ',
  '<UNK>'],
 ['women', 'Diversity'])

In [65]:
idx = 1000
pred = hallucinate_tags(train[idx][0])
pred_args = np.argsort(pred)[::-1]
[idx_to_vocab[i] for i in train[idx][0]], [idx_to_tag[i] for i in train[idx][1]], [idx_to_tag[i] for i in pred_args[:10]]

(['states',
  'working',
  'on',
  'new',
  'accounts',
  'for',
  'disabled',
  'families',
  ':',
  'see',
  '<UNK>',
  'story',
  '<UNK>'],
 ['batman'],
 ['NHS',
  'quote',
  'UFC',
  'photographer',
  'startups',
  'UK',
  'ausbiz',
  'NATO',
  'Destiny',
  'Sabres'])

In [62]:
pred[3239], pred[165]

(7.152557373046875e-06, 0.001961648464202881)

In [15]:
def hallucinate_tweet(given_tags):
    dy.renew_cg()
    
    # Transduce all batch elements with an LSTM
    tags = given_tags

    # initialize the LSTM
    #init_state_src = lstm_encode.initial_state()

    # get the output of the first LSTM
    #src_output = init_state_src.add_inputs([embed[x] for x in src])[-1].output()

    # Now compute mean and standard deviation of source hidden state.
    #W_mu_tweet = dy.parameter(W_mu_tweet_p)
    #V_mu_tweet = dy.parameter(V_mu_tweet_p)
    #b_mu_tweet = dy.parameter(b_mu_tweet_p)

    #W_sig_tweet = dy.parameter(W_sig_tweet_p)
    #V_sig_tweet = dy.parameter(V_sig_tweet_p)
    #b_sig_tweet = dy.parameter(b_sig_tweet_p)
    
    # Compute tweet encoding
    #mu_tweet      = mlp(src_output, W_mu_tweet,  V_mu_tweet,  b_mu_tweet)
    #log_var_tweet = mlp(src_output, W_sig_tweet, V_sig_tweet, b_sig_tweet)
    
    W_mu_tag = dy.parameter(W_mu_tag_p)
    V_mu_tag = dy.parameter(V_mu_tag_p)
    b_mu_tag = dy.parameter(b_mu_tag_p)

    W_sig_tag = dy.parameter(W_sig_tag_p)
    V_sig_tag = dy.parameter(V_sig_tag_p)
    b_sig_tag = dy.parameter(b_sig_tag_p)
    
    # Compute tag encoding
    tags_tensor = dy.sparse_inputTensor([tags], np.ones((len(tags),)), (NUM_TAGS,))
    
    mu_tag      = dy.dropout(mlp(tags_tensor, W_mu_tag,  V_mu_tag,  b_mu_tag), DROPOUT)
    log_var_tag = dy.dropout(mlp(tags_tensor, W_sig_tag, V_sig_tag, b_sig_tag), DROPOUT)
    
    # Combine encodings for mean and diagonal covariance
    W_mu = dy.parameter(W_mu_p)
    b_mu = dy.parameter(b_mu_p)

    W_sig = dy.parameter(W_sig_p)
    b_sig = dy.parameter(b_sig_p)
    
    mu_tweet = dy.zeros(HIDDEN_DIM)
    log_var_tweet = dy.zeros(HIDDEN_DIM)
    
    mu      = dy.affine_transform([b_mu,  W_mu,  dy.concatenate([mu_tweet, mu_tag])])
    log_var = dy.affine_transform([b_sig, W_sig, dy.concatenate([log_var_tweet, log_var_tag])])

    # KL-Divergence loss computation
    kl_loss = -0.5 * dy.sum_elems(1 + log_var - dy.pow(mu, dy.inputVector([2])) - dy.exp(log_var))

    z = reparameterize(mu, log_var)

    # now step through the output sentence
    all_losses = []

    current_state = lstm_decode.initial_state().set_s([z, dy.tanh(z)])
    prev_word = vocab[START]
    W_sm = dy.parameter(W_tweet_softmax_p)
    b_sm = dy.parameter(b_tweet_softmax_p)

    gen_tweet = []
    for i in range(20):
        # feed the current state into the
        current_state = current_state.add_input(embed[prev_word])
        output_embedding = current_state.output()

        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        p = dy.softmax(s).npvalue()
        next_word = np.random.choice(VOCAB_SIZE, p=p/p.sum())
        gen_tweet.append(next_word)
        prev_word = next_word
                               
    return gen_tweet

In [27]:
idx = 1000
[idx_to_tag[i] for i in train[idx][1]], [idx_to_vocab[i] for i in hallucinate_tweet(train[idx][1])]

(['savings'],
 ['shanghai',
  '<UNK>',
  'rt',
  'platforms',
  'htt',
  'surprised',
  'board',
  'get',
  'days',
  'set',
  '&',
  '5-star',
  'rt',
  'ed',
  'rt',
  '@justintrudeau',
  'in',
  'jumped',
  'rt',
  'bear'])

In [32]:
idx = 1000
[idx_to_vocab[i] for i in train[idx][0]], [idx_to_tag[i] for i in hallucinate_tags(train[idx][0])]

(['<S>',
  'states',
  'working',
  'on',
  'new',
  'accounts',
  'for',
  'disabled',
  'families',
  ':',
  'see',
  '<UNK>',
  'story',
  '<UNK>',
  '</S>'],
 ['Ankara'])