In [1]:
from utils.load_results import *
from utils.plot_helpers import *

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import torch
from utils.analysis_from_interaction import *
from language_analysis_local import TopographicSimilarityConceptLevel, encode_target_concepts_for_topsim
from egg.core.callbacks import Checkpoint, CheckpointSaver
from egg import core
from archs import Sender, Receiver
import torch.nn as nn
import dataset

In [2]:
# For convenince and reproducibility, we set some EGG-level command line arguments here
opts = core.init(params=['--random_seed=7', # will initialize numpy, torch, and python RNGs
                         '--lr=1e-3',   # sets the learning rate for the selected optimizer 
                         '--batch_size=32',
                         '--optimizer=adam'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATH = 'results/(5,4)_game_size_10_vsf_3/standard/zero_shot/specific/0/final.tar'

In [3]:
checkpoint = torch.load('results/(5,4)_game_size_10_vsf_3/standard/zero_shot/specific/0/final.tar', map_location=torch.device('cpu'))

In [4]:
checkpoint.model_state_dict.keys()

odict_keys(['sender.sos_embedding', 'sender.agent.fc1.weight', 'sender.agent.fc1.bias', 'sender.agent.fc2.weight', 'sender.agent.fc2.bias', 'sender.agent.fc3.weight', 'sender.agent.fc3.bias', 'sender.hidden_to_output.weight', 'sender.hidden_to_output.bias', 'sender.embedding.weight', 'sender.embedding.bias', 'sender.cell.weight_ih', 'sender.cell.weight_hh', 'sender.cell.bias_ih', 'sender.cell.bias_hh', 'receiver.agent.fc1.weight', 'receiver.agent.fc1.bias', 'receiver.cell.weight_ih', 'receiver.cell.weight_hh', 'receiver.cell.bias_ih', 'receiver.cell.bias_hh', 'receiver.embedding.weight', 'receiver.embedding.bias'])

In [5]:
checkpoint.model_state_dict['sender.agent.fc1.weight']

tensor([[-0.3765, -0.1887, -0.4408,  ..., -0.5134,  0.3979,  0.1921],
        [ 0.0589,  0.7363,  0.1616,  ..., -0.7936,  0.1686,  0.8655],
        [ 0.0942, -0.2178, -0.7711,  ..., -0.1211,  0.0076, -0.1230],
        ...,
        [-0.1728,  0.1262,  0.2502,  ..., -0.2342,  0.1369, -0.3680],
        [ 1.4740, -1.2592, -1.5957,  ..., -0.0123, -0.0830, -0.5855],
        [-0.2176, -0.0708,  0.0612,  ...,  0.2657,  0.3926,  0.0561]])

In [6]:
hidden_size = 128
dimensions = [4, 4, 4, 4, 4]
game_size = 10
context_unaware = False
vocab_size_factor = 3
max_mess_len = None
sender_cell = 'gru'
temperature = 2.0
receiver_cell = 'gru'
learning_rate = 0.001
device = 'cuda'
path = ''
load_dataset = 'dim(5,4)_specific.ds'

In [7]:
def loss(_sender_input, _message, _receiver_input, receiver_output, labels, _aux_input):
    """
    Loss needs to be defined for gumbel softmax relaxation.
    For a discriminative game, accuracy is computed by comparing the index with highest score in Receiver
    output (a distribution of unnormalized probabilities over target positions) and the corresponding 
    label read from input, indicating the ground-truth position of the target.
    Adaptation to concept game with multiple targets after Mu & Goodman (2021) with BCEWithLogitsLoss
        receiver_output: Tensor of shape [batch_size, n_objects]
        labels: Tensor of shape [batch_size, n_objects]
    """
    # after Mu & Goodman (2021):
    loss_fn = nn.BCEWithLogitsLoss()
    loss = loss_fn(receiver_output, labels)
    receiver_pred = (receiver_output > 0).float()
    per_game_acc = (receiver_pred == labels).float().mean(1).cpu().numpy()  # all labels have to be predicted correctly
    acc = per_game_acc.mean()
    return loss, {'acc': acc}

In [8]:
data_set = torch.load(path + 'data/' + load_dataset)
print('data loaded from: ' + 'data/' + load_dataset)

data loaded from: data/dim(5,4)_specific.ds


In [9]:
train, val, test = data_set
train = torch.utils.data.DataLoader(train, batch_size = opts.batch_size, shuffle=False)
val = torch.utils.data.DataLoader(val, batch_size = opts.batch_size, shuffle=False)
test = torch.utils.data.DataLoader(test, batch_size=opts.batch_size, shuffle=False)

In [10]:
len(test) # n batches

1600

In [11]:
sender = Sender(hidden_size, sum(dimensions), game_size, context_unaware)
receiver = Receiver(sum(dimensions), hidden_size)

minimum_vocab_size = dimensions[0] + 1  # plus one for 'any'
vocab_size = minimum_vocab_size * vocab_size_factor + 1  # multiply by factor plus add one for eos-symbol

# allow user to specify a maximum message length
if max_mess_len:
    max_len = max_mess_len
# default: number of attributes
else:
    max_len = len(dimensions)

sender = core.RnnSenderGS(sender,
                          vocab_size,
                          int(hidden_size / 2),
                          hidden_size,
                          cell=sender_cell,
                          max_len=max_len,
                          temperature=temperature)

receiver = core.RnnReceiverGS(receiver,
                              vocab_size,
                              int(hidden_size / 2),
                              hidden_size,
                              cell=receiver_cell)

game = core.SenderReceiverRnnGS(sender, receiver, loss)

optimizer = torch.optim.Adam([
        {'params': game.sender.parameters(), 'lr': learning_rate},
        {'params': game.receiver.parameters(), 'lr': learning_rate}
    ])

trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train, validation_data=val)

In [12]:
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
game.load_state_dict(checkpoint.model_state_dict)
optimizer.load_state_dict(checkpoint.optimizer_state_dict)
epoch = checkpoint.epoch

mean_loss, interaction = trainer.eval(data=test)

In [13]:
mean_loss

0.17476728558540344

In [14]:
interaction.aux['acc'].mean()

tensor(0.9474)