In [None]:
import numpy as np
import networkx as nx

#%matplotlib notebook
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
%config InlineBackend.figure_format = 'retina'

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataset import TensorDataset
from torch.utils.data import DataLoader # (testset, batch_size=4,shuffle=False, num_workers=4)
from torch.optim.lr_scheduler import ReduceLROnPlateau as RLRP
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.nn.init import xavier_normal
from torch.nn.parameter import Parameter
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import sys
from datetime import datetime
from functools import reduce
import os
import os.path
import pandas as pd
import pickle
import importlib
from collections import Counter
from copy import deepcopy
from collections import OrderedDict

import torch_geometric as tg
import nkmodel as nk
import ppo.core as core
from ppo.ppo import PPOBuffer
from utils.utils import max_mean_clustering_network
import envs
import json
from itertools import product
from functools import reduce  

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib.gridspec as gridspec

In [None]:
E = 32
M = 100
N = 15
K = 12
NN = 3
exp = 4
trj_len = 200
graph_type = 'complete'
reward_type = 'indv_raw_full'
action_type = 'total'
extra_type = 'SI'
env_name = 'SL_NK_' + action_type

nx_dict = {'complete': nx.complete_graph, 'ba': nx.barabasi_albert_graph, 'er': nx.erdos_renyi_graph, 'maxmc':max_mean_clustering_network} 
nx_arg_dict = {'complete': {'n': M}, 'ba': {'n': M, 'm': 19}, 'er': {'n': M, 'p': 0.3}, 'maxmc': {'n': M}}

env_kwargs = {
        'E': E,
        'M': M,
        'N': N,
        'K': K,
        'neighbor_num': NN,
        'exp': exp,
        'graph': nx_dict[graph_type],
        'graph_dict': nx_arg_dict[graph_type],
        'reward_type': reward_type,
        'action_type': action_type,
        'extra_type': extra_type,
    'corr_type': 'TT'
    }

In [None]:
baseline_data_dict = {}
baseline_data_dict['keys'] = ['Ret', 'FinalScore']

In [None]:
env_num = 20
test_ensemble_num = 100
env_list = [envs.__dict__[env_name](**env_kwargs) for i in range(env_num)]
state_list = []
for i in range(env_num):
    print(i)
    _, fixed_state = env_list[i].reset(E=test_ensemble_num, base=True)
    state_list.append(deepcopy(fixed_state))
print("Baseline construction initiated")

In [None]:
#baselines = ['FollowBest', 'FollowBest_indv', 'FollowMajor', 'FollowMajor_indv', 'IndvLearning', 'RandomCopy']
baselines = ['FollowBest', 'FollowBest_indv', 'FollowBest_random', 'FollowBest_prob',
             'FollowMajor', 'FollowMajor_indv', 'FollowMajor_random', 'FollowMajor_prob',
            'IndvLearning', 'IndvRandom', 'IndvProb', 'RandomCopy']

for baseline_name in baselines:
    if baseline_name not in baseline_data_dict.keys():
        print(f"Baseline : {baseline_name}")
        baseline_data = {}
        baseline_data['Ret'] = []
        baseline_data['FinalScore'] = []
        baseline_data['scr_buf'] = []
        baseline_data['unq_buf'] = []

        for i in range(env_num):
            print(i)
            env_base = env_list[i]
            ac_base = core.__dict__[baseline_name](env_base, action_type, extra_type, corr_type='TT')
            scr_buf = np.zeros((test_ensemble_num, M, trj_len), dtype=np.float32)
            unq_buf = np.zeros((test_ensemble_num, trj_len), dtype=np.float32)

            o, _ = env_base.reset(states=state_list[i], state_only=True, base=True)
            ep_ret, ep_len = 0, 0
            for t in range(trj_len):
                a = ac_base.step(o)
                next_o, r, s = env_base.step(a)
                ep_ret += r
                ep_len += 1
                scr_buf[..., t] = s
                for e in range(test_ensemble_num):
                    freq = np.unique(a[e], axis=0)
                    unq_buf[e][t] = freq.shape[0]
                o = next_o

            baseline_data['Ret'].append(np.mean(ep_ret / ep_len))
            baseline_data['FinalScore'].append(np.mean(s))
            baseline_data['scr_buf'].append(scr_buf)
            baseline_data['unq_buf'].append(unq_buf)
        baseline_data['Ret'] = np.mean(baseline_data['Ret'])
        baseline_data['FinalScore'] = np.mean(baseline_data['FinalScore'])
        baseline_data['scr_buf'] = np.array(baseline_data['scr_buf'])
        baseline_data['unq_buf'] = np.array(baseline_data['unq_buf'])
        baseline_data_dict[baseline_name] = baseline_data
        print("Baseline finished")
        print(f"{baseline_name}, {baseline_data_dict[baseline_name]['Ret']}, {baseline_data_dict[baseline_name]['FinalScore']}")
with open(f'baseline_{graph_type}_N{N}K{K}NN{NN}_exp{exp}.pkl', 'wb') as f:
    pickle.dump(baseline_data_dict, f, pickle.HIGHEST_PROTOCOL)

In [None]:
for baseline_name in baselines:
    if baseline_name is not 'keys':
        print(f"{baseline_name}, {baseline_data_dict[baseline_name]['Ret']}, {baseline_data_dict[baseline_name]['FinalScore']}")

In [None]:
baseline_data_dict

In [None]:
# Figure drawing
fig = plt.figure(figsize=(4,4), dpi=200)
ax = fig.add_subplot(111)
color_list = ['limegreen', 'darkgreen','deepskyblue', 'royalblue', 'purple', 'gold']
label_dict = {'FollowBest':'BI', 'FollowBest_indv':'BI-I', 'FollowBest_random':'BI-R', 'FollowMajor':'CF', 'FollowMajor_indv':'CF-I', 'FollowMajor_random':'CF-R'}
counter=0
if baselines:
    for baseline_name in baselines:
        x = baseline_data_dict[baseline_name]['scr_buf']
        avg_pf = np.mean(x, axis=tuple(range(0, len(x.shape) - 1)))
        std_pf = np.std(x, axis=tuple(range(0, len(x.shape) - 1)))
        ax.plot(np.arange(x.shape[-1]), avg_pf, c=color_list[counter], ls=(0, (3, 2)), label=label_dict[baseline_name])
        #ax.fill_between(np.arange(x.shape[-1]), avg_pf-std_pf, avg_pf+std_pf, facecolor=color_list[counter], alpha=0.2)
        counter+=1

#ax.fill_between(np.arange(x.shape[-1]), avg_pf-std_pf, avg_pf+std_pf, facecolor=color_list[counter], alpha=0.2)
ax.set_xlabel('Time')
ax.set_ylabel('Average Payoff')
ax.legend(fontsize=8, loc=4)
#fig_name = 'st_complete_indv_raw_full_total_random_SI_TT_N15K7NN3_disc_g99_I100_L200_RST_TMT'
#plt.savefig(f'./result/figure/{fig_name}.png')