In [110]:
from utils.load_results import *
from utils.plot_helpers import *

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
plt.style.use('default')
import torch
from utils.analysis_from_interaction import *
from language_analysis_local import TopographicSimilarityConceptLevel, encode_target_concepts_for_topsim
import os
if not os.path.exists('analysis'):
    os.makedirs('analysis')
#import plotly.express as px
from collections import Counter

### Utilities

In [111]:
def objects_to_concepts(sender_input):
    """reconstruct concepts from objects in interaction"""
    n_targets = int(sender_input.shape[1]/2)
    # get target objects and fixed vectors to re-construct concepts
    target_objects = sender_input[:, :n_targets]
    target_objects = k_hot_to_attributes(target_objects, n_values[i])
    # concepts are defined by a list of target objects (here one sampled target object) and a fixed vector
    (objects, fixed) = retrieve_concepts_sampling(target_objects, all_targets=True)
    concepts = list(zip(objects, fixed))
    return concepts

In [112]:
def retrieve_messages(interaction):
    """retrieve messages from interaction"""
    messages = interaction.message.argmax(dim=-1)
    messages = [msg.tolist() for msg in messages]
    return messages

In [113]:
def count_symbols(messages):
    """counts symbols in messages"""
    all_symbols = [symbol for message in messages for symbol in message]
    symbol_counts = Counter(all_symbols)
    return symbol_counts

In [114]:
def get_unique_message_set(messages):
    """returns unique messages as a set ready for set operations"""
    return set(tuple(message) for message in messages)

In [115]:
def get_unique_concept_set(concepts):
    """returns unique concepts"""
    concept_tuples = []
    for objects, fixed in concepts:
        tuple_objects = []
        for object in objects:
            tuple_objects.append(tuple(object))
        tuple_objects = tuple(tuple_objects)
        tuple_concept = (tuple_objects, tuple(fixed))
        concept_tuples.append(tuple_concept)
    tuple(concept_tuples)
    unique_concepts = set(concept_tuples)
    return unique_concepts

### Configurations

In [116]:
datasets = ['(3,4)', '(3,8)', '(3,16)', '(4,4)', '(4,8)', '(5,4)']
n_values = [4, 8, 16, 4, 8, 4]
n_attributes = [3, 3, 3, 4, 4, 5]
n_epochs = 300
n_datasets = len(datasets)
paths = ['results/' + d + '_game_size_10_vsf_3' for d in datasets]

In [62]:
datasets = ['(3,4)', '(3,8)']
n_values = [4, 8]
n_attributes = [3, 3]
n_epochs = 300
n_datasets = len(datasets)
paths = ['results/' + d + '_game_size_10_vsf_3' for d in datasets]

In [138]:
context_unaware = False # whether original or context_unaware simulations are evaluated
zero_shot = True # whether zero-shot simulations are evaluated
zero_shot_test = 'generic' # 'generic' or 'specific'
test_interactions = True # whether scores should be calculated on test interactions (only with zero shot)
test_as = 'test_sampled_unscaled' # 'test' or 'test_sampled_unscaled' or 'test_unscaled' or 'test_fine' 
setting = ""
if context_unaware:
    setting = setting + 'context_unaware'
else:
    setting = setting + 'standard'
if zero_shot:
    setting = setting + '/zero_shot/' + zero_shot_test

### Determine vocab size and message reuse

