In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import csv
import egg.core as core
from egg.core import RnnSenderReinforce, RnnReceiverDeterministic
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# game parameter
comm_size = 4
seed = 5
sender_comm_size = comm_size
receiver_comm_size = comm_size
vocab_size = 10
max_len = 5

# agent parameter
sender_hidden = 10
receiver_hidden = 35
sender_embedding = 5
receiver_embedding = 5
sender_cell = 'gru'
receiver_cell = 'gru'
n_features = 512

batch_size = 128

In [None]:
class EGG_Dataset(Dataset):
    '''
    A dataset that is read in from a txt file and is compatible with the EGG framework and PyTorch's DataLoader.
    
    Attributes:
        frame -- The dataset
        
    Methods:
        __init__(game_data_path, test_or_val) -- Constructor, initializes a dataset object
        get_n_features() -- Returns the dimension of an image embedding
        __len__() -- Returns the lenght of the dataset
        __getitem__(idx) -- Returns an element of the dataset by the given index idx
    '''
    def __init__(self, game_data_path, test = False):
        '''
        Constructor, initializes an EGG_Dataset object, i.e. a dataset that is compatible with both EGG
        and PyTorch's DataLoader.
        
        Parameter:
            game_data_path -- The path to the game data (train, val, or test) for which a dataset is to be produced
            test -- Boolean that indicates whether or not the dataset will be used for testing
        '''
        self.frame = []
        with open(game_data_path, 'r') as game_data:
            for i,line in enumerate(game_data):
                # splits row into list with an element for each embedding and the target idx (at last position)
                raw_info = line.split(",")
                # stores the embeddings in nested list index_vectors and casts values to int
                # , e.g. [[feat1, feat2, feat3], [feat1, feat2, feat3]]
                embeddings = list([list(map(float, x.split())) for x in raw_info[:-2]])
                # imagenet class of the target image
                target_class = int(raw_info[-2])
                # index of the target embedding
                target_idx = torch.tensor(int(raw_info[-1]))
                # the target embedding, which is the sender input
                target_embedding = torch.FloatTensor(embeddings[target_idx])
                # constructing the receiver input, i.e. storing target and distractor embeddings in one tensor
                target_and_distractors = []
                for embedding in embeddings:
                    target_and_distractors = np.concatenate((target_and_distractors, embedding))
                target_and_distractors = torch.FloatTensor(target_and_distractors).view(len(embeddings), -1)
                # for training data: storing sender input, target_index, and receiver input in data frame
                if test == False:
                    self.frame.append((target_embedding, target_idx, target_and_distractors))
                # for test or val data: additionally sotre the imagenet class of the target
                else:
                    self.frame.append((target_embedding, target_idx, target_and_distractors, target_class))

        
    def get_n_features(self):
        'Returns the dimension of the image embeddings'
        return self.frame[0][0].size(0)

    def __len__(self):
        'Returns the length of the dataset'
        return len(self.frame)

    def __getitem__(self, idx):
        'Returns an element of the dataset by index'
        return self.frame[idx]

In [None]:
class Sender(nn.Module):
    '''
    The core of any sender agent. It takes the embedding of the target image as input and produces an initial
    hidden state for the message producing GRU (that will be initialized via an EGG-Wrapper below).
    
    Attributes:
        fc1 -- The only layer of the sender's core
    
    Methods:
        __init__(n_hidden, n_features) -- Constructor, initializes the sender's core
        forward(x, _aux_input) -- Performs a forward pass through the sender's core 
    '''
    def __init__(self, n_hidden, n_features):
        '''
        Constructor, initializes the sender's core.
        
        Parameter:
            n_features -- The input size, i.e. the number of elements in an image embedding
            n_hidden -- The output size, i.e. the size of the hidden states in the message producing GRU
            
        Output:
            None
        '''
        super(Sender, self).__init__()
        self.fc1 = nn.Linear(n_features, n_hidden)

    def forward(self, x, _aux_input):
        '''Performs a forward pass through the sender's core, i.e. maps the target img embedding x to 
        the initial hidden state of the GRU and returns this mapping'''
        return self.fc1(x)
        #return self.fc1(x).tanh()
        #return self.fc1(x).LeakyReLu
        # here, it might make sense to add a non-linearity, such as tanh



