In [1]:
import pickle
from utils.analysis_from_interaction import *
from egg.core.language_analysis import Disent
from language_analysis_local import TopographicSimilarityConceptLevel, encode_target_concepts_for_topsim
import os

# calculate metrics from stored interactions

In [2]:
datasets = ('(3,4)', '(3,8)', '(3,16)', '(4,4)', '(4,8)', '(5,4)')
n_attributes = (3, 3, 3, 4, 4, 5)
n_values = (4, 8, 16, 4, 8, 4)
epochs = 300
n_runs = 5
vsf = 0
max_mess_lens = (20, 20, 20, 20, 20, 20)
if vsf == 3:
    vocab_sizes = [16, 28, 52, 16, 28, 52]
    paths = ['results/' + d + '_game_size_10_vsf_3/' for d in datasets]
elif vsf == 0:
    vocab_sizes = [5, 9, 17, 5, 9, 5]
    paths = ['results/' + d + '_game_size_10_vsf_0/' for d in datasets]

In [3]:
context_unaware = False # whether original or context_unaware simulations are evaluated
zero_shot = False # whether zero-shot simulations are evaluated
zero_shot_test = 'specific' # 'generic' or 'specific'
test_interactions = True # whether scores should be calculated on test interactions
test_mode = 'test' # 'test' or 'test_fine' or 'test_sampled_unscaled' or 'test_load_train'
length_cost = True # whether length_cost was applied; length cost runs have been run with early stopping
early_stopping = True # only with length cost and sampled context
rsa = False 
rsa_test = 'testtrainmixed' # 'testtrain' for RSA tests on the test dataset, utterances are sampled from interactions during training
sampled_context = False
hierarchical = False
shared_context = True
if rsa or test_interactions:
    is_gumbel = False
else:
    is_gumbel = True
setting = ""
if length_cost:
    setting = setting + 'length_cost/'
    if not context_unaware:
        setting = setting + 'context_aware'
if context_unaware:
    setting = setting + 'context_unaware'
else:
    if not length_cost:
        setting = setting + 'standard'
if hierarchical:
    setting = setting + '/hierarchical'
if shared_context:
    setting = setting + '/shared_context'
if zero_shot:
    setting = setting + '/zero_shot/' + zero_shot_test
elif sampled_context:
    setting = setting + '/sampled_context'

In [4]:
# get n_epochs if early stopping
if early_stopping:
    
    n_epochs_all_data = []
    for d in range(len(datasets)):
        
        n_epochs = []
        
        for run in range(0, n_runs):
    
            path_to_run = paths[d] + str(setting) +'/' + str(run) + '/' 
            with open(os.path.join(path_to_run, 'loss_and_metrics.pkl'), 'rb') as input_file:
                data = pickle.load(input_file)
                final_epoch = max(data['loss_train'].keys())
                n_epochs.append(final_epoch)
                
        n_epochs_all_data.append(n_epochs)
        
else:
    n_epochs_all_data = []
    for d in range(len(datasets)):
        n_epochs = []
        
        for run in range(n_runs):
            n_epochs.append(epochs)
                
        n_epochs_all_data.append(n_epochs)
            

## entropy scores: MI, effectiveness, efficiency