In [125]:
# go through all datasets
for i, d in enumerate(datasets):
    print(d)
    for run in range(5):
        path_to_run = paths[i] + '/' + str(setting) +'/' + str(run) + '/'
        path_to_interaction_train = (path_to_run + 'interactions/train/epoch_' + str(n_epochs) + '/interaction_gpu0')
        path_to_interaction_val = (path_to_run + 'interactions/validation/epoch_' + str(n_epochs) + '/interaction_gpu0')
        path_to_interaction_test = (path_to_run + 'interactions/' + str(test_as) +'/epoch_0/interaction_gpu0')
        interaction_train = torch.load(path_to_interaction_train)
        interaction_val = torch.load(path_to_interaction_val)
        interaction_test = torch.load(path_to_interaction_test)
        
        concepts_train = objects_to_concepts(interaction_train.sender_input)
        concepts_val = objects_to_concepts(interaction_val.sender_input)
        concepts_test = objects_to_concepts(interaction_test.sender_input)
        
        messages_train = retrieve_messages(interaction_train)
        messages_val = retrieve_messages(interaction_val)
        messages_test = retrieve_messages(interaction_test)
    
        symbol_counts_train = count_symbols(messages_train)
        symbol_counts_val = count_symbols(messages_val)
        symbol_counts_test = count_symbols(messages_test)
        symbol_counts = [symbol_counts_train, symbol_counts_val, symbol_counts_test]
        pickle.dump(symbol_counts, open(path_to_run + 'symbol_counts_' + str(test_as) + '.pkl', 'wb'))
        
        # consider train and validation messages together
        messages_train_val = messages_train +  messages_val
        # consider only unique messages
        messages_train_val_unique = get_unique_message_set(messages_train_val)
        #print("messages train val", len(messages_train_val), len(messages_train_val_unique))
        messages_test_unique = get_unique_message_set(messages_test)
        #print("messages test", len(messages_test), len(messages_test_unique))
        # total messages
        messages_total = messages_train_val +  messages_test
        messages_total_unique = get_unique_message_set(messages_total)
        
        # concepts
        concepts_train_unique = get_unique_concept_set(concepts_train)
        concepts_val_unique = get_unique_concept_set(concepts_val)
        concepts_test_unique = get_unique_concept_set(concepts_test)
        #print("concepts", len(concepts_test), len(concepts_test_unique))
        concepts_total = concepts_train + concepts_val + concepts_test
        concepts_total_unique = get_unique_concept_set(concepts_total)
        num_of_concepts = [len(concepts_train_unique), len(concepts_val_unique), len(concepts_test_unique), len(concepts_total_unique), len(concepts_total)]
        pickle.dump(num_of_concepts, open(path_to_run + 'num_of_concepts_' + str(test_as) + '.pkl', 'wb'))
        
        # messages reused in testing:
        intersection = messages_train_val_unique & messages_test_unique
        
        # messages only used in training:
        difference_train = messages_train_val_unique - messages_test_unique
        
        # messages only used in testing:
        difference_test = messages_test_unique - messages_train_val_unique
        print(len(difference_test), "novel messages used for the", len(concepts_test_unique), "novel concepts")
        
        message_reuse = [len(intersection), len(difference_train), len(difference_test), len(concepts_test_unique), (len(difference_test)/len(concepts_test_unique)), len(messages_test_unique)]
        pickle.dump(message_reuse, open(path_to_run + 'message_reuse_' + str(test_as) + '.pkl', 'wb'))

(3,4)
31 novel messages used for the 64 novel concepts
19 novel messages used for the 64 novel concepts
19 novel messages used for the 64 novel concepts
28 novel messages used for the 64 novel concepts
28 novel messages used for the 64 novel concepts
(3,8)
218 novel messages used for the 512 novel concepts
180 novel messages used for the 512 novel concepts
130 novel messages used for the 512 novel concepts
239 novel messages used for the 512 novel concepts
195 novel messages used for the 512 novel concepts
(3,16)
590 novel messages used for the 4096 novel concepts
422 novel messages used for the 4096 novel concepts
520 novel messages used for the 4096 novel concepts
701 novel messages used for the 4096 novel concepts
400 novel messages used for the 4096 novel concepts
(4,4)
127 novel messages used for the 256 novel concepts
168 novel messages used for the 256 novel concepts
123 novel messages used for the 256 novel concepts
143 novel messages used for the 256 novel concepts
99 novel me

In [139]:
message_reuse_dict = {'intersection': [], 'difference train': [], 'difference test': [], 'concepts test unique': [], 'test ratio': [], 'messages test unique': [],
                      'reuse rate': [], 'novelty rate': [], 'total ratio': []}
for i, d in enumerate(datasets):
    intersection, train_difference, test_difference, test_concepts, test_ratio, test_messages, reuse_rate, novelty_rate, total_ratio = [], [], [], [], [], [], [], [], []
    for run in range(5):
        path_to_run = paths[i] + '/' + str(setting) +'/' + str(run) + '/'
        message_reuse = pickle.load(open(path_to_run + 'message_reuse_' + str(test_as) + '.pkl', 'rb'))
        intersection.append(message_reuse[0]) # messages reused in testing
        train_difference.append(message_reuse[1]) # messages only used in training
        test_difference.append(message_reuse[2]) # messages only used in testing
        test_concepts.append(message_reuse[3]) # total number of concepts used in testing
        test_ratio.append(message_reuse[4]) 
        test_messages.append(message_reuse[5]) # total number of messages used in testing
        reuse_rate.append(message_reuse[0]/message_reuse[5])
        novelty_rate.append(message_reuse[2]/message_reuse[5])
        total_ratio.append(message_reuse[5]/message_reuse[3]) # test_messages / test_concepts (novel unique messages & concepts)
    message_reuse_dict['intersection'].append(intersection)
    message_reuse_dict['difference train'].append(train_difference)
    message_reuse_dict['difference test'].append(test_difference)
    message_reuse_dict['concepts test unique'].append(test_concepts)
    message_reuse_dict['test ratio'].append(test_ratio)
    message_reuse_dict['messages test unique'].append(test_messages)
    message_reuse_dict['reuse rate'].append(reuse_rate)
    message_reuse_dict['novelty rate'].append(novelty_rate)
    message_reuse_dict['total ratio'].append(total_ratio)

