In [1]:
import random
from collections import Counter, defaultdict
# from statistics import mean, stdev

import numpy as np
import torch

from egg.core.language_analysis import calc_entropy
from egg.zoo.referential_language.games import build_game

from IPython.core.debugger import set_trace

In [15]:
def load_model(model_args, checkpoint_path):
    model = build_game(model_args)
    model_ckpt = torch.load(checkpoint_path)
    model.load_state_dict(model_ckpt.model_state_dict)
    print(f"| Successfully loaded model from {checkpoint_path}")
    return model

# DA
checkpoint_path = "/private/home/rdessi/contextual_emcomm/da_model/53845331_0/final.tar"
interaction_path = "/private/home/rdessi/contextual_emcomm/da_model/interactions/interaction_53845331_0"
# TO
#checkpoint_path = "/private/home/rdessi/contextual_emcomm/to_model/53820379_0/final.tar"
#interaction_path = "/private/home/rdessi/contextual_emcomm/to_model/interactions/interaction_53820379_0"

In [42]:
def get_msg_counters(msg, mask):
    n_batches, bsz, max_objs = mask.shape
    msg = torch.argmax(msg.view(n_batches, bsz, max_objs, 2, -1), dim=-1)

    nb_samples = 0
    list_sym1, list_sym2, list_msg = [], [], []
    for batch_id, batch in enumerate(msg):
        for elem_id, batch_elem in enumerate(batch):
            for msg_id, msg in enumerate(batch_elem):
                if mask[batch_id, elem_id, msg_id] is False:
                    continue
                nb_samples += 1
                list_sym1.append(msg[0].item())
                list_sym2.append(msg[1].item())
                list_msg.append(tuple(msg.tolist()))
    print(f"{calc_entropy(list_sym1)}")
    print(f"{calc_entropy(list_sym2)}")
    print(f"{calc_entropy(list_msg)}")
    return Counter(list_sym1), Counter(list_sym2), Counter(list_msg)

In [36]:
def compute_accuracy(msgs, recv_input, receiver, loss, mask, counter_sym1, counter_sym2, counter_msg):
    acc = 0.0
    for batch_id in range(msgs.shape[0]):
        msg = msgs[batch_id]
        images = recv_input[batch_id]

        recv_output = receiver(msg, images)

        aux_input = {"mask": mask[batch_id]}
        _, aux = loss(None, None, None, recv_output, None, aux_input)
        acc += aux["acc"].item()

    print(f"original acc = {acc / (batch_id + 1):.4f}")

    c1, c2, c_msg = list(counter_sym1), list(counter_sym2), list(counter_msg)
    """
    acc = 0.0
    for batch_id in range(msgs.shape[0]):
        msg = msgs[batch_id]
        for m in msg:
            idx = torch.argmax(m[0], -1)
            new_idx = counter_sym1.most_common(1)[0][0]
            if idx == new_idx:
                new_idx = counter_sym1.most_common(2)[1][0]
            #new_idx = random.sample(c1, k=1)[0]
            #while new_idx == idx:
            #    new_idx = random.sample(c1, k=1)[0]
            m[0][idx] = 0
            m[0][new_idx] = 1
            
        images = recv_input[batch_id]

        recv_output = receiver(msg, images)

        aux_input = {"mask": mask[batch_id]}
        _, aux = loss(None, None, None, recv_output, None, aux_input)
        acc += aux["acc"].item()
    print(f"acc sym1 = {acc / (batch_id + 1):.4f}")
    
    acc = 0.0
    for batch_id in range(msgs.shape[0]):
        msg = msgs[batch_id]
        for m in msg:
            idx = torch.argmax(m[1], -1)
            new_idx = counter_sym2.most_common(1)[0][0]
            if idx == new_idx:
                new_idx = counter_sym2.most_common(2)[1][0]
            #new_idx = random.sample(c1, k=1)[0]
            #while new_idx == idx:
            #    new_idx = random.sample(c1, k=1)[0]
            m[1][idx] = 0
            m[1][new_idx] = 1
            
        images = recv_input[batch_id]

        recv_output = receiver(msg, images)

        aux_input = {"mask": mask[batch_id]}
        _, aux = loss(None, None, None, recv_output, None, aux_input)
        acc += aux["acc"].item()
    print(f"acc sym2 = {acc / (batch_id + 1):.4f}")
    
    """
    acc = 0.0
    for batch_id in range(msgs.shape[0]):
        msg = msgs[batch_id]
        for m in msg:
            s1 = torch.argmax(m[0], -1)
            s2 = torch.argmax(m[1], -1)
            msg = (s1, s1)
            new_msg = random.sample(c_msg, k=1)[0]
            while new_msg == msg:
                new_msg = random.sample(c_msg, k=1)[0]
            set_trace()
            m[0][s1] = 0
            m[1][s2] = 0
            m[0][new_msg[0]] = 1
            m[1][new_msg[1]] = 1
            
            #if idx == new_idx:
            #    new_idx = counter_sym2.most_common(2)[1][0]
            #new_idx = random.sample(c1, k=1)[0]
            #while new_idx == idx:
            #    new_idx = random.sample(c1, k=1)[0]
            #m[1][idx] = 0
            #m[1][new_idx] = 1
            
        images = recv_input[batch_id]

        recv_output = receiver(msg, images)

        aux_input = {"mask": mask[batch_id]}
        _, aux = loss(None, None, None, recv_output, None, aux_input)
        acc += aux["acc"].item()
    print(f"acc msg = {acc / (batch_id + 1):.4f}")

In [16]:
interaction = torch.load(interaction_path)
model = load_model(model_args=interaction.aux_input["args"], checkpoint_path=checkpoint_path)

| Successfully loaded model from /private/home/rdessi/contextual_emcomm/da_model/53845331_0/final.tar


In [43]:
msg = interaction.message
mask = interaction.aux_input["mask"]
counter_sym1, counter_sym2, counter_msg = get_msg_counters(msg, mask)

3.813432543404763
2.5219374258830447
6.135146735257421


In [None]:
counter_sym1

In [34]:
receiver = model.game.receiver
loss = model.game.loss
recv_input = interaction.aux_input["recv_img_feats"]
msg = interaction.message.clone()
mask = interaction.aux_input["mask"]

In [35]:
compute_accuracy(msg, recv_input, receiver, loss, mask, counter_sym1, counter_sym2, counter_msg)

original acc = 0.8689


TypeError: tuple indices must be integers or slices, not tuple