In [None]:
import matplotlib.pyplot as plt
import itertools
import json
import argparse
import numpy as np
import seaborn as sns

from plot_utils import extract_measures, extract_utterance, get_distributions, get_x_positions, order_lists, plot_reward, plot_training_curve 
# from plot_graphs import plot_reward  

sns.set()
# plt.set
# plt.figure.

%matplotlib inline

dir_ = 'final/new/'

filenames = {
    'self_both': dir_ + 'log_20191108_202416final_self.log',
    'self_linguistic': dir_ +'log_20191105_181425final_self_disable-proposal.log',
    'self_proposal': dir_ + 'log_20191106_174032final_self_disable-comms.log',
    'self_none': dir_ + 'log_20191109_114924final_self_disable-proposal_disable.log',
    
    'prosoc_both': dir_ + 'log_20191103_201228final_prosoc.log',
    'prosoc_linguistic': dir_ + 'log_20191104_193106final_prosoc_disable-proposal.log',
    'prosoc_proposal': dir_ + 'log_20191104_145248final_prosoc_disable-comms.log',
    'prosoc_none': dir_ + 'log_20191107_200112final_prosoc_disable-proposal_disable-comms.log'  
}

filenames_memory_comp = {
    'prosoc_none': dir_ + 'log_20191108_182251final_prosoc_disable-proposal_disable-comms_memory-comp.log',
    'prosoc_linguistic': dir_ + 'log_20191111_090756final_prosocial_memory-comp_disable-proposal.log',
    'prosoc_none2': dir_ + 'log_20191113_100238final_prosocial_memory-comp_disable-comms_disable-proposal.log',
    'prosoc_proposal': dir_ + 'log_20191112_174358final_prosocial_memory-comp_disable-comms.log',
    
    'self_proposal':dir_ + 'log_20191113_185154final_self_memory-comp_disable-comms.log'
#     'self_linguistic':, 'log_20191114_095448final_self_memory-comp_disable-proposal.log'
}

output_dir = 'figs/'
    
def double_check(filenames):
    for key, filename in filenames.items():
        try:
            with open(filename, 'r') as f:
                    for n, line in enumerate(f):
                        if n == 0:
                            msg = line
                        break
        except Exception as e:
            print(e)
            msg = 'Somethings wrong in here'
        print(key, msg)

def gen_iter(exclude_sociality=None, exclude_channel=None):
    socialities = ['self', 'prosoc']
    channels = ['proposal', 'linguistic', 'both', 'none']
    if exclude_sociality:
        socialities.remove(exclude_sociality)
    if exclude_channel:
        channels.remove(exclude_channel)
    return itertools.product(socialities, channels)

In [None]:
def plot_reward_all(filenames, prefix=''):
    for sociality, channel in gen_iter():
        key = sociality + '_' + channel
        print(key)
        filename = filenames[key] if key in filenames else None

        if filename:
            output = '{}reward_{}_{}.png'.format(output_dir, key, prefix)
            plot_reward(filename, 0, 1, key, 200000, labels=None, output_file=output)   

In [None]:
plot_reward_all(filenames)
plot_reward_all(filenames_memory_comp, 'memory-comp')

In [None]:
def joint_reward_success(filenames):

    """
    FOR TABLE 2
    
    keys:
    self_proposal, self_linguistic, self_both, self_none,
    prosoc_proposal, prosoc_linguistic, prosoc_both, prosoc_none
    
    values:
    filenames
    
    Joint reward success and average number of turns taken for paired agents negotiating
    with random game termination, varying the agent reward scheme and communication channel.
    """
    
    from_paper = {'self_proposal': 0.87, 'self_linguistic': 0.75, 'self_both': 0.87, 'self_none': 0.77,
                  'prosoc_proposal': 0.93,  'prosoc_linguistic': 0.99, 'prosoc_both': 0.92, 'prosoc_none': 0.95}
    
    data = {}
    
    for sociality, channel in gen_iter():
        key = sociality + '_' + channel
        filename = filenames.get(key, None)

        if filename:
            extracted = extract_measures(filenames[key], ['test_reward'])
            joint_reward = np.mean(extracted['test_reward'])

        else:
            joint_reward = -1
        data[key] = {'joint_reward': joint_reward}
            
    for sociality, channel in itertools.product(socialities, channels):
        key = sociality + '_' + channel
        print(sociality + ' ' + channel)
        print('\tour:       {}'.format(data[key]['joint_reward']))
        print('\tfrom paper: {}'.format(from_paper[key]))
    return data
    
    

