In [1]:
import argparse
import collections
import itertools
import shlex
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import seaborn as sns
import torch

import egg
from egg.zoo.basic_games.data_readers import AttrValClassDataset
from egg.zoo.basic_games import play_sum

In [2]:
PATH_TRAIN = "../../data/sum5.train.train"
PATH_VAL = "../../data/sum5.train.val"
PATH_TEST = "../../data/sum5.test"
N_VALUES = 5
MAX_RESULT = 2 * (N_VALUES - 1)

In [3]:
def entropy(x_samples):
    n = len(x_samples)
    cx = collections.Counter(x_samples)
    return -sum([cx[x] / n * np.log2(cx[x] / n) for x in cx])

def cond_entropy(y_samples, x_samples):
    assert len(x_samples) == len(y_samples)
    n = len(x_samples)
    cx = collections.Counter(x_samples)
    cxy = collections.Counter(zip(x_samples, y_samples))
    return -sum([cxy[x, y] / n * np.log2(cxy[x, y] / cx[x]) for x, y in cxy])

def pointwise_mutual_information(x_samples, y_samples):
    assert len(x_samples) == len(y_samples)
    n = len(x_samples)
    cx = collections.Counter(x_samples)
    cy = collections.Counter(y_samples)
    cxy = collections.Counter(zip(x_samples, y_samples))
    return np.array([np.log2(cxy[x, y] * n / (cx[x] * cy[y])) for x, y in zip(x_samples, y_samples)])

def mutual_information(x_samples, y_samples):
    return pointwise_mutual_information(x_samples, y_samples).mean()

In [4]:
def update_legend(ax, **kwargs):
    old_legend = ax.legend_
    handles = old_legend.legendHandles
    labels = [t.get_text() for t in old_legend.get_texts()]
    if "title" not in kwargs:
        kwargs["title"] = old_legend.get_title().get_text()
    ax.legend(handles, labels, **kwargs)

In [5]:
# Build a codebook of letter combinations to display instead of numbers for easier viewing
codewords = [''.join(l) for l in itertools.product(*[[chr(i) for i in range(ord('a'), ord('z') + 1)]] * 2)]
np.random.default_rng(1).shuffle(codewords)

In [6]:
exp_dir = "vs50_20210728-014041"

with open(Path(exp_dir) / "args") as f:
    args = shlex.split(f.read())
opts = play_sum.get_params(args)

# Make the trainer load the checkpoint instead of writing it
assert opts.checkpoint_dir
opts.load_from_checkpoint = str(Path(opts.checkpoint_dir) / "final.tar")
opts.checkpoint_dir = None
opts.tensorboard = False

game = play_sum.main(args, opts=opts, train=False)
game.eval()

Namespace(batch_size=32, checkpoint_dir=None, checkpoint_freq=0, cuda=True, device=device(type='cuda'), distributed_context=DistributedContext(is_distributed=False, rank=0, local_rank=0, world_size=1, mode='none'), distributed_port=18363, fp16=False, load_from_checkpoint='vs50_20210728-014041/final.tar', lr=0.001, max_len=1, mode='rf', n_attributes=None, n_epochs=1000, n_values=5, no_cuda=False, optimizer='adam', preemptable=False, print_validation_events=False, random_seed=2134625598, receiver_cell='rnn', receiver_embedding=10, receiver_hidden=100, receiver_layers=2, rnn=False, sender_cell='rnn', sender_embedding=10, sender_entropy_coeff=0.1, sender_hidden=100, sender_layers=2, temperature=1.0, tensorboard=False, tensorboard_dir='vs50_20210728-014041', train_data='../../data/sum5.train.train', update_freq=1, validation_batch_size=32, validation_data='../../data/sum5.train.val', validation_freq=20, vocab_size=50)
# Initializing model, trainer, and optimizer from vs50_20210728-014041/fi