In [140]:
message_reuse = [message_reuse_dict['concepts test unique'], message_reuse_dict['messages test unique'], message_reuse_dict['total ratio'], message_reuse_dict['reuse rate'], message_reuse_dict['novelty rate']]

# Convert the list to a NumPy array
mess_reuse_array = np.array(message_reuse)

# Compute means and standard deviations over the five runs
means = np.mean(mess_reuse_array, axis=-1)
std_devs = np.std(mess_reuse_array, axis=-1)

# Row names and column names
row_names = ["D(3,4)", "D(3,8)", "D(3,16)", "D(4,4)", "D(4,8)", "D(5,4)"]
col_names = ["test concepts", "test messages", "message-concept ratio", "reuse rate","novelty rate"]

# Prepare the data for the DataFrames
data = []

# iterate over datasets
for i in range(means.shape[1]):
    row = []
    # iterate over conditions
    for j in range(means.shape[0]):
        if j > 1:
            formatted_value = f"{means[j, i]:.2f} $\\pm$ {std_devs[j, i]:.2f}"
        elif j == 0:
            formatted_value = f"{int(means[j, i])}"
        else:
            formatted_value = f"{means[j, i]:.1f} $\\pm$ {std_devs[j, i]:.1f}"
        row.append(formatted_value)
    data.append(row)

# Create DataFrames
df = pd.DataFrame(data, index=row_names, columns=col_names)

# Convert DataFrames to LaTeX tables
latex_table = df.to_latex(index=True, escape=False)

print(latex_table)

\begin{tabular}{llllll}
\toprule
{} & test concepts &   test messages & message-concept ratio &       reuse rate &     novelty rate \\
\midrule
D(3,4)  &            12 &  11.8 $\pm$ 0.4 &       0.98 $\pm$ 0.03 &  0.80 $\pm$ 0.19 &  0.20 $\pm$ 0.19 \\
D(3,8)  &            24 &  22.2 $\pm$ 0.4 &       0.93 $\pm$ 0.02 &  0.83 $\pm$ 0.07 &  0.17 $\pm$ 0.07 \\
D(3,16) &            48 &  44.4 $\pm$ 0.8 &       0.93 $\pm$ 0.02 &  0.75 $\pm$ 0.06 &  0.25 $\pm$ 0.06 \\
D(4,4)  &            16 &  15.6 $\pm$ 0.8 &       0.97 $\pm$ 0.05 &  0.74 $\pm$ 0.14 &  0.26 $\pm$ 0.14 \\
D(4,8)  &            32 &  28.8 $\pm$ 3.1 &       0.90 $\pm$ 0.10 &  0.93 $\pm$ 0.06 &  0.07 $\pm$ 0.06 \\
D(5,4)  &            20 &  19.6 $\pm$ 0.5 &       0.98 $\pm$ 0.02 &  1.00 $\pm$ 0.00 &  0.00 $\pm$ 0.00 \\
\bottomrule
\end{tabular}


  latex_table = df.to_latex(index=True, escape=False)


### Symbol reuse
Also in "to generic" condition, all symbols are reused during testing, i.e. they all encode relevant information. This is why a qualitative analysis of messages makes more sense.

In [102]:
def symbol_frequency(interaction, n_attributes, n_values, vocab_size, is_gumbel=True):
    messages = interaction.message.argmax(dim=-1) if is_gumbel else interaction.message
    messages = messages[:, :-1] # without EOS
    sender_input = interaction.sender_input
    n_objects = sender_input.shape[1]
    n_targets = int(n_objects / 2)
    # k_hots = sender_input[:, :-n_attributes]
    # objects = k_hot_to_attributes(k_hots, n_values)
    target_objects = sender_input[:, :n_targets]
    target_objects = k_hot_to_attributes(target_objects, n_values)
    # intentions = sender_input[:, -n_attributes:]  # (0=same, 1=any)
    (objects, fixed) = retrieve_concepts_sampling(target_objects)

    objects[fixed == 1] = np.nan

    objects = objects
    messages = messages
    favorite_symbol = {}
    mutual_information = {}
    for att in range(n_attributes):
        for val in range(n_values):
            object_labels = (objects[:, att] == val).astype(int)
            max_MI = 0
            for symbol in range(vocab_size):
                symbol_indices = np.argwhere(messages == symbol)[0]
                symbol_labels = np.zeros(len(messages))
                symbol_labels[symbol_indices] = 1
                MI = normalized_mutual_info_score(symbol_labels, object_labels)
                if MI > max_MI:
                    max_MI = MI
                    max_symbol = symbol
            favorite_symbol[str(att) + str(val)] = max_symbol
            mutual_information[str(att) + str(val)] = max_MI

    return favorite_symbol, mutual_information

