In [12]:
import argparse
import base64
import ctypes
import collections
import itertools
import shlex
import traceback
from pathlib import Path

import numpy as np
import pandas as pd
import torch

import egg
from egg.zoo.basic_games.data_readers import AttrValClassDataset
from egg.zoo.basic_games import play_sum

In [13]:
PATH_TRAIN = "../../data/sum50.train.train"
PATH_VAL = "../../data/sum50.train.val"
PATH_TEST = "../../data/sum50.test"
N_VALUES = 50
N_CLASSES = N_VALUES * 2 - 1

In [14]:
train_loader, test_loader = [
    torch.utils.data.DataLoader(
        AttrValClassDataset(
            path=path,
            n_values=N_VALUES,
        ),
        batch_size=32,
        shuffle=False,
        num_workers=1,
    )
    for path in [PATH_TRAIN, PATH_TEST]
]

In [15]:
def load_game(exp_dir):
    with open(Path(exp_dir) / "args") as f:
        args = shlex.split(f.read())
    opts = play_sum.get_params(args)

    # Make the trainer load the checkpoint instead of writing it
    assert opts.checkpoint_dir
    opts.load_from_checkpoint = str(Path(opts.checkpoint_dir) / "final.tar")
    opts.checkpoint_dir = None
    opts.tensorboard = False

    game = play_sum.main(args, opts=opts, train=False)
    game.eval()
    return game, opts

In [16]:
def interact(game, data_loader):
    interactions = []
    for sender_input, labels in data_loader:
        with torch.no_grad():
            interactions.append(game(sender_input.cuda(), labels.cuda())[1].to("cpu"))
    return egg.core.Interaction.from_iterable(interactions)

In [17]:
def balanced_accuracy(interaction):
    class_hits = np.bincount(interaction.labels, weights=interaction.aux["acc"])
    class_counts = np.bincount(interaction.labels)
    class_accuracies = class_hits[class_counts > 0] / class_counts[class_counts > 0]
    return class_accuracies.mean()

In [18]:
results = {}
for exp_dir in ["20210728-021922", "20210725-003558"]:
    try:
        game, opts = load_game(exp_dir)
    except:
        traceback.print_exc()
        continue
    results[str(exp_dir)] = {}
    for split, data_loader in [("train", train_loader), ("test", test_loader)]:
        interaction = interact(game, data_loader)
        results[str(exp_dir)][split + "_acc"] = balanced_accuracy(interaction)
results = pd.DataFrame.from_dict(results, orient="index")

Namespace(batch_size=32, checkpoint_dir=None, checkpoint_freq=0, cuda=True, device=device(type='cuda'), distributed_context=DistributedContext(is_distributed=False, rank=0, local_rank=0, world_size=1, mode='none'), distributed_port=18363, fp16=False, load_from_checkpoint='20210728-021922/final.tar', lr=0.001, max_len=1, mode='rf', n_attributes=None, n_epochs=500, n_values=50, no_cuda=False, optimizer='adam', preemptable=False, print_validation_events=False, random_seed=706341019, receiver_cell='rnn', receiver_embedding=10, receiver_hidden=200, receiver_layers=2, rnn=False, sender_cell='rnn', sender_embedding=10, sender_entropy_coeff=0.1, sender_hidden=200, sender_layers=2, temperature=1.0, tensorboard=False, tensorboard_dir='20210728-021922', train_data='../../data/sum50.train.train', update_freq=1, validation_batch_size=32, validation_data='../../data/sum50.train.val', validation_freq=20, vocab_size=5000)
# Initializing model, trainer, and optimizer from 20210728-021922/final.tar
# lo

In [19]:
results

Unnamed: 0,train_acc,test_acc
20210728-021922,0.328195,0.081806
20210725-003558,0.972613,0.401741


In [21]:
with pd.option_context("display.precision", 2):
    print(results.to_latex(index=False))

\begin{tabular}{rr}
\toprule
 train\_acc &  test\_acc \\
\midrule
      0.33 &      0.08 \\
      0.97 &      0.40 \\
\bottomrule
\end{tabular}

