In [2]:
from pathlib import Path

import gin
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch

from src.game import Game, CircleL1, CircleL2

In [3]:
def plot(logdir, savedir=None, name='error', show=True):
    if savedir:
        savepath = Path(savedir)
        
    logpath = Path(logdir)    

    config_file = next(logpath.glob('**/*.gin'))
    print(f'config file {config_file}')
    gin.parse_config_file(config_file, skip_unknown=True)
    
    bias = gin.config.query_parameter('Game.bias')
    num_points = gin.config.query_parameter('Game.num_points')
    test_loss = CircleL1(num_points)
    
    run_logs = []
    for path in logpath.glob('**/*.json'):
        print(f'plotting from {path}')
        with open(path, 'r') as logfile:
            run_logs.append(pd.read_json(logfile))

    logs = pd.concat(run_logs, ignore_index=True)
    sender = pd.DataFrame(logs['sender'].to_list()).join(logs['epoch'])
    recver = pd.DataFrame(logs['recver'].to_list()).join(logs['epoch'])
    
    if show:
        metric = "test_l1_error" if "test_l1_error" in sender else "test_error"

        # Rewards
        sns.lineplot(data=sender, x="epoch", y=sender[metric]*10, label="sender")
        sns.lineplot(data=recver, x="epoch", y=recver[metric]*10, label="receiver")

        # Baselines
        nocomm_diff = torch.tensor(36 / 4)
        nocomm_error = test_loss(torch.tensor(0.), nocomm_diff)*10
        fair_error = test_loss(torch.tensor(0.), bias/2)*10
        plt.axhline(nocomm_error, label='no communication', color="black", dashes=(2,2,2,2))
        plt.axhline(fair_error, label='fair split', color="grey", dashes=(2,2,2,2))

        plt.ylabel(r'Test $L_1$ loss')

        if savedir:
            plt.savefig(savepath / f'{name}.png',  bbox_inches='tight')
            
        print("here")
        plt.show()
        plt.clf()

    error_sum = pd.DataFrame(sender['test_l1_error'] + recver['test_l1_error']).join(logs['epoch'])
    error_20 = error_sum[error_sum['epoch'] >= 20]
    return [error_20[i*10:i*10+10]['test_l1_error'].mean() for i in range(5)]
    

In [None]:
resultspath = Path('../results/cat-deter')
with sns.plotting_context('paper'):
    sns.set(font_scale=1.3)
    plot_hyperparam_results(resultspath, 'SenderLOLA4-ReceiverLOLA4', resultspath / 'plots')