In [107]:
context_unaware = False # whether original or context_unaware simulations are evaluated
zero_shot = True # whether zero-shot simulations are evaluated
zero_shot_test = 'generic' # 'generic' or 'specific'
test_interactions = True # whether scores should be calculated on test interactions (only with zero shot)
setting = ""
if context_unaware:
    setting = setting + 'context_unaware'
else:
    setting = setting + 'standard'
if zero_shot:
    setting = setting + '/zero_shot/' + zero_shot_test

In [108]:
for run in range(5):
    path_to_run = paths[0] + '/' + str(setting) +'/' + str(run) + '/'
    path_to_interaction_train = (path_to_run + 'interactions/train/epoch_' + str(n_epochs) + '/interaction_gpu0')
    path_to_interaction_val = (path_to_run + 'interactions/validation/epoch_' + str(n_epochs) + '/interaction_gpu0')
    path_to_interaction_test = (path_to_run + 'interactions/test/epoch_0/interaction_gpu0')
    interaction_train = torch.load(path_to_interaction_train)
    interaction_val = torch.load(path_to_interaction_val)
    interaction_test = torch.load(path_to_interaction_test)
    
    # retrieve "lexicon" based on mutual information
    # hard-code for D(3,4) for now
    favorite_symbol, mutual_information = symbol_frequency(interaction_train, n_attributes=3, n_values=4, vocab_size=13)
    print(favorite_symbol)

    messages = interaction_test.message.argmax(dim=-1)
    messages = [msg.tolist() for msg in messages]
    sender_input = interaction_test.sender_input
    print(sender_input.shape)
    n_targets = int(sender_input.shape[1]/2)
    # get target objects and fixed vectors to re-construct concepts
    target_objects = sender_input[:, :n_targets]
    target_objects = k_hot_to_attributes(target_objects, n_values[i])
    # concepts are defined by a list of target objects (here one sampled target object) and a fixed vector
    (objects, fixed) = retrieve_concepts_sampling(target_objects, all_targets=True)
    concepts = list(zip(objects, fixed))

    # get distractor objects to re-construct context conditions
    distractor_objects = sender_input[:, n_targets:]
    distractor_objects = k_hot_to_attributes(distractor_objects, n_values[i])
    context_conds = retrieve_context_condition(objects, fixed, distractor_objects)

    # get random qualitative samples
    #fixed_index = random.randint(0, n_attributes[i]-1) # define a fixed index for the concept
    #n_fixed = random.randint(1, n_attributes[i]) # how many fixed attributes?
    n_fixed = 3
    #fixed_indices = random.sample(range(0, n_attributes[i]), k=n_fixed) # select which attributes are fixed
    fixed_indices = [0, 1, 2]
    #fixed_value = random.randint(0, n_values[i]-1) # define a fixed value for this index
    fixed_values = random.choices(range(0, n_values[i]), k=n_fixed)
    fixed_values = [3, 0, 1]
    print(n_fixed, fixed_indices, fixed_values)
    #index_threshold = 20000 # optional: define some index threshold to make sure that examples are not taken from the beginning of training
    # TODO: adapt this loop such that multiple indices can be fixed
    all_for_this_concept = []
    for idx, (t_objects, t_fixed) in enumerate(concepts):
        #if sum(t_fixed) == 1 and t_fixed[fixed_index] == 1:# and idx > index_threshold:
        if sum(t_fixed) == n_fixed and all(t_fixed[fixed_index] == 1 for fixed_index in fixed_indices):
            for t_object in t_objects:
                if all(t_object[fixed_index] == fixed_values[j] for j, fixed_index in enumerate(fixed_indices)):
                    all_for_this_concept.append((idx, t_object, t_fixed, context_conds[idx], messages[idx]))
                    fixed = t_fixed
    if len(all_for_this_concept) > 0:
        #sample = random.sample(all_for_this_concept, 20)
        sample = all_for_this_concept
        column_names = ['game_nr', 'object', 'fixed indices', 'context condition', 'message']
        df = pd.DataFrame(sample, columns=column_names)
        print(df)
        #df.to_csv('analysis/quali_' + str(d) + '_' + str(setting) + '_' + str(sample[0][1]) + ',' + str(fixed) + 'all.csv', index=False)
        #print('saved ' + 'analysis/quali_' + str(d) + '_' + str(setting) + '_' + str(sample[0][1]) + ',' + str(fixed) + 'all.csv')
    else:
        raise ValueError("sample for dataset " + str(d) + " could not be generated")

{'00': 11, '01': 11, '02': 11, '03': 11, '10': 2, '11': 12, '12': 7, '13': 2, '20': 8, '21': 8, '22': 8, '23': 8}
torch.Size([60, 20, 12])
3 [0, 1, 2] [3, 0, 1]


ValueError: sample for dataset (3,4) could not be generated