In [None]:
import pickle
import pandas as pd
import numpy as np


seed = 1
env_name = 'kangaroo'

with open(f'../stats/{env_name}/weights_seed_{seed}.pkl', 'rb') as f:
    weights = pickle.load(f)
    
with open(f'../stats/{env_name}/action_values_seed_{seed}.pkl', 'rb') as f:
    action_values = pickle.load(f)
    
action_values = np.array(action_values)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
colors = ['hotpink', 'dodgerblue']
bluepink_cmap = LinearSegmentedColormap.from_list('bluepink', colors, N=100)

In [None]:
import torch
def build_weight_map(weights):
    weights_logic = np.array(weights)[:,1]
    weights_logic = torch.tensor(weights_logic)
    # normalize for dim=1
    weights_logic = weights_logic.unsqueeze(0).expand(50, -1)
    weights_logic = torch.softmax(weights_logic, dim=1) 
    return weights_logic

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_weights_with_bar(weights, action_values, env_name):
    # plt.style.use('default') 
    C1 = "dodgerblue" 
    C0 = "deeppink"

    labels = ['neural', 'logic']

    action_values = np.array(action_values)
    # fig=plt.figure(figsize=(5,3))

    fig, (ax2,ax1) = plt.subplots(nrows=2,figsize=(7.5, 3), sharex=True)

    ax1.axes.xaxis.set_ticks([])
    ax1.axes.yaxis.set_ticks([])
    ax1.set_ylabel("L / N", fontsize=14)
    sns.set_style("white")
    weights_map = build_weight_map(weights)
    im = ax1.imshow(weights_map, cmap=bluepink_cmap)

    ax2.set_title( f"{env_name}".capitalize(), fontsize=20)
    ax1.set_xlabel("episodic steps", fontsize=16)
    ax2.set_ylabel("policy output", fontsize=14)


    # sns.set_pjalette('Set1')
    sns.set_style("whitegrid", {'grid.linestyle': '--'})
    ax2.plot(action_values[:,0], label=labels[0], alpha=0.6, color=C0)
    ax2.plot(action_values[:,1], label=labels[1], alpha=0.99, color=C1)

    # cbar = fig.colorbar(im, ax=[ax1,ax2], shrink=0.95)
    ax2.legend(loc='upper right', fontsize=12)
    plt.tight_layout()
    path = f'figures/{env_name}_blender_weights_seed_{seed}.pdf' 
    plt.savefig(path, bbox_inches='tight')
    plt.show()
    plt.close()


In [None]:

seeds = [0]#, 1]
env_name = 'kangaroo'

for seed in seeds:
    with open(f'../stats/{env_name}/weights_seed_{seed}.pkl', 'rb') as f:
        weights = pickle.load(f)
    
    with open(f'../stats/{env_name}/action_values_seed_{seed}.pkl', 'rb') as f:
        action_values = pickle.load(f)
        
    plot_weights_with_bar(weights, action_values, env_name)
    # plot_action_values(np.array(action_values))
    # plot_weights(np.array(weights))

In [None]:

seeds = [0]
env_name = 'seaquest'

for seed in seeds:
    with open(f'../stats/{env_name}/weights_seed_{seed}.pkl', 'rb') as f:
        weights = pickle.load(f)
    
    with open(f'../stats/{env_name}/action_values_seed_{seed}.pkl', 'rb') as f:
        action_values = pickle.load(f)
        
    plot_weights_with_bar(weights, action_values, env_name)