# Homework and bake-off: pragmatic color descriptions

In [1]:
__author__ = "Christopher Leung"
__version__ = "CS224u, Stanford, Spring 2020"

## Set-up

See [colors_overview.ipynb](colors_overview.ipynb) for set-up in instructions and other background details.

In [2]:
from colors import ColorsCorpusReader
import os
from sklearn.model_selection import train_test_split
from torch_color_selector import (
    ColorizedNeuralListener, create_example_dataset)
from torch_color_describer import ColorizedInputDescriber
import utils
from utils import START_SYMBOL, END_SYMBOL, UNK_SYMBOL
import numpy as np
import torch

In [3]:
utils.fix_random_seeds()

In [4]:
COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "filteredCorpus.csv")

## All two-word examples as a dev corpus

So that you don't have to sit through excessively long training runs during development, I suggest working with the two-word-only subset of the corpus until you enter into the late stages of system testing.

In [5]:
dev_corpus = ColorsCorpusReader(
    COLORS_SRC_FILENAME, 
    word_count=None, 
    normalize_colors=True)

In [6]:
dev_examples = list(dev_corpus.read())

This subset has about one-third the examples of the full corpus:

In [7]:
len(dev_examples)

46994

We __should__ worry that it's not a fully representative sample. Most of the descriptions in the full corpus are shorter, and a large proportion are longer. So this dataset is mainly for debugging, development, and general hill-climbing. All findings should be validated on the full dataset at some point.

## Dev dataset

Let's load the saved training and test data.

In [8]:
def load_from_pickle():
    import pickle 
    
    with open('dev_vocab.pickle', 'rb') as handle:
        dev_vocab = pickle.load(handle)
    with open('dev_seqs_test.pickle', 'rb') as handle:
        dev_seqs_test = pickle.load(handle)
    with open('dev_seqs_train.pickle', 'rb') as handle:
        dev_seqs_train = pickle.load(handle)
    with open('dev_cols_test.pickle', 'rb') as handle:
        dev_cols_test = pickle.load(handle)
    with open('dev_cols_train.pickle', 'rb') as handle:
        dev_cols_train = pickle.load(handle)
    with open('embedding.pickle', 'rb') as handle:
        embedding = pickle.load(handle)
    return dev_vocab, dev_seqs_test, dev_seqs_train, dev_cols_test, dev_cols_train, embedding
dev_vocab, dev_seqs_test, dev_seqs_train, dev_cols_test, dev_cols_train, embedding = load_from_pickle()

At this point, our preprocessing steps are complete, and we can fit a first model.

## GloVe embeddings

We also load the GloVe embedding that was used by the speaker.

In [9]:
def load_glove_from_pickle():
    import pickle 
    with open('dev_glove_vocab.pickle', 'rb') as handle:
        dev_glove_vocab = pickle.load(handle)
    with open('dev_glove_embedding.pickle', 'rb') as handle:
        dev_glove_embedding = pickle.load(handle)
    return dev_glove_vocab, dev_glove_embedding
dev_glove_vocab, dev_glove_embedding = load_glove_from_pickle()

The above might dramatically change your vocabulary, depending on how many items from your vocab are in the Glove space:

## Load the Literal Listener

In [10]:
literal_listener = ColorizedNeuralListener(
    dev_vocab, 
    #embedding=dev_glove_embedding, 
    embed_dim=100,
    embedding=embedding,
    hidden_dim=100, 
    max_iter=100,
    batch_size=256,
    dropout_prob=0.,
    eta=0.001,
    lr_rate=0.96,
    warm_start=True,
    device='cuda')
literal_listener.load_model("literal_listener_with_attention.pt")

Using cuda


In [11]:
test_preds = literal_listener.predict(dev_cols_test, dev_seqs_test)
train_preds = literal_listener.predict(dev_cols_train, dev_seqs_train)

  color_seqs = torch.FloatTensor(color_seqs)


In [12]:
correct = sum([1 if x == 2 else 0 for x in test_preds])
print("test", correct, "/", len(test_preds), correct/len(test_preds))
correct = sum([1 if x == 2 else 0 for x in train_preds])
print("train", correct, "/", len(train_preds), correct/len(train_preds))

test 9405 / 11749 0.8004936590348115
train 31783 / 35245 0.901773301177472


## Load the Literal Speaker

In [13]:
literal_speaker = ColorizedInputDescriber(
    dev_glove_vocab, 
    embedding=dev_glove_embedding, 
    hidden_dim=100, 
    max_iter=40, 
    eta=0.005,
    batch_size=128)
literal_speaker.load_model("literal_speaker.pt")

Using cuda


In [14]:
literal_speaker.listener_accuracy(dev_cols_test, dev_seqs_test)

  color_seqs = torch.FloatTensor(color_seqs)
  perp = [np.prod(s)**(-1/len(s)) for s in scores]


0.8149629755723892

## Hallucinating Pragmatic Speaker

First we set up the Hallucinating Pragmatic Speaker

