In [1]:
%load_ext autoreload
%autoreload 2

import os
import json
import tabulate
from collections import Counter
from IPython.display import HTML, display

# Load wordstat json logs

In [3]:
models_dir = '~/ParlAI/data/controllable_dialogue/wordstat_files'  # Enter the path to your wordstat_files directory here
wordstat_files = [fname for fname in os.listdir(models_dir) if 'wordstats.json' in fname]
mf2data = {} # master dict mapping model file name to its data dict

print('Loading %i files...' % len(wordstat_files), end='')
for idx, json_file in enumerate(sorted(wordstat_files)):
    mf = json_file[:json_file.index('.wordstats.json')]
    print('%i, ' % idx, end='')
    with open(os.path.join(models_dir, json_file), "r") as f:
        data = json.load(f)
    mf2data[mf] = data
print('\nFinished loading files')

Loading 54 files...0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 
Finished loading files


# Make table of automatic metrics

In [4]:
# This cell makes Table 6 from the paper

columns = [
    'extrep_2gram',
    'extrep_nonstopword',
    'intrep_2gram',
    'intrep_nonstopword',
    'partnerrep_2gram',
    'avg_nidf',
    'lastuttsim',
    'question',
]

header_row = ['model name'] + columns

rows = [
    # gold data and baselines
    'goldresponse',
    'convai2_finetuned_baseline.valid.usemodelreply.beam1',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10',

    # repetition control (WD)
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-0.5',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-1.25',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-1e+20',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    
    # question control (CT)
    'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question0.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question1.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question4.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question7.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question10.beamreorder_best_extrep2gram_qn.WDfeatures:extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',

    # specificity control (CT)
    'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf0.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf2.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf4.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf7.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
    'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf9.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',

    # specificity control (WD)
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf-10.0',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf-4.0',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf4.0',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf6.0',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf8.0',
    
    # response-related control (WD)
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim-10.0_partnerrep_2gram-1e+20',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim0.0_partnerrep_2gram-1e+20',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim5.0_partnerrep_2gram-1e+20',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim10.0_partnerrep_2gram-1e+20',
    'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim13.0_partnerrep_2gram-1e+20',

]

def mean(l):
    return sum(l)/len(l)

def model2row(mf, data):
    """Given the data from a json file, make a row of data for the table"""
    row = [mf]
    for attr in columns:
        sent_attrs = data['sent_attrs']
        if attr in sent_attrs:
            attr_mean = mean(sent_attrs[attr])
            if attr in ['avg_nidf', 'lastuttsim']:
                row.append("%.4f" % (attr_mean))
            else:
                row.append("%.2f%%" % (attr_mean*100))
        else:
            row.append('')
    return row

# Build table
table = [header_row] 
for mf in rows:
    data = mf2data[mf]
    table.append(model2row(mf, data))
html = HTML(tabulate.tabulate(table, tablefmt='html', stralign='center'))
html.data = html.data.replace("text-align: center;", "text-align: left;") # fix left-alignment 
display(html)

0,1,2,3,4,5,6,7,8
model name,extrep_2gram,extrep_nonstopword,intrep_2gram,intrep_nonstopword,partnerrep_2gram,avg_nidf,lastuttsim,question
goldresponse,4.65%,9.62%,0.38%,0.97%,5.10%,0.2119,0.1691,28.80%
convai2_finetuned_baseline.valid.usemodelreply.beam1,35.88%,36.31%,8.08%,10.59%,12.20%,0.1688,0.1850,6.46%
convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10,46.85%,44.15%,0.32%,0.61%,12.90%,0.1662,0.0957,80.87%
convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-0.5,19.70%,16.85%,0.26%,0.62%,11.93%,0.1730,0.1348,73.04%
convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-1.25,4.62%,4.79%,0.40%,0.89%,10.61%,0.1763,0.1504,61.22%
convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5,0.75%,4.61%,0.47%,0.94%,9.89%,0.1771,0.1681,48.89%
convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-1e+20,0.00%,4.74%,0.51%,1.05%,9.56%,0.1780,0.1711,45.98%
convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20,0.73%,0.00%,0.17%,0.00%,9.55%,0.1766,0.1676,49.98%
control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question0.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20,0.06%,0.00%,0.19%,0.00%,9.20%,0.1871,0.1753,2.01%


# Show predictions of a model

In [7]:
mf = 'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10' # beam search baseline
num_show = 100  # Show the top 100 most common utterances

def show_preds(mf, num_show=None):
    counter = Counter()
    preds = mf2data[mf]['word_statistics']['pred_list'] # this is the normalized version; use pure_pred_list for unnormalized
    counter.update(preds)
    num_unique = len([p for p,count in counter.items() if count==1])
    print("%% of utterances that are unique: %.2f%% (%i/%i)\n" % (num_unique*100/sum(counter.values()), num_unique, sum(counter.values())))
    print("COUNT   UTTERANCE")
    for p, count in counter.most_common(num_show):
        print("%5i   %s" % (count, p))

show_preds(mf, num_show)

% of utterances that are unique: 14.77% (1152/7801)

COUNT   UTTERANCE
 2945   what city are you from
 1245   what do you do for living
  205   i am good how are you
  190   do you have any pets
  104   what kind of music do you like
   97   hi how are you today
   80   do you have any hobbies
   71   i am great how are you
   51   no i do not do you
   44   do you play any instruments
   42   what kind of music do you play
   41   hello how are you today
   34   i am in cali
   31   i am good thanks for asking
   30   i am stay at home mom
   29   that sounds like lot of fun
   28   i am well how are you
   27   i am doing well how are you
   27   what kind of work do you do
   26   what do you do for work
   25   what is your favorite food
   23   what is your favorite color
   22   i do not have any pets
   22   i am from los angeles
   21   what languages do you speak
   19   what kind of dog do you have
   19   what kind of dogs do you have
   19   what kind of car do you drive
  