How close is our agent to predicting the ground-truth values of the reward its Q-values should be predicting?

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.utils import get_most_recent_checkpoint_foldername, seed_stochastic_modules_globally
from retro_branching.networks import BipartiteGCN
from retro_branching.agents import Agent, REINFORCEAgent, PseudocostBranchingAgent, StrongBranchingAgent, RandomAgent, DQNAgent, DoubleDQNAgent
from retro_branching.environments import EcoleBranching
from retro_branching.validators import ReinforcementLearningValidator

import ecole
import torch
import numpy as np
import os
import shutil
import glob
import time
import gzip
import pickle
import copy
from collections import defaultdict
from sklearn.metrics import mean_squared_error
import sigfig

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import ray
import psutil
num_cpus = psutil.cpu_count(logical=False)
ray.init(num_cpus=int(num_cpus), ignore_reinit_error=True)

seed = 0
seed_stochastic_modules_globally(default_seed=seed)

In [None]:
# define agent and env params

%autoreload
learner = 'supervised_learner' # 'reinforce_learner' 'dqn_learner' 'supervised_learner'
base_name = 'gnn' # 'rl_gnn' 'dqn_gnn' 'gnn'
AgentClass = Agent # Agent REINFORCEAgent DQNAgent DoubleDQNAgent
nrows = 100
ncols = 100
threshold_difficulty = None
last_checkpoint_idx = None
max_steps = int(1e12) # int(1e12) 3
agent_reward = 'sb_scores' # 'dual_bound_frac' 'normalised_lp_gain'
gamma = 0

In [None]:
%autoreload

# load agent
device = 'cpu'

i = 326 # 569 312
checkpoint = 37 # 25 88
observation_function = '43_var_features' # 'default' '40_var_features' '43_var_features'
agent_name = f'{base_name}_{i}'
DEVICE = 'cpu'
agents, agent_paths = {}, []
path = f'/scratch/datasets/retro_branching/{learner}/{base_name}/{agent_name}/checkpoint_{checkpoint}/'
config = path + '/config.json'
agent = AgentClass(device=device, config=config)
agent.name = f'{agent_name}_checkpoint_{checkpoint}'
for network_name, network in agent.get_networks().items():
    if network is not None:
        try:
            # see if network saved under same var as 'network_name'
            agent.__dict__[network_name].load_state_dict(torch.load(path+f'/{network_name}_params.pkl', map_location=device))
        except KeyError:
            # network saved under generic 'network' var (as in Agent class)
            agent.__dict__['network'].load_state_dict(torch.load(path+f'/{network_name}_params.pkl', map_location=device))
        
agent.eval() # put in test mode

max_steps_agent = None

In [None]:
%autoreload
# env
env = EcoleBranching(observation_function=observation_function, # 'default' 40_var_features'
                      information_function='default',
                      reward_function='default',
                      scip_params='default')
env.seed(seed)

# instance
files = glob.glob(f'/scratch/datasets/retro_branching/instances/set_cover_nrows_{nrows}_ncols_{ncols}_density_005_threshold_{threshold_difficulty}/*.mps')
instances = [ecole.scip.Model.from_file(f) for f in files]
# instance = instances[10]

In [None]:
%autoreload

# init tracker
instance_stats = defaultdict(lambda: [])
step_stats = defaultdict(lambda: defaultdict(lambda: []))

# DEBUG
sb = StrongBranchingAgent()

# solve each instance with branching agent
for i, instance in enumerate(instances):
    print(f'\nInstance {i+1} of {len(instances)}')
    # reset agent and env
    agent.before_reset(instance)
    sb.before_reset(instance) # DEBUG
    obs, action_set, reward, done, info = env.reset(instance)
    
    # init trackers
    predictions, rewards, returns = [], [], []
    t = 1
    
    # run branching agent
    while not done:
        # sb score target
        rewards.append(sb.extract(env.model, done)[action_set].max())
        
        # get agent q values
        q_vals = agent.calc_Q_values(obs)
        if type(q_vals) == list:
            q_vals = torch.stack(q_vals).squeeze(0).detach().cpu().numpy()
        else:
            q_vals = q_vals.squeeze(0).detach().cpu().numpy()
        # get q values of valid actions
        q_vals = q_vals[action_set]

        # get agent action
#         action, action_idx = sb.action_select(action_set=action_set, model=env.model, done=done)
        action, action_idx = agent.action_select(action_set=action_set, obs=obs, agent_idx=0)

        # get q value agent predicted
        predictions.append(q_vals[action_idx])

        # take step in environment
        obs, action_set, reward, done, info = env.step(action)

        # get true next-step reward received by agent
#         rewards.append(reward[agent_reward])


#     # calc true discounted returns
#     returns = []
#     R = 0
#     for r in rewards[::-1]:
#         R = r + (gamma * R)
#         returns.insert(0, R)
    returns = rewards

    for t in range(len(predictions)):
        print(f'Step {t} | q_val prediction: {predictions[t]} | reward: {rewards[t]} | true discounted_return: {returns[t]}')
        step_stats[t]['predictions'].append(predictions[t])
        step_stats[t]['true_targets'].append(returns[t])
    instance_stats['predictions'].append(predictions)
    instance_stats['true_targets'].append(returns)

In [None]:
# plot mean predictions and true targets at each step
fig = plt.figure()
colours = iter(sns.color_palette(palette='hls', n_colors=len(instance_stats.keys())))
flat_data = {}
for label in instance_stats.keys():
    colour = next(colours)
    data = defaultdict(lambda: [])
    for t in step_stats.keys():
        data['x_vals'].extend([t+1 for _ in range(len(step_stats[t][label]))])
        data['y_vals'].extend(step_stats[t][label])
    flat_data[label] = data['y_vals']
    sns.lineplot(x='x_vals',
                 y='y_vals',
                 data=data,
                 color=colour,
                 ci=68,
                 alpha=0.8,
                 label=label)
    plt.axhline(y=np.mean(data['y_vals']), color=colour, linestyle='--', alpha=0.5)
mse = mean_squared_error(flat_data['true_targets'], flat_data['predictions'])
plt.title(f'{agent.name} MSE: {sigfig.round(mse, sigfigs=3)}')
plt.xlabel('Step')
plt.ylabel(f'gamma_{gamma}_{agent_reward}')
plt.legend()
plt.show()

In [None]:
# # TEMPORARY: Plotting strong branching gamma=0 case...
# fig = plt.figure()
# colours = iter(sns.color_palette(palette='hls', n_colors=len(instance_stats.keys())))
# flat_data = {}
# colour = next(colours)
# data = defaultdict(lambda: [])
# for t in step_stats.keys():
#     data['x_vals'].extend([t+1 for _ in range(len(step_stats[t]['true_targets']))])
#     data['y_vals'].extend(step_stats[t]['true_targets'])
# flat_data[label] = data['y_vals']
# sns.lineplot(x='x_vals',
#              y='y_vals',
#              data=data,
#              color=colour,
#              ci=68,
#              alpha=0.8,
#              label=f'{agent.name}_returns')
# plt.axhline(y=np.mean(data['y_vals']), color=colour, linestyle='--', alpha=0.5)
# plt.xlabel('Step')
# plt.ylabel(f'gamma_{gamma}_{agent_reward}')
# plt.legend()
# plt.show()