In [15]:
def generate_listener_hallucinations(input_colors, num_hallucinations=5, alpha=0.544, k_samples=10):
    '''This method generates listener hallucinations.
    Parameters
    ----------
    input_colors:
        A list of size (n,m,p) of int where each example has a list of m colors. Each color
        is embedded in size p.
    Returns
    -------
    prag_speaker_pred:
        (n,k_samples,*) The top sentences from the speaker that maximizes the likelihood 
        that the listener will choose the target color. Each sentence can be of different
        length and is tokenized.
    '''
    print("Sampling utterances")
    utterances = literal_speaker.sample_utterances(input_colors, k_samples=k_samples)
    
    print("Preparing Data")
    # Prepare data, flatten the target utterances and repeat the input colors per k_sample
    target_utterances = [seq for seq_list in utterances for seq in seq_list]
    input_colors_extended = [item for item in input_colors for i in range(k_samples)]
    
    print("Calculating probabilities")
    # utterance_preds = literal_listener.predict(input_colors_extended, target_utterances)
    utterance_probs = literal_listener.predict(input_colors_extended, target_utterances, probabilities=True)
    utterance_probs = torch.FloatTensor([preds[2] for preds in utterance_probs]).view(-1, k_samples)
    utterance_probs = utterance_probs ** alpha
    
    total = torch.sum(utterance_probs, dim=1).unsqueeze(1)
    normalized_utterance_probs = utterance_probs/total

    print("Finding top m utterances")
    # Find the best k number of utterances that maximize the listener likelihood
    best_utter_values, best_utter_indices = torch.topk(normalized_utterance_probs, num_hallucinations, dim=1)
    
    # DEPRECATED -Then flip the index number back.
    # prag_speaker_pred_ind = normalized_utterance_probs.shape[1] - best_utter_index - 1
    
    # Index into the utterances to find the sequence candidates
    prag_speaker_pred = [[seqs[utter_index] for utter_index in best_utter_indices[ind]] for ind, seqs in enumerate(utterances)]
    return prag_speaker_pred

Let's generate the input colors needed to predict for different candidate targets.

In [16]:
top_hallucinations = []
for col_partition in [dev_cols_train[:10000], dev_cols_train[10000:20000], dev_cols_train[20000:30000], dev_cols_train[30000:]]:
    torch.cuda.empty_cache()
    third_col_speaker_pred = generate_listener_hallucinations(col_partition, num_hallucinations=5, k_samples=8)
    top_hallucinations.append([seqs[0] for seqs in third_col_speaker_pred])

Sampling utterances


  color_seqs = torch.FloatTensor(color_seqs).to(self.device)


Preparing Data
Calculating probabilities
Finding top m utterances
Sampling utterances
Preparing Data
Calculating probabilities
Finding top m utterances
Sampling utterances
Preparing Data
Calculating probabilities
Finding top m utterances
Sampling utterances
Preparing Data
Calculating probabilities
Finding top m utterances


In [17]:
top_hallucinations = [seq for seqs in top_hallucinations for seq in seqs]
top_hallucinations[:5]

[['<s>', 'grey', '</s>'],
 ['<s>', 'green', '</s>'],
 ['<s>', 'red', '</s>'],
 ['<s>', 'dull', 'blue', '</s>'],
 ['<s>', 'dark', '+est', 'green', '+ish', '</s>']]

Let's do a quick test.

In [18]:
listened_preds = literal_listener.predict(dev_cols_train, top_hallucinations)
correct = sum([1 if x == 2 else 0 for x in listened_preds])
print("test", correct, "/", len(listened_preds), correct/len(listened_preds))

test 34960 / 35245 0.9919137466307277


How do these utterances perfectly capture the space of color differences? More needs to be done to examine this and is an excellent research direction.

One other thing we can do is to train the speaker on these hallucinations.

In [19]:
literal_speaker.warm_start = True
# We only reassign the optimizer, not the graph.
literal_speaker.opt = literal_speaker.optimizer(
                literal_speaker.model.parameters(),
                lr=literal_speaker.eta,
                weight_decay=literal_speaker.l2_strength)

In [20]:
literal_speaker.fit(dev_cols_train, top_hallucinations)

Epoch 40; err = 23.515129979699856

ColorizedInputDescriber(
	hidden_dim=100,
	batch_size=128,
	max_iter=40,
	eta=0.005,
	optimizer=<class 'torch.optim.adam.Adam'>,
	l2_strength=0,
	embed_dim=100,
	embedding=[[ 0.1394268  -0.47498924 -0.22497068 ...  0.02911435  0.47107838
   0.3607797 ]
 [ 0.38472     0.49351     0.49096    ...  0.026263    0.39052
   0.52217   ]
 [-0.66099    -0.073023    0.92379    ... -0.22556     0.8148
  -0.44052   ]
 ...
 [ 0.4765     -0.14409    -0.49884    ... -1.1854     -0.88582
  -0.57597   ]
 [-0.29881     0.81797     1.002      ... -0.23776    -0.90741
   0.55244   ]
 [ 0.38433493  0.12163181  0.07975889 ... -0.19281575  0.35057666
  -0.15122421]])

Let's see how it did.

