In [7]:
import torch as th
import itertools
from pprint import pprint


class SimpleAgent(th.nn.Module):
    def __init__(self, n_attributes: int, n_values: int, vocab_size: int):
        super().__init__()
        self.n_attributes = n_attributes
        self.n_values = n_values
        self.vocab_size = vocab_size

        self.fc1 = th.nn.Linear(n_attributes * n_values, vocab_size)
        self.fc2 = th.nn.Linear(vocab_size, n_attributes * n_values)

    def forward(self, x: th.Tensor, input_type: str, train=True):
        log_prob = None
        if input_type == "object":
            x = self.fc1(x)
            x = th.nn.functional.relu(x)
            if train:
                dist = th.distributions.Categorical(logits=x)
                x = dist.sample()
                log_prob = dist.log_prob(x).mean()
            else:
                x = x.argmax(dim=-1)

            x = th.nn.functional.one_hot(x, self.vocab_size).float()
        elif input_type == "message":
            x = self.fc2(x)
            x = th.nn.functional.relu(x)

        return x, log_prob


N_EPOCHS = 10000
BATCH_SIZE = 32
N_ATTRIBUTES = 1
N_VALUES = 4
VOCAB_SIZE = 10
DEVICE = "cpu"
PRINT_RATE = 1000

dataset = th.Tensor(
    list(itertools.product(th.arange(N_VALUES), repeat=N_ATTRIBUTES))
).long()
dataset = th.nn.functional.one_hot(dataset, N_VALUES).float()
dataset = dataset.view(-1, N_ATTRIBUTES * N_VALUES)
dataset = dataset.to(DEVICE)

train_loader = th.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

agent_1 = SimpleAgent(N_ATTRIBUTES, N_VALUES, VOCAB_SIZE).to(DEVICE)
agent_2 = SimpleAgent(N_ATTRIBUTES, N_VALUES, VOCAB_SIZE).to(DEVICE)

agent_1_optim = th.optim.Adam(agent_1.parameters(), lr=1e-3)
agent_2_optim = th.optim.Adam(agent_2.parameters(), lr=1e-3)

agent_1_baseline = 0
agent_2_baseline = 0
baseline_count = 0

for epoch in range(N_EPOCHS):
    total_reward = 0
    baseline_count += 1
    for batch in train_loader:
        message, agent_1_log_prob = agent_1(batch, "object", train=True)
        answer, agent_2_log_prob = agent_2(message, "message", train=True)

        batch_size = batch.shape[0]
        batch = batch.view(batch_size * N_ATTRIBUTES, N_VALUES)
        answer = answer.view(batch_size * N_ATTRIBUTES, N_VALUES)

        reward = -th.nn.functional.cross_entropy(answer, batch.argmax(dim=-1)).mean()
        agent_1_baseline += (reward.detach().item() - agent_1_baseline) / baseline_count
        agent_2_baseline += (reward.detach().item() - agent_2_baseline) / baseline_count

        agent_1_loss = -agent_1_log_prob * (reward - agent_1_baseline)
        agent_2_loss = -reward

        agent_1_optim.zero_grad()
        agent_1_loss.backward(retain_graph=True)
        agent_1_optim.step()
        agent_2_optim.zero_grad()
        agent_2_loss.backward(retain_graph=True)
        agent_2_optim.step()

        total_reward += reward.item()

    if epoch % PRINT_RATE == 0:
        print(f"Epoch {epoch}: {total_reward / len(train_loader)}")


message, _ = agent_1(dataset, "object", train=False)
answer, _ = agent_2(message, "message", train=False)
print(answer.argmax(dim=-1).view(-1, N_ATTRIBUTES, N_VALUES).squeeze().tolist())

Epoch 0: -1.600724458694458
Epoch 1000: -1.5607264041900635
Epoch 2000: -1.4777616262435913
Epoch 3000: -1.2476773262023926
Epoch 4000: -1.4370496273040771
Epoch 5000: -0.7159523963928223
Epoch 6000: -0.842619776725769
Epoch 7000: -0.7506885528564453
Epoch 8000: -0.8220731019973755
Epoch 9000: -1.3526273965835571
[0, 1, 2, 0, 4]
