In [1]:
import gzip
import json
import numpy as np
from scipy.misc import logsumexp
import matplotlib.pyplot as plt

In [26]:
import os
RUN_DIR = 'runs/hawkins_big_l2_perp_heldout'

In [2]:
grids_path = os.path.join(RUN_DIR, 'grids.0.jsons.gz')
with gzip.open(grids_path, 'rb') as infile:
    grids = [json.loads(line.strip()) for line in infile]

In [7]:
ALPHA = 0.544

In [3]:
len(grids)

16635

In [11]:
data_path = os.path.join(RUN_DIR, 'data.eval.jsons')
with open(data_path, 'r') as infile:
    insts = [json.loads(line.strip()) for line in infile]
gold_outputs = np.array([inst['output'] for inst in insts])[:len(grids)]

In [4]:
l0_biggrid = np.array([[np.array(ss['L0']).T for ss in grid['sets']] for grid in grids])

In [5]:
all_ss = l0_biggrid.shape[1]
l0_biggrid.shape

(16635, 8, 25, 3)

In [6]:
def compute_s1(l0, alpha):
    l0_a = l0 * alpha
    return l0_a - logsumexp(l0_a, axis=2, keepdims=True)

In [8]:
s1 = compute_s1(l0_biggrid, alpha=ALPHA)

In [22]:
rng = np.random
def sample(a, temperature=1.0):
    # helper function to sample an index from a log probability array
    a = np.array(a)
    if len(a.shape) < 1:
        raise ValueError('scalar is not a valid probability distribution')
    elif len(a.shape) == 1:
        # Cast to higher resolution to try to get high-precision normalization
        a = np.exp(a / temperature).astype(np.float64)
        a /= np.sum(a)
        return np.argmax(rng.multinomial(1, a, 1))
    else:
        return np.array([sample(s, temperature) for s in a])

In [34]:
# Get the log probs of the speaker from sample set 0 (arbitrarily), for the true targets
s1_true_probs = s1[np.arange(s1.shape[0]), 0, 1:, gold_outputs]
s1_true_probs.shape

(16635, 24)

In [35]:
s1_utts_grid = np.array([[np.array(ss['utts']) for ss in grid['sets']] for grid in grids])
s1_utts_grid.shape

(16635, 8, 25)

In [36]:
s1_utts_grid = s1_utts_grid[:, 0, 1:]
s1_utts_grid.shape

(16635, 24)

In [37]:
s1_utts_grid[0]

array([u'khaki', u'apple', u'orange', u'yellow', u'green',
       u'the brown one', u'the more yellow ish', u'yellow', u'red', u'red',
       u'red but pink', u'redish', u'red no rose', u'red', u'the pumpkin',
       u'red', u'grey', u'stormy', u'gray', u'grey', u'concrete', u'grey',
       u'slate', u'grey'], 
      dtype='<U373')

In [38]:
s1_preds = s1_utts_grid[np.arange(s1_utts_grid.shape[0]), np.argmax(s1_true_probs, axis=1)]
print(len(s1_preds))
s1_samples = s1_utts_grid[np.arange(s1_utts_grid.shape[0]), sample(s1_true_probs)]
print(len(s1_samples))

16635
16635


In [39]:
with open(os.path.join(RUN_DIR, 's1_predictions_from_grids.0.jsons'), 'w') as outfile:
    for pred in s1_preds:
        outfile.write(json.dumps(pred) + '\n')
with open(os.path.join(RUN_DIR, 's1_samples_from_grids.0.jsons'), 'w') as outfile:
    for samp in s1_samples:
        outfile.write(json.dumps(samp) + '\n')