In [21]:
speaker_preds_train = literal_speaker.predict(dev_cols_train)
listened_preds = literal_listener.predict(dev_cols_train, speaker_preds_train)
correct = sum([1 if x == 2 else 0 for x in listened_preds])
print("test", correct, "/", len(listened_preds), correct/len(listened_preds))

  color_seqs = torch.FloatTensor(color_seqs)


test 34241 / 35245 0.9715136898850901


In [22]:
speaker_preds_test = literal_speaker.predict(dev_cols_test)
listened_preds = literal_listener.predict(dev_cols_test, speaker_preds_test)
correct = sum([1 if x == 2 else 0 for x in listened_preds])
print("test", correct, "/", len(listened_preds), correct/len(listened_preds))

test 11322 / 11749 0.9636564814026726


## Reasoning about hallucinating speakers

We use Monroe et al.'s pragmatic speaker model based off RSA to construct the pragmatic hallucinating speaker.

First, we generate permutations of the color context and feed them into the hallucinating speaker. Then, we do a prediction with the literal listener to find the probabilities of the utterances of each target.

In [23]:
first_color_target = [[col_seq[2], col_seq[1],col_seq[0]] for col_seq in dev_cols_test]
second_color_target = [[col_seq[0], col_seq[2],col_seq[1]] for col_seq in dev_cols_test]

In [24]:
first_col_speaker_pred = generate_listener_hallucinations(first_color_target)
second_col_speaker_pred = generate_listener_hallucinations(second_color_target)
third_col_speaker_pred = generate_listener_hallucinations(dev_cols_test)

Sampling utterances
Preparing Data
Calculating probabilities
Finding top m utterances
Sampling utterances
Preparing Data
Calculating probabilities
Finding top m utterances
Sampling utterances
Preparing Data
Calculating probabilities


RuntimeError: CUDA out of memory. Tried to allocate 1.84 GiB (GPU 0; 8.00 GiB total capacity; 2.27 GiB already allocated; 1.03 GiB free; 5.21 GiB reserved in total by PyTorch)

In [None]:
listened_preds = literal_listener.predict(second_color_target, first_color_preds)
correct = sum([1 if x == 2 else 0 for x in listened_preds])
print("test", correct, "/", len(listened_preds), correct/len(listened_preds))

In [None]:
listened_preds = literal_listener.predict(first_color_target, second_color_preds)
correct = sum([1 if x == 2 else 0 for x in listened_preds])
print("test", correct, "/", len(listened_preds), correct/len(listened_preds))

Finally, we use this to normalize over the best hallucinating utterances and use bayesian inference to select the right target.

In [None]:
num_hallucinations = 5
alpha = 1.
def find_utt_marginal(color_list, speaker_preds, alpha=0.01):
    probs_per_col_context = []
    total_over_all_utt = 0
    for col_target, speaker_pred in [*zip(color_list, speaker_preds)]:
        # Flatten the utterances
        target_utterances = [seq for seq_list in speaker_pred for seq in seq_list]
        input_colors_extended = [item for item in col_target for i in range(num_hallucinations)]
        # Test all the utterances
        #print(len(target_utterances), len(input_colors_extended))
        lit_preds_per_col_context = torch.FloatTensor(literal_listener.predict(input_colors_extended, target_utterances, probabilities=True))
        # Reshape the utterances and only take the prediction of the target
        probs_per_col_context.append(lit_preds_per_col_context.view(-1, num_hallucinations, 3)[:, :, 2] ** alpha)
        total_over_all_utt += torch.sum(probs_per_col_context[-1], dim=1)
    return total_over_all_utt, probs_per_col_context
    
# Find both the marginal over utterances and the probability predictions per color context
color_list = [first_color_target, second_color_target, dev_cols_test]
speaker_preds = [first_col_speaker_pred, second_col_speaker_pred, third_col_speaker_pred]

total_over_all_utt, lit_preds_per_col_context = find_utt_marginal(color_list, speaker_preds, alpha)

In [None]:
test_preds_probs = []
# Now, we calculate the probabilities of the predictions
for ind, col_target, speaker_pred in [*zip(range(3), color_list, speaker_preds)]:
    test_pred = torch.FloatTensor(literal_listener.predict(col_target, dev_seqs_test, probabilities=True))
    test_pred = test_pred[:, 2] ** alpha
    
    target_prior = torch.sum(lit_preds_per_col_context[ind], dim=1)
    totals = total_over_all_utt + test_pred
    test_preds_probs.append(test_pred/target_prior)
test_preds_probs = torch.stack(test_preds_probs, dim=1)

In [None]:
test_preds = torch.argmax(test_preds_probs, dim=1)

In [None]:
correct = sum([1 if x == 2 else 0 for x in test_preds])
print("test", correct, "/", len(test_preds), correct/len(test_preds))

Let's examine the examples that were incorrect and see if we can analyze what happened.

In [None]:
limit = 200
for i, x in enumerate(test_preds):
    if x != 2:
        print(third_col_speaker_pred[i][0], dev_seqs_test[i], x, i)
        limit -= 1
    if limit == 0:
        break