In [20]:
from collections import defaultdict
from statistics import mean, stdev

import numpy as np
import torch

from IPython.core.debugger import set_trace

In [18]:
def entropy_dict(freq_table):
    t = torch.tensor([v for v in freq_table.values()]).float()
    if (t < 0.0).any():
        raise RuntimeError("Encountered negative probabilities")

    t /= t.sum()
    return -(torch.where(t > 0, t.log(), t) * t).sum().item() / np.log(2)

def calc_entropy(messages):
    freq_table = defaultdict(float)
    for m in messages:
        m = _hashable_tensor(m)
        freq_table[m] += 1.0

    return entropy_dict(freq_table)

def _hashable_tensor(t):
    if torch.is_tensor(t) and t.numel() > 1:
        t = tuple(t.tolist())
    elif torch.is_tensor(t) and t.numel() == 1:
        t = t.item()
    elif isinstance(t, list):
        t = tuple(t)
    return t

In [31]:
# interaction_path = "/private/home/rdessi/contextual_emcomm/previous_exp/ctx/self_attn/interactions/interaction_53233351_0"

# DA model
#interaction_path = "/private/home/rdessi/contextual_emcomm/ctx/self_attn_diff_sender_mlp_same_recv_mlp_best_setup/interactions/interaction_53355809_3"
# TO Model
interaction_path = "/private/home/rdessi/contextual_emcomm/control_no_attn_double_separate_sender_same_recv_9distractors/interactions/interaction_53561113_1"
interaction = torch.load(interaction_path)

In [32]:
top_attn_score_right_guess = []
top_attn_score_wrong_guess = []
attn_weights = interaction.aux_input["attn_weights"].squeeze()[:, :, 0, ...]
receiver_output = interaction.receiver_output

errs = 0
for batch_id, batch in enumerate(interaction.aux_input["all_accs"]):
    for elem_id, batch_elem in enumerate(batch):
        for guess_id, guess_acc in enumerate(batch_elem):
            if interaction.aux_input["mask"][batch_id, elem_id, guess_id] == False:
                # skipping masked elements
                continue
            top_attn_score_idx = torch.argmax(attn_weights[batch_id, elem_id, guess_id], dim=-1)
            top_attn_score_value = attn_weights[batch_id, elem_id, guess_id][top_attn_score_idx].item()
            
            if guess_acc.item() == 0.:
                errs += 1
                correct_guess = guess_id
                recv_guess = receiver_output[batch_id, elem_id, guess_id]
                assert correct_guess != recv_guess.item()
                top_attn_score_wrong_guess.append(top_attn_score_value)
            
            else:
                top_attn_score_right_guess.append(top_attn_score_value)
                
print(f'acc = {100 - round(errs / interaction.aux_input["mask"].sum().item(), 4) * 100}%')

all_top_attn_scores = top_attn_score_right_guess + top_attn_score_wrong_guess
print(f"mean top attn score {mean(all_top_attn_scores):.4f} ± {stdev(all_top_attn_scores):.4f}")

print(f"mean top attn score right guess {mean(top_attn_score_right_guess):.4f} ± {stdev(top_attn_score_right_guess):.4f}")
print(f"mean top attn score wrong guess {mean(top_attn_score_wrong_guess):.4f} ± {stdev(top_attn_score_wrong_guess):.4f}")

AttributeError: 'NoneType' object has no attribute 'squeeze'

In [33]:
n_batches, bsz, max_objs = interaction.aux_input["mask"].shape
messages = interaction.message.view(n_batches, bsz, max_objs, 2, -1)
first_symbol = []
second_symbol = []
two_symbol_msg = []
total = 0
for batch_id, batch in enumerate(interaction.aux_input["all_accs"]):
    for elem_id, batch_elem in enumerate(batch):
        for guess_id, guess_acc in enumerate(batch_elem):
            if interaction.aux_input["mask"][batch_id, elem_id, guess_id] == False:
                # skipping masked elements
                continue
            total += 1
            msg_idx_first_symbol = torch.argmax(messages[batch_id, elem_id, guess_id][0]).item()
            first_symbol.append(msg_idx_first_symbol)
            
            msg_idx_second_symbol = torch.argmax(messages[batch_id, elem_id, guess_id][1]).item()
            second_symbol.append(msg_idx_second_symbol)
            
            two_symbol_msg.append((msg_idx_first_symbol, msg_idx_second_symbol))


assert len(first_symbol) == len(second_symbol)

print(f"Total number of objects in the test set {total}")
lexicon1 = set(first_symbol)
lexicon2 = set(second_symbol)
print(f"entropy symbol1: {calc_entropy(first_symbol):.4f}, max={np.log2(len(lexicon1)):.4f}")
print(f"entropy symbol2: {calc_entropy(second_symbol):.4f}, max={np.log2(len(lexicon2)):.4f}")
print(f"cardinality lexicon of symb1 {len(lexicon1)}")
print(f"cardinality lexicon of symb2 {len(lexicon2)}")

print(f"intersection of first and second symbol = {len(list(size_lexicon1 & size_lexicon2))} elems")

lexicon_two_symbol_msg = set(two_symbol_msg)
print(f"entropy two symbol msgs: {calc_entropy(lexicon_two_symbol_msg):.4f}, max={np.log2(len(lexicon_two_symbol_msg)):.4f}")
print(f"cardinality lexicon of two symbol msgs {len(lexicon_two_symbol_msg)}")

Total number of objects in the test set 39949
entropy symbol1: 3.8688, max=4.3219
entropy symbol2: 3.9605, max=4.4594
cardinality lexicon of symb1 20
cardinality lexicon of symb2 22
intersection of first and second symbol = 0 elems
entropy two symbol msgs: 8.2095, max=8.2095
cardinality lexicon of two symbol msgs 296