In [None]:
joint_reward_success(filenames)
joint_reward_success(filenames_memory_comp)

In [None]:
def plot_training_curve_all(filenames, prefix=''):
    """
    FOR FIGURE 2a
    
    Training curves for SELF-INTERESTED agents learning to negotiate under the various com- munication channels.
    """

    for sociality, channel in gen_iter():
        key = sociality + '_' + channel
        filename = filenames.get(key, None)
        if filename:
            output = '{}training_curve_{}{}.png'.format(output_dir, key, prefix)
            plot_training_curve(filename, min_y=0, max_y=1, title='', max_x=200000, labels=None, output=output)
        

In [None]:
plot_training_curve_all(filenames_memory_comp, '_memory-comp')
plot_training_curve_all(filenames)

In [None]:
def plot_utterance(distribution, turn, vocab_len=10, utter_len=6):
    labels = [str(i) for i in range(utter_len)]
    
    x = np.arange(utter_len)
    width = 4
    positions = get_x_positions(vocab_len, utter_len, width=width, outer_width=2).T
    
    fig, ax = plt.subplots()
    reacts = []
    
    for i in range(vocab_len):
        values = distribution[turn, :, i]
        l = positions[i]
        print('lens', len(l), values.shape)
        reacts.append(ax.bar(l, values, width, label=str(i)))
    ax.legend()
        
    def autolabel(rects):
        """
        Attach a text label above each bar in *rects*, displaying its height.
        from: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/barchart.html#sphx-glr-gallery-lines-bars-and-markers-barchart-py
        """
        for rect in rects:
            height = rect.get_height()
            ax.annotate('{}'.format(height),
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),  # 3 points vertical offset
                        textcoords="offset points",
                        ha='center', va='bottom')
            
    ax.set_ylabel('Occurance')
    ax.set_title('Symbols distribution in position')
#     ax.set_xticks(np.arange(labels))
    ax.set_xticklabels(labels)
    fig.tight_layout()


def plot_unigram_statistics(filenames):
    """
    FOR FIGURE 3a
    
    Unigram statistics of symbol usage broken down by turn and by position within the utterance
    for prosocial agents communicating via the linguistic channel.
    """
    
    for sociality, channel in gen_iter(exclude_sociality='self'):
        key = sociality + '_' + channel
        filename = filenames.get(key, None)
        if filename:
            extracted = extract_utterance(filename)
            distribution = get_distributions(extracted)
            plot_utterance(np.array(distribution), 0)


In [None]:
plot_unigram_statistics(filenames)

In [None]:
def bigram_statistics(filenames):
    """
    FOR FIGURE 3b
    
    Bigram counts for prosocial agents communicating via the linguistic channel, sorted by frequency.
    """
    filename = filenames['prosoc_linguistic']
    extracted = extract_utterance(filename)
    extracted_a = extracted[::2]
    extracted_b = extracted[1::2]
    extracted_a = np.array([list(map(str, msg)) for sublist in extracted_a for msg in sublist])
    extracted_b = np.array([list(map(str, msg)) for sublist in extracted_b for msg in sublist])

    bigrams_a = []
    bigrams_b = []
    for i in range(extracted_a.shape[1] - 1):
        new_bigrams_a = list(np.core.defchararray.add(extracted_a[:, i], extracted_a[:, i + 1]))
        new_bigrams_b = list(np.core.defchararray.add(extracted_b[:, i], extracted_b[:, i + 1]))
        bigrams_a += new_bigrams_a
        bigrams_b += new_bigrams_b

    unique, counts = np.unique(bigrams_b, return_counts=True)
    counts, unique = order_lists(counts, unique)
    x = np.arange(len(counts))
    plt.bar(x, counts)

bigram_statistics(filenames)