class Receiver(nn.Module):
    '''
    The core of any receiver agent. It takes two inputs: a message embedding from its wrapper GRU and the image
    embeddings of the target and distractors. Target and distractor embeddings are mapped to the same dimension
    as the message embedding and then computes the dot product between the message embedding and each of the mapped
    image embeddings. The resulting list of dot products is interpreted as a non-normalized probability distribution
    over possible target positions, i.e. the highest dot product was computed using the mapped embedding of the target image.
    
    Attributes:
        fc1 -- The only layer of the receiver's core
        
    Methods:
        __init__(n_features, n_hidden) -- Constructor, initializes a receiver's core
        forward(x, _input, _aux_input) -- Performs a forward pass through the receiver's core
    '''
    def __init__(self, n_features, n_hidden):
        '''
        Constructor, initalizes the sender's core.
        
        Parameter:
            n_features -- The input sizes of the target and distractor embeddings
            n_hidden -- The size of the wrapper-GRU's hidden state
            
        Output:
            None
        '''
        super(Receiver, self).__init__()
        self.fc1 = nn.Linear(n_features, n_hidden)

    def forward(self, x, _input, _aux_input):
        '''
        Performs a forward pass through the receiver's core, i.e. maps all embeddings of target and distractor
        images to the dimension of the wrapper-GRU's message embedding and then computes and returns the dot
        products between all target/distractor mappings and the message embedding.
        
        Parameter:
            x -- The message embedding produced by the wrapper-GRU
            _input -- The embeddings of target and distractor images
        
        Output:
            dots -- A list of dot products between mapped image embeddings and the message embedding, the element
                    with the highest dot product acts as the receiver's prediction of the target position
        '''
        # the rationale for the non-linearity here is that the RNN output (x) will also be the outcome of a non-linearity
        embedded_input = self.fc1(_input).tanh()
        dots = torch.matmul(embedded_input, torch.unsqueeze(x, dim=-1))
        return dots.squeeze()
        # return nn.Sigmoid(dots.squeeze())

In [None]:
test_dataset = EGG_Dataset('./Data/Game_Data/test.txt', test = True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

## Loading the Agents

In [None]:
sender_comm = []
receiver_comm = []
for i in range(comm_size):
    current_sender = Sender(n_hidden=sender_hidden, n_features=n_features)
    sender_comm.append(core.RnnSenderReinforce(
                            current_sender,
                            vocab_size=vocab_size,
                            embed_dim=sender_embedding,
                            hidden_size=sender_hidden,
                            cell=sender_cell,
                            max_len=max_len,
                            ).to(device)
    )

for i in range(comm_size):
    current_receiver = Receiver(n_features=n_features, n_hidden=receiver_hidden)
    receiver_comm.append(core.RnnReceiverDeterministic(
                            current_receiver,
                            vocab_size=vocab_size,
                            embed_dim=receiver_embedding,
                            hidden_size=receiver_hidden,
                            cell=receiver_cell,
                            ).to(device)
    )
    
for i,sender in enumerate(sender_comm):
    dict_file_name = 'sender_'+str(i+1)+'.pt'
    dict_path = './Agents/Comm_Size_'+ str(comm_size) +'/Seed_'+str(seed)+'/Senders/'+dict_file_name
    sender.load_state_dict(torch.load(dict_path))

for i,receiver in enumerate(receiver_comm):
    dict_file_name = 'receiver_'+str(i+1)+'.pt'
    dict_path = './Agents/Comm_Size_'+ str(comm_size) +'/Seed_'+str(seed)+'/Receivers/'+dict_file_name
    receiver.load_state_dict(torch.load(dict_path))

## Producing the Evaluation Dataset

In [None]:
for sender in sender_comm:
    sender.eval()
for receiver in receiver_comm:
    receiver.eval()

with open('./Data/Evaluation_Data/Comm_Size_'+ str(comm_size) +'/Seed_'+str(seed)+'/evaluation_data.csv', 'w') as file:
    writer = csv.writer(file)
    header = ['id', 'sender_input', 'target_class', 'accuracy']
    for i in range(comm_size):
        column = 'message_'+str(i+1)
        header.append(column)
    writer.writerow(header)
    row_nr = 0
    
    for (sender_inputs, target_idxs, receiver_inputs, target_classes) in test_loader:
        current_batch_size = len(target_idxs)
        # messages contains one entry for each sender and each of these entries is a batch of messages
        messages = []
        for sender in sender_comm:
            batch_messages, _batch_sen_logprobs, _batch_sen_entropies = sender(sender_inputs.to(device))
            messages.append(batch_messages) 
        
        # contains one list for each batch of messages. These lists contain one list from each receiver
        # corresponding to the receiver's outputs given that message batch. In other words, the first
        # element contains the receiver outputs of all pairs involving sender1, the second one of all pairs
        # involving sender2 and so on.
        receiver_outputs = []
        for message_batch in messages:
            outputs_with_current_sender = []
            for receiver in receiver_comm:
                batch_rec_outputs, _batch_rec_logrobs, _batch_rec_entropies = receiver(message_batch, receiver_inputs.to(device))
                outputs_with_current_sender.append(batch_rec_outputs)
            receiver_outputs.append(outputs_with_current_sender)
        
        # one row for each element in the batch
        for i in range(current_batch_size):
            row_nr += 1
            sender_input = str(sender_inputs[i].tolist()).replace('[', '').replace(']', '')
            target_class = target_classes[i].item()
            row = [row_nr, sender_input, target_class]
            
            sample_target_idx = target_idxs[i]
            sample_accs = []
            for pairs in receiver_outputs:
                for receivers_output in pairs:
                    #print(receivers_output)
                    output_for_sample = receivers_output[i]
                    sample_accs.append(int(output_for_sample.argmax(dim=0)==sample_target_idx))

            avg_sample_acc = np.array(sample_accs, dtype=float).mean()
            row.append(avg_sample_acc)
            for sender_outputs in messages:
                message = str(sender_outputs[i].tolist()).replace('[', '').replace(']', '')
                row.append(message)

            writer.writerow(row)