SymbolGameReinforce(
  (sender): ReinforceWrapper(
    (agent): SumSender(
      (net): Sequential(
        (0): Linear(in_features=10, out_features=100, bias=True)
        (1): ELU(alpha=1.0)
        (2): Linear(in_features=100, out_features=50, bias=True)
      )
    )
  )
  (receiver): ReinforceDeterministicWrapper(
    (agent): SymbolReceiverWrapper(
      (agent): SumReceiver(
        (net): Sequential(
          (0): Linear(in_features=100, out_features=100, bias=True)
          (1): ELU(alpha=1.0)
          (2): Linear(in_features=100, out_features=9, bias=True)
        )
      )
      (embedding): RelaxedEmbedding(50, 100)
    )
  )
)

In [7]:
data_loader = torch.utils.data.DataLoader(
    AttrValClassDataset(
        path=PATH_TRAIN,
        n_values=opts.n_values,
    ),
    batch_size=opts.validation_batch_size,
    shuffle=False,
    num_workers=1,
)

In [8]:
interaction = []
for sender_input, labels in data_loader:
    with torch.no_grad():
        interaction.append(game(sender_input.cuda(), labels.cuda())[1].to("cpu"))
interaction = egg.core.Interaction.from_iterable(interaction)

In [9]:
aa = interaction.sender_input[:, :opts.n_values].argmax(dim=-1).numpy()
bb = interaction.sender_input[:, opts.n_values:].argmax(dim=-1).numpy()
rr = interaction.receiver_output.argmax(dim=-1).numpy()
correct_mask = (aa + bb == rr)
messages = np.array([codewords[msg] for msg in interaction.message.tolist()])

In [10]:
all_df = pd.DataFrame({
    "msg": messages,
    "a": aa,
    "b": bb,
    "r": rr,
    "acc": aa + bb == rr,
    "ab": list(zip(aa, bb)),
    "ab_unordered": [tuple(sorted(x)) for x in zip(aa, bb)]}
)
correct_df = all_df[correct_mask]

In [14]:
# Normalized excess information
cond_entropy(correct_df["msg"], correct_df["r"]) / cond_entropy(correct_df["ab"], correct_df["r"])

0.9128491025779536

In [15]:
correct_df["msg"].astype("category")

0      mr
1      nc
2      sz
3      ix
4      vo
       ..
395    nc
396    nc
397    xn
398    ro
399    go
Name: msg, Length: 400, dtype: category
Categories (16, object): ['ay', 'cz', 'el', 'fw', ..., 'vo', 'xn', 'zb', 'zn']

In [16]:
correct_df[["msg", "ab"]].drop_duplicates().groupby("msg")["ab"].apply(list).to_dict()

{'ay': [(0, 3)],
 'cz': [(0, 4)],
 'el': [(3, 1), (1, 3)],
 'fw': [(2, 3)],
 'go': [(2, 4)],
 'ix': [(4, 2)],
 'jb': [(4, 1)],
 'kk': [(2, 2)],
 'mr': [(1, 1)],
 'nc': [(0, 0)],
 'ro': [(1, 2)],
 'sz': [(0, 1)],
 'vo': [(4, 4)],
 'xn': [(3, 4)],
 'zb': [(2, 1)],
 'zn': [(1, 4)]}

In [17]:
# Commutative property
comm_df = correct_df.copy()
# Keep only pairs that are in the dataset both ways
ab_set = set(comm_df["ab"].values)
comm_df = comm_df[comm_df.apply(lambda row: tuple(reversed(row["ab"])) in ab_set, axis=1)]
comm_df = comm_df[["a", "b", "r", "msg"]].drop_duplicates()
comm_df = comm_df.groupby(["a", "b"]).first()
comm_df = comm_df.sort_values(["a", "b"])
comm_df["msg_ba"] = comm_df.reset_index().rename(columns={"a": "b", "b": "a"}).set_index(["a", "b"])["msg"]
comm_df["comm"] = (comm_df["msg"] == comm_df["msg_ba"])
comm_df = comm_df.reset_index()
comm_df = comm_df[comm_df["a"] < comm_df["b"]]  # Count each pair only once

print("Commutativity:", comm_df["comm"].mean())
print("Min. number of commutative pairs per sum:", comm_df.groupby("r")["comm"].count().min())

Commutativity: 0.25
Min. number of commutative pairs per sum: 1
