In [None]:
# import modules 
import argparse

import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import egg.core as core
from egg.core import Callback, Interaction, PrintValidationEvents
from egg.zoo.basic_games.architectures import DiscriReceiver, RecoReceiver, Sender

In [None]:
# set path and values

# input file https://github.com/franfranz/EGG/blob/main/egg/zoo/sum_game/fullset_train.txt
path="fullset_train.txt" 
frame = np.loadtxt(path, dtype="S10")
n_attributes=2
n_values=21

batch_size=20
# inspect
#frame

## Reco game

### AttValDataset - Reco

The AttValRecoDataset class is used in the reconstruction game. It takes an input file with a
space-delimited attribute-value vector per line and  creates a data-frame with the two mandatory fields expected in EGG games, namely sender_input and labels.
In this case, the two fields contain the same information, namely the input attribute-value vectors,represented as one-hot in sender_input, and in the original integer-based format in labels.


In [None]:
# AttValRecoDataset 

temp_reco =[]
frame = np.loadtxt(path, dtype="S10")    
for row in frame:
            if n_attributes == 1:
                row = row.split()
            config = list(map(int, row))
            z = torch.zeros((n_attributes, n_values))
            for i in range(n_attributes):
                z[i, config[i]] = 1
            label = torch.tensor(list(map(int, row)))
            temp_reco.append((z.view(-1), label))
            
# inspect
#temp_reco            

### Loss - Reco
In the case of the recognition game, for each attribute we compute a different cross-entropy score based on comparing the probability distribution produced by the Receiver over the values of each attribute with the corresponding ground truth, and then averaging across attributes accuracy is instead computed by considering as a hit only cases where, for each attribute, the Receiver assigned the largest probability to the correct value most of this function consists of the usual pytorch madness needed to reshape tensors in order to perform these computations. 

In [None]:
        def loss(
            sender_input, _message, _receiver_input, receiver_output, labels, _aux_input
        ):
            
           # n_attributes = opts.n_attributes
           # n_values = opts.n_values
           # batch_size = sender_input.size(0)
            receiver_output = receiver_output.view(batch_size * n_attributes, n_values)
            receiver_guesses = receiver_output.argmax(dim=1)
            correct_samples = (
                (receiver_guesses == labels.view(-1))
                .view(batch_size, n_attributes)
                .detach()
            )
            acc = (torch.sum(correct_samples, dim=-1) == n_attributes).float()
            labels = labels.view(batch_size * n_attributes)
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            loss = loss.view(batch_size, -1).mean(dim=1)
            return loss, {"acc": acc}

In [None]:
     # again, see data_readers.py in this directory for the AttValRecoDataset data reading class
        train_loader = DataLoader(
            AttValRecoDataset(
                path=path,
                n_attributes=n_attributes,
                n_values=n_values,
            ),
            #batch_size=opts.batch_size,
            shuffle=True,
            num_workers=1,
        )

        # the number of features for the Receiver (input) and the Sender (output) is given by n_attributes*n_values because
        # they are fed/produce 1-hot representations of the input vectors
        n_features = n_attributes * n_values

## Sum game

### AttValDataset - Sum

The AttValSumDataset class is used in the sum game. It takes an input file with a
space-delimited attribute-value vector per line and creates a data-frame with sender_input and labels.

Sender_input is a concatenation of the input integers, encoded as one-hot vectors. 
Labels is a one-hot vector indicating the sum of the two integers. 

In [None]:
# this AttValSumDataset will substitute the Reco dataset in the Sum game

temp_sum = []
for row in frame:
            config = list(map(int, row))
            conf_sum = (config[0]+config[1])
            conrow = [config[0], config[1], conf_sum]
            z = torch.zeros((n_attributes, n_values))
            for i in range(n_attributes):
                z[i, config[i]] = 1
            s_inp = torch.cat((z[0],z[1]))
            label2=torch.zeros(n_attributes*n_values)
            label2[conf_sum] = 1
            temp_sum.append((s_inp, label2))

# inspect
#temp_sum

### Loss - Sum 
In the case of the sum game, for each attribute we compute a different cross-entropy score between the output of the Receiver and the ground truth vector (labels). 

In [None]:
        def loss(
            sender_input, _message, _receiver_input, receiver_output, labels, _aux_input
        ):
            
           # n_attributes = opts.n_attributes
           # n_values = opts.n_values
           # batch_size = sender_input.size(0)
           # receiver_output = receiver_output.view(batch_size * n_attributes, n_values)
            
            # acc = (torch.sum(correct_samples, dim=-1) == n_attributes).float()
            
            #labels = labels.view(batch_size * n_attributes)
            
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            return loss, {"acc": acc}
        