In [19]:
import argparse
import base64
import ctypes
import collections
import itertools
import shlex
import traceback
from pathlib import Path

import numpy as np
import pandas as pd
import torch

import egg
from egg.zoo.basic_games.data_readers import AttrValClassDataset
from egg.zoo.basic_games import play_sum

In [20]:
PATH_TRAIN = "../../data/sum5.train.train"
PATH_VAL = "../../data/sum5.train.val"
PATH_TEST = "../../data/sum5.test"
N_VALUES = 5
N_CLASSES = N_VALUES * 2 - 1

In [21]:
train_loader, test_loader = [
    torch.utils.data.DataLoader(
        AttrValClassDataset(
            path=path,
            n_values=N_VALUES,
        ),
        batch_size=32,
        shuffle=False,
        num_workers=1,
    )
    for path in [PATH_TRAIN, PATH_TEST]
]

In [22]:
def load_game(exp_dir):
    with open(Path(exp_dir) / "args") as f:
        args = shlex.split(f.read())
    opts = play_sum.get_params(args)

    # Make the trainer load the checkpoint instead of writing it
    assert opts.checkpoint_dir
    opts.load_from_checkpoint = str(Path(opts.checkpoint_dir) / "final.tar")
    opts.checkpoint_dir = None
    opts.tensorboard = False

    game = play_sum.main(args, opts=opts, train=False)
    game.eval()
    return game, opts

In [23]:
def interact(game, data_loader):
    interactions = []
    for sender_input, labels in data_loader:
        with torch.no_grad():
            interactions.append(game(sender_input.cuda(), labels.cuda())[1].to("cpu"))
    return egg.core.Interaction.from_iterable(interactions)

In [24]:
def balanced_accuracy(interaction):
    class_hits = np.bincount(interaction.labels, weights=interaction.aux["acc"])
    class_counts = np.bincount(interaction.labels)
    class_accuracies = class_hits[class_counts > 0] / class_counts[class_counts > 0]
    return class_accuracies.mean()

In [25]:
results = {}
for exp_dir in Path(".").glob("vs*"):
    try:
        game, opts = load_game(exp_dir)
    except:
        traceback.print_exc()
        continue
    results[str(exp_dir)] = {
        "vocab_size": opts.vocab_size,
    }
    for split, data_loader in [("train", train_loader), ("test", test_loader)]:
        interaction = interact(game, data_loader)
        results[str(exp_dir)][split + "_acc"] = balanced_accuracy(interaction)
results = pd.DataFrame.from_dict(results, orient="index")

Namespace(batch_size=32, checkpoint_dir=None, checkpoint_freq=0, cuda=True, device=device(type='cuda'), distributed_context=DistributedContext(is_distributed=False, rank=0, local_rank=0, world_size=1, mode='none'), distributed_port=18363, fp16=False, load_from_checkpoint='vs9_20210728-011522/final.tar', lr=0.001, max_len=1, mode='rf', n_attributes=None, n_epochs=1000, n_values=5, no_cuda=False, optimizer='adam', preemptable=False, print_validation_events=False, random_seed=2125479434, receiver_cell='rnn', receiver_embedding=10, receiver_hidden=100, receiver_layers=2, rnn=False, sender_cell='rnn', sender_embedding=10, sender_entropy_coeff=0.1, sender_hidden=100, sender_layers=2, temperature=1.0, tensorboard=False, tensorboard_dir='vs9_20210728-011522', train_data='../../data/sum5.train.train', update_freq=1, validation_batch_size=32, validation_data='../../data/sum5.train.val', validation_freq=20, vocab_size=9)
# Initializing model, trainer, and optimizer from vs9_20210728-011522/final.

