In [1]:
import random

import numpy as np
import torch
from torch.utils.data.dataloader import default_collate

from settings import EXPERIMENTS_DIR
from experiment import Experiment
from utils import to_device, load_weights, load_embeddings, create_embeddings_matrix
from vocab import Vocab
from train import create_model
from preprocess import load_dataset, create_dataset_reader

In [3]:
exp_id = 'train.0d1b9t6u'

# Load everything

In [4]:
exp = Experiment.load(EXPERIMENTS_DIR, exp_id)

In [5]:
exp.config

TrainConfig(model_class=<class 'models.Seq2SeqMeaningStyle'>, preprocess_exp_id='preprocess.pb25misv', embedding_size=300, hidden_size=256, dropout=0.2, scheduled_sampling_ratio=0.5, pretrained_embeddings=True, trainable_embeddings=False, meaning_size=32, style_size=32, lr=0.001, weight_decay=1e-07, grad_clipping=5, D_num_iterations=10, D_loss_multiplier=1, P_loss_multiplier=10, P_bow_loss_multiplier=1, use_discriminator=True, use_predictor=False, use_predictor_bow=True, use_motivator=True, use_gauss=False, num_epochs=5, batch_size=256, best_loss='loss')

In [6]:
preprocess_exp = Experiment.load(EXPERIMENTS_DIR, exp.config.preprocess_exp_id)
dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb = load_dataset(preprocess_exp)

Dataset: 21176, val: 10000, test: 10000
Vocab: 28106, style vocab: 2
W_emb: (28106, 300)


In [7]:
dataset_reader = create_dataset_reader(preprocess_exp.config)

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/josephcappadona/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [8]:
model = create_model(exp.config, vocab, style_vocab, dataset_train.max_len, W_emb)

In [9]:
load_weights(model, exp.experiment_dir.joinpath('best.th'))

In [44]:
model = model.eval()

## Predict

In [10]:
def create_inputs(instances):
    if not isinstance(instances, list):
        instances = [instances,]
        
    if not isinstance(instances[0], dict):
        sentences = [
            dataset_reader.preprocess_sentence(dataset_reader.spacy( dataset_reader.clean_sentence(sent)))
            for sent in instances
        ]
        
        style = list(style_vocab.token2id.keys())[0]
        instances = [
            {
                'sentence': sent,
                'style': style,
            }
            for sent in sentences
        ]
        
        for inst in instances:
            inst_encoded = dataset_train.encode_instance(inst)
            inst.update(inst_encoded)            
    
    
    instances = [
        {
            'sentence': inst['sentence_enc'],
            'style': inst['style_enc'],
        } 
        for inst in instances
    ]
    
    instances = default_collate(instances)
    instances = to_device(instances)      
    
    return instances

In [11]:
def get_sentences(outputs):
    predicted_indices = outputs["predictions"]
    end_idx = vocab[Vocab.END_TOKEN]
    
    if not isinstance(predicted_indices, np.ndarray):
        predicted_indices = predicted_indices.detach().cpu().numpy()

    all_predicted_tokens = []
    for indices in predicted_indices:
        indices = list(indices)

        # Collect indices till the first end_symbol
        if end_idx in indices:
            indices = indices[:indices.index(end_idx)]

        predicted_tokens = [vocab.id2token[x] for x in indices]
        all_predicted_tokens.append(predicted_tokens)
        
    return all_predicted_tokens

In [12]:
sentence =  ' '.join(dataset_val.instances[1]['sentence'])

In [13]:
sentence

'during the period of industrial growth from 1850 to 1950 , detroit ’s population grew dramatically .'

In [14]:
inputs = create_inputs(sentence)

In [15]:
outputs = model(inputs)

In [16]:
sentences = get_sentences(outputs)

In [17]:
' '.join(sentences[0])

'the the the the the the the the the the the the the the the the the the the the'

### Swap style

In [18]:
possible_styles = list(style_vocab.token2id.keys()) #['negative', 'positive']

In [19]:
possible_styles

['kids', 'scholars']

In [20]:
sentences0 = [s for s in dataset_val.instances if s['style'] == possible_styles[0]]
sentences1 = [s for s in dataset_val.instances if s['style'] == possible_styles[1]]

In [21]:
for i in np.random.choice(np.arange(len(sentences0)), 5):
    print(i, ' '.join(sentences0[i]['sentence']))

2855 this means that in the extreme north and south , most winds and currents run eastward , while near
470 the west indies is a group of islands that stretches from near the u.s. state of florida to the
1466 factories in the metropolitan area produce metals , chemicals , and machinery .
5709 fewer than one - quarter of the islands are populated .
1175 nearly everyone is muslim .


In [35]:
for i in np.random.choice(np.arange(len(sentences1)), 5):
    print(i, ' '.join(sentences1[i]['sentence']))

804 it had two lenses of identical focal length — one transmitting the image to the film and the other
354 after starring in neiboku sendai hagi ( “ the disputed succession ” ) , he adopted the dynastic name
2014 highsmith , who took her stepfather ’s name , graduated from barnard college , new york city , in
1322 cugat ’s bands included violins , maracas , and bongo and conga drums and featured dancers who demonstrated the
1725 caliban , a feral , sullen , misshapen creature in shakespeare ’s the tempest .


#### Swap

In [31]:
target0 = 2855 # np.random.choice(np.arange(len(sentences0)))
target1 = 804 # np.random.choice(np.arange(len(sentences0)))

In [36]:
print(' '.join(sentences0[target0]['sentence']))

this means that in the extreme north and south , most winds and currents run eastward , while near


In [37]:
print(' '.join(sentences1[target1]['sentence']))

most employment is related to the gaming and tourist industry .


In [38]:
inputs = create_inputs([
    sentences0[target0],
    sentences1[target1],
])

In [39]:
z_hidden = model(inputs)

In [40]:
z_hidden['style_hidden'].shape

torch.Size([2, 32])

In [41]:
z_hidden['meaning_hidden'].shape

torch.Size([2, 32])

In [46]:
original_decoded = model.decode(z_hidden)

In [47]:
original_sentences = get_sentences(original_decoded)

In [48]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))

the the the the the the the the the the the the the the the the the the the the
the the the the the the the the the the the the the the the the the the the the


In [63]:
z_hidden_swapped = {
    'meaning_hidden': torch.stack([
        z_hidden['meaning_hidden'][0].clone(),
        z_hidden['meaning_hidden'][1].clone(),        
    ], dim=0),
    'style_hidden': torch.stack([
        z_hidden['style_hidden'][1].clone(),
        z_hidden['style_hidden'][0].clone(),        
    ], dim=0),
}

In [64]:
swaped_decoded = model.decode(z_hidden_swapped)

In [65]:
swaped_sentences = get_sentences(swaped_decoded)

In [66]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))
print()
print(' '.join(swaped_sentences[0]))
print(' '.join(swaped_sentences[1]))

the rice had hard things in it .
which is awesome !

plus is really hard to it .
the rice was awesome .
