In [1]:
from utils.load_results import *
from utils.plot_helpers import *

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
plt.style.use('default')
import torch
from utils.analysis_from_interaction import *
from language_analysis_local import TopographicSimilarityConceptLevel, encode_target_concepts_for_topsim
import os
if not os.path.exists('analysis'):
    os.makedirs('analysis')
#import plotly.express as px

In [2]:
datasets = ('(3,4)', '(3,8)', '(3,16)', '(4,4)', '(4,8)', '(5,4)')
n_attributes = (3, 3, 3, 4, 4, 5)
n_values = (4, 8, 16, 4, 8, 4)
n_epochs = 300
paths = [f'results/one_hot_encoded_ds/{d}_game_size_10_vsf_3/' for d in datasets]

In [3]:
context_unaware = False # whether original or context_unaware simulations are evaluated
if context_unaware:
    setting = 'context_unaware'
else:
    setting = 'standard'

# Determine Unique Message size 

In [4]:
from collections import Counter

##### Total message size if symbol order matters:

In [5]:
# go through all datasets
for i, d in enumerate(datasets):
    path_to_run = paths[i] + '/' + str(setting) +'/' + str(0) + '/'
    path_to_interaction_train = (path_to_run + 'interactions/train/epoch_' + str(n_epochs) + '/interaction_gpu0')
    path_to_interaction_val = (path_to_run + 'interactions/validation/epoch_' + str(n_epochs) + '/interaction_gpu0')
    interaction = torch.load(path_to_interaction_train)
    print(path_to_interaction_train)
    
    messages = interaction.message.argmax(dim=-1)
    messages = [tuple(msg.tolist()) for msg in messages]  # Convert messages to tuples for hashing
    total_messages = set(messages)  # Set removes duplicate messages
    print("Number of total messages:", len(total_messages))

results/one_hot_encoded_ds/(3,4)_game_size_10_vsf_3//standard/0/interactions/train/epoch_300/interaction_gpu0
Number of total messages: 100


FileNotFoundError: [Errno 2] No such file or directory: 'results/one_hot_encoded_ds/(3,8)_game_size_10_vsf_3//standard/0/interactions/train/epoch_300/interaction_gpu0'

##### Unique message size if you consider messages with the same symbols but in different orders as the same message:

In [24]:
# go through all datasets
for i, d in enumerate(datasets):
    path_to_run = paths[i] + '/' + str(setting) +'/' + str(0) + '/'
    path_to_interaction_train = (path_to_run + 'interactions/train/epoch_' + str(n_epochs) + '/interaction_gpu0')
    path_to_interaction_val = (path_to_run + 'interactions/validation/epoch_' + str(n_epochs) + '/interaction_gpu0')
    interaction = torch.load(path_to_interaction_train)
    print(path_to_interaction_train)
    
    messages = interaction.message.argmax(dim=-1)
    
    # Convert messages to sorted tuples. Sorted orders ints in ascending order.
    sorted_messages = [tuple(sorted(msg.tolist())) for msg in messages]
    
    # Set removes duplicate messages
    unique_sorted_messages = set(sorted_messages)
    
    print("Number of unique sorted messages:", len(unique_sorted_messages))


results/backup/(3,4)_game_size_10_vsf_3/standard/0/interactions/train/epoch_300/interaction_gpu0
Number of unique sorted messages: 92


##### Comparison between unique sorted messages with the number of unique concepts:

In [30]:
sender_input = interaction.sender_input
n_targets = int(sender_input.shape[1]/2)
    
# Get target objects and fixed vectors to re-construct concepts
target_objects = sender_input[:, :n_targets]
target_objects = k_hot_to_attributes(target_objects, n_values[i])
    
# Concepts are defined by a list of target objects (here one sampled target object) and a fixed vector
(objects, fixed) = retrieve_concepts_sampling(target_objects, all_targets=True)
concepts = list(zip(objects, fixed))
    
# Convert concepts to strings
concepts_strings = [(str(obj), str(fixed_vec)) for obj, fixed_vec in concepts]
print(f'concepts strings: {concepts_strings}')

# Set removes duplicated concepts. Note: tuples are returned unordered  
unique_concepts = set(concepts_strings)
    
# Print the number of unique concepts
print("Number of unique concepts:", len(unique_concepts))
    
# Compare the number of unique sorted messages and unique concepts
print("Ratio of unique sorted messages to unique concepts:", len(unique_sorted_messages) / len(unique_concepts))

concepts strings: [('[[3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]\n [3. 1. 1.]]', '[1. 1. 1.]'), ('[[3. 1. 3.]\n [3. 3. 3.]\n [3. 1. 3.]\n [3. 0. 3.]\n [3. 3. 3.]\n [3. 3. 3.]\n [3. 2. 3.]\n [3. 0. 3.]\n [3. 2. 3.]\n [3. 1. 3.]]', '[1. 0. 1.]'), ('[[2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]\n [2. 2. 2.]]', '[1. 1. 1.]'), ('[[1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]\n [1. 2. 3.]]', '[1. 1. 1.]'), ('[[1. 0. 1.]\n [1. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [3. 0. 1.]\n [1. 0. 1.]\n [2. 0. 1.]\n [3. 0. 1.]\n [0. 0. 1.]]', '[0. 1. 1.]'), ('[[0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]\n [0. 0. 1.]]', '[1. 1. 1.]'), ('[[3. 0. 3.]\n [3. 0. 3.]\n [3. 0. 3.]\n [3. 0. 3.]\n [3. 0. 3.]\n [3. 0. 3.]\n [3. 0.