# Initializing model, trainer, and optimizer from vs9_20210728-011103/final.tar
# loading trainer state from vs9_20210728-011103/final.tar
Namespace(batch_size=32, checkpoint_dir=None, checkpoint_freq=0, cuda=True, device=device(type='cuda'), distributed_context=DistributedContext(is_distributed=False, rank=0, local_rank=0, world_size=1, mode='none'), distributed_port=18363, fp16=False, load_from_checkpoint='vs80_20210727-204454/final.tar', lr=0.001, max_len=1, mode='rf', n_attributes=None, n_epochs=1000, n_values=5, no_cuda=False, optimizer='adam', preemptable=False, print_validation_events=False, random_seed=268405226, receiver_cell='rnn', receiver_embedding=10, receiver_hidden=100, receiver_layers=2, rnn=False, sender_cell='rnn', sender_embedding=10, sender_entropy_coeff=0.1, sender_hidden=100, sender_layers=2, temperature=1.0, tensorboard=False, tensorboard_dir='vs80_20210727-204454', train_data='../../data/sum5.train.train', update_freq=1, validation_batch_size=32, validation_data

# Initializing model, trainer, and optimizer from vs25_20210728-010138/final.tar
# loading trainer state from vs25_20210728-010138/final.tar
Namespace(batch_size=32, checkpoint_dir=None, checkpoint_freq=0, cuda=True, device=device(type='cuda'), distributed_context=DistributedContext(is_distributed=False, rank=0, local_rank=0, world_size=1, mode='none'), distributed_port=18363, fp16=False, load_from_checkpoint='vs50_20210728-013111/final.tar', lr=0.001, max_len=1, mode='rf', n_attributes=None, n_epochs=1000, n_values=5, no_cuda=False, optimizer='adam', preemptable=False, print_validation_events=False, random_seed=1111318823, receiver_cell='rnn', receiver_embedding=10, receiver_hidden=100, receiver_layers=2, rnn=False, sender_cell='rnn', sender_embedding=10, sender_entropy_coeff=0.1, sender_hidden=100, sender_layers=2, temperature=1.0, tensorboard=False, tensorboard_dir='vs50_20210728-013111', train_data='../../data/sum5.train.train', update_freq=1, validation_batch_size=32, validation_d

In [26]:
results

Unnamed: 0,vocab_size,train_acc,test_acc
vs9_20210728-011522,9,0.573016,0.142857
vs25_20210728-010145,25,0.888889,0.0
vs9_20210728-012513,9,0.748148,0.142857
vs25_20210728-011052,25,0.866667,0.0
vs9_20210728-012034,9,0.715873,0.0
vs50_20210728-014041,50,1.0,0.0
vs50_20210728-012642,50,0.748148,0.0
vs9_20210728-011103,9,0.614815,0.0
vs80_20210727-204454,80,0.977778,0.142857
vs25_20210728-010622,25,0.977778,0.0


In [27]:
results.groupby("vocab_size").count()

Unnamed: 0_level_0,train_acc,test_acc
vocab_size,Unnamed: 1_level_1,Unnamed: 2_level_1
9,5,5
25,5,5
50,5,5
80,5,5


In [28]:
table = results.groupby("vocab_size")[["train_acc", "test_acc"]].agg(["max", "mean", lambda vals: (vals > 0.999).sum()]).drop(("test_acc", "<lambda_0>"), axis=1).rename(columns={"<lambda_0>": "# perfect"}).reset_index()
table

Unnamed: 0_level_0,vocab_size,train_acc,train_acc,train_acc,test_acc,test_acc
Unnamed: 0_level_1,Unnamed: 1_level_1,max,mean,# perfect,max,mean
0,9,0.748148,0.66963,0,0.142857,0.057143
1,25,0.977778,0.912238,0,0.0,0.0
2,50,1.0,0.856296,1,0.0,0.0
3,80,1.0,0.951111,2,0.142857,0.028571


In [29]:
with pd.option_context("display.precision", 2):
    print(table.to_latex(index=False))

\begin{tabular}{rrrrrr}
\toprule
vocab\_size & \multicolumn{3}{l}{train\_acc} & \multicolumn{2}{l}{test\_acc} \\
           &       max & mean & \# perfect &      max & mean \\
\midrule
         9 &      0.75 & 0.67 &         0 &     0.14 & 0.06 \\
        25 &      0.98 & 0.91 &         0 &     0.00 & 0.00 \\
        50 &      1.00 & 0.86 &         1 &     0.00 & 0.00 \\
        80 &      1.00 & 0.95 &         2 &     0.14 & 0.03 \\
\bottomrule
\end{tabular}