In [7]:
for d in range(len(datasets)):
    
    n_epochs = n_epochs_all_data[d]
    
    for run in range(0, n_runs):

        path_to_run = paths[d] + str(setting) +'/' + str(run) + '/' 
        if not rsa:
            if not test_interactions:
                path_to_interaction = (path_to_run + 'interactions/train/epoch_' + str(n_epochs[run]) + '/interaction_gpu0')
            else:
                path_to_interaction = (path_to_run + 'interactions/' + test_mode + '/epoch_0/interaction_gpu0')
        else:
            path_to_interaction = (path_to_run + 'interactions/rsa_' + rsa_test + '/epoch_0/interaction_gpu0')

        interaction = torch.load(path_to_interaction)

        attributes = n_attributes[d]
        values = n_values[d]
        scores = information_scores(interaction, attributes, values, normalizer="arithmetic", is_gumbel=is_gumbel, trim_eos=True, max_mess_len=21)   

        if not rsa:
            if not test_interactions:
                pickle.dump(scores, open(path_to_run + 'entropy_scores.pkl', 'wb'))
            else:
                pickle.dump(scores, open(path_to_run + 'entropy_scores_' + test_mode + '.pkl', 'wb'))
        else:
            pickle.dump(scores, open(path_to_run + 'entropy_scores_rsa_' + rsa_test + '.pkl', 'wb'))

  (m_entropy_concept_x_context + c_entropy_concept_x_context - joint_entropy_concept_x_context)
  normalized_effectiveness_conc_x_cont = ((joint_entropy_concept_x_context - m_entropy_concept_x_context)
  normalized_consistency_conc_x_cont = (joint_entropy_concept_x_context - c_entropy_concept_x_context) / m_entropy_concept_x_context


{'normalized_mutual_info': 0.5873561878008984, 'normalized_mutual_info_hierarchical': array([0.6353783 , 0.67341312, 0.55749098]), 'normalized_mutual_info_context_dep': array([0.78058389, 0.68382368, 0.6574306 ]), 'normalized_mutual_info_concept_x_context': array([0.6353783 ,        nan,        nan, 0.86464254, 0.67001922,
              nan, 0.80223225, 0.68372308, 0.6574306 ]), 'effectiveness': 0.5890766296843117, 'effectiveness_hierarchical': array([1.        , 0.82885426, 0.58970942]), 'effectiveness_context_dep': array([0.76074482, 0.69293224, 0.65699975]), 'effectiveness_concept_x_context': array([1.        ,        nan,        nan, 0.89338715, 0.81214968,
              nan, 0.827904  , 0.69229958, 0.65699975]), 'consistency': 0.5856457660087032, 'consistency_hierarchical': array([0.46560765, 0.56706688, 0.52861065]), 'consistency_context_dep': array([0.80148542, 0.67495147, 0.65786203]), 'consistency_concept_x_context': array([0.46560765,        nan,        nan, 0.83768997, 0.570

  normalized_MI_hierarchical = ((m_entropy_hierarchical + c_entropy_hierarchical - joint_entropy_hierarchical)
  normalized_effectiveness_hierarchical = ((joint_entropy_hierarchical - m_entropy_hierarchical)
  normalized_consistency_hierarchical = (joint_entropy_hierarchical - c_entropy_hierarchical) / m_entropy_hierarchical


{'normalized_mutual_info': 0.5534947082113106, 'normalized_mutual_info_hierarchical': array([0.88942926, 0.73425305, 0.58699048, 0.58063811]), 'normalized_mutual_info_context_dep': array([0.74646318, 0.66419329, 0.62971255, 0.64514988]), 'normalized_mutual_info_concept_x_context': array([0.88942926,        nan,        nan,        nan, 0.88096653,
       0.70786367,        nan,        nan, 0.77450666, 0.69581428,
       0.63944198,        nan, 0.82218651, 0.74939521, 0.69413558,
       0.64514988]), 'effectiveness': 0.47477134438401747, 'effectiveness_hierarchical': array([0.95835817, 0.75947295, 0.53919781, 0.5477809 ]), 'effectiveness_context_dep': array([0.63882862, 0.58085242, 0.53483312, 0.57451328]), 'effectiveness_concept_x_context': array([0.95835817,        nan,        nan,        nan, 0.86143425,
       0.7143569 ,        nan,        nan, 0.69752844, 0.64350079,
       0.56231875,        nan, 0.7556636 , 0.71510529, 0.63099647,
       0.57451328]), 'consistency': 0.66351411856

##  message length

In [203]:
# we evaluated message length per hierarchy level after training but 
# you can also use the HierarchicalMessageLength callback and store the results 

for d in range(len(datasets)):
    
    n_epochs = n_epochs_all_data[d]
    
    for run in range(0, n_runs): 
        
        path_to_run = paths[d] + str(setting) +'/' + str(run) + '/'
        if not rsa:
            if not test_interactions:
                path_to_interaction = (path_to_run + 'interactions/train/epoch_' + str(n_epochs[run]) + '/interaction_gpu0')
            else:
                path_to_interaction = (path_to_run + 'interactions/' + test_mode + '/epoch_0/interaction_gpu0')
        else:
            path_to_interaction = (path_to_run + 'interactions/rsa_' + rsa_test + '/epoch_0' + '/interaction_gpu0')
            
        interaction = torch.load(path_to_interaction)

        attributes = n_attributes[d]
        values = n_values[d]
        ml, ml_concept = message_length_per_hierarchy_level(interaction, attributes)
        ml_context, ml_fine_context, ml_coarse_context = message_length_per_context_condition(interaction, attributes)
        scores = {'ml_over_context': ml_context, 'ml_fine_context': ml_fine_context, 'ml_coarse_context': ml_coarse_context}
        
        if not rsa:
            if not test_interactions:
                pickle.dump(ml, open(path_to_run + 'message_length.pkl', 'wb'))
                pickle.dump(ml_concept, open(path_to_run + 'message_length_hierarchical.pkl', 'wb'))
                pickle.dump(scores, open(path_to_run + 'message_length_over_context.pkl', 'wb'))
            else:
                pickle.dump(ml, open(path_to_run + 'message_length_' + test_mode + '.pkl', 'wb'))
                pickle.dump(ml_concept, open(path_to_run + 'message_length_hierarchical_' + test_mode + '.pkl', 'wb'))
                pickle.dump(scores, open(path_to_run + 'message_length_over_context_' + test_mode + '.pkl', 'wb'))            
        else:
            pickle.dump(ml, open(path_to_run + 'message_length_rsa_' + rsa_test + '.pkl', 'wb'))
            pickle.dump(ml_concept, open(path_to_run + 'message_length_hierarchical_rsa_' + rsa_test + '.pkl', 'wb'))
            pickle.dump(scores, open(path_to_run + 'message_length_over_context_rsa_' + rsa_test + '.pkl', 'wb'))

## lexicon properties

In [209]:
distance = 'manhattan' # 'manhattan' or 'euclidean'
for d in range(len(datasets)):
    print(datasets[d])
    
    n_epochs = n_epochs_all_data[d]
    
    for run in range(n_runs): 
        
        path_to_run = paths[d] + str(setting) +'/' + str(run) + '/'
        if not rsa:
            if not test_interactions:
                path_to_interaction = (path_to_run + 'interactions/train/epoch_' + str(n_epochs[run]) + '/interaction_gpu0')
            else:
                path_to_interaction = (path_to_run + 'interactions/' + test_mode + '/epoch_0/interaction_gpu0')
        else:
            path_to_interaction = (path_to_run + 'interactions/rsa_' + rsa_test + '/epoch_0/interaction_gpu0')
            
        interaction = torch.load(path_to_interaction)

        lex_info, unique_messages, num_concepts = informativeness_score(interaction, distance=distance)
        scores = {'lexicon informativeness': lex_info, 'lexicon size': unique_messages, 'number of concepts': num_concepts}
        print(scores)

        if not rsa:
            if not test_interactions:
                pickle.dump(scores, open(path_to_run + 'lexicon_properties_' + distance + '.pkl', 'wb'))
            else:
                pickle.dump(scores, open(path_to_run + 'lexicon_properties_' + distance + '_' + test_mode  + '.pkl', 'wb'))
        else:
            pickle.dump(scores, open(path_to_run + 'lexicon_properties_' + distance + '_rsa_' + rsa_test + '.pkl', 'wb'))

(3,4)
{'lexicon informativeness': nan, 'lexicon size': 742, 'number of concepts': 742}
{'lexicon informativeness': nan, 'lexicon size': 742, 'number of concepts': 742}
{'lexicon informativeness': nan, 'lexicon size': 742, 'number of concepts': 742}
{'lexicon informativeness': nan, 'lexicon size': 742, 'number of concepts': 742}
{'lexicon informativeness': nan, 'lexicon size': 742, 'number of concepts': 742}
(3,8)


KeyboardInterrupt: 

##  symbol redundancy

In [11]:
for d in range(len(datasets)):
    
    n_epochs = n_epochs_all_data[d]
    
    attributes = n_attributes[d]
    values = n_values[d]
    vocab_size = vocab_sizes[d]
    max_mess_len = max_mess_lens[d]
    
    for run in range(n_runs): 
                
        path_to_run = paths[d] + str(setting) +'/' + str(run) + '/'
        if not rsa:
            if not test_interactions:
                path_to_interaction = (path_to_run + 'interactions/train/epoch_' + str(n_epochs[run]) + '/interaction_gpu0')
            else:
                path_to_interaction = (path_to_run + 'interactions/' + test_mode + '/epoch_0/interaction_gpu0')
        else:
            path_to_interaction = (path_to_run + 'interactions/rsa_' + rsa_test + '/epoch_0/interaction_gpu0')
        interaction = torch.load(path_to_interaction)
        redundancy, MI = symbol_frequency(interaction, attributes, values, vocab_size, max_mess_len, is_gumbel, trim_eos=True)
        
        scores = {'symbol_redundancy': redundancy, 'MI_symbol-attribute_value': MI}
        
        if not rsa:
            if not test_interactions:
                pickle.dump(scores, open(path_to_run + 'symbol_redundancy.pkl', 'wb'))
            else:
                pickle.dump(scores, open(path_to_run + 'symbol_redundancy_' + test_mode + '.pkl', 'wb'))
        else:
            pickle.dump(scores, open(path_to_run + 'symbol_redundancy_rsa_' + rsa_test + '.pkl', 'wb'))

  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
  return symbol_frequency / att_val_frequency, mutual_information
