Implementing DQN brancher.

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.agents import DQNAgent
from retro_branching.environments import EcoleBranching
from retro_branching.learners import DQNLearner
from retro_branching.networks import BipartiteGCN

import ecole

In [None]:
%autoreload

# agent
agent = DQNAgent(device='cuda:1')

# env
env = EcoleBranching(observation_function='default',
                      information_function='default',
                      reward_function='default',
                      scip_params='default')
env.seed(0)

# instances
instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)

In [None]:
%autoreload

num_episodes = 1
for ep in range(num_episodes):
    print(f'> Episode {ep} <')
    
    # find an instance not pre-solved by environment
    obs = None
    while obs is None:
        env.seed(0)
        instance = next(instances)
        instance_before_reset = instance.copy_orig()
        agent.before_reset(instance_before_reset.copy_orig())
        obs, action_set, reward, done, info = env.reset(instance)
        
    while not done:
        action, action_idx = agent.action_select(action_set, obs)
        obs, action_set, reward, done, info = env.step(action)

In [None]:
%autoreload

device = 'cuda:1'

value_network = BipartiteGCN(device=device,
                             emb_size=64,
                             num_rounds=1,
                             cons_nfeats=5,
                             edge_nfeats=1,
                             var_nfeats=19,
                             aggregator='add')

agent = DQNAgent(device=device,
                 value_network=value_network,
                 name='rl_gnn')

learner = DQNLearner(agent=agent,
                     env=env,
                     instances=instances,
                     buffer_capacity=100,
                     buffer_min_length=100,
                     update_target_frequency=500,
                     seed=0,
                     batch_size=32,
                     agent_reward='num_nodes',
                     lr=1e-4,
                     gamma=0.99,
                     initial_epsilon=1,
                     final_epsilon=0.05,
                     final_epsilon_episode=50000,
                     threshold_difficulty=None,
                     max_steps=int(1e5),
                     episode_log_frequency=1,
                     name='dqn_learner')

learner.train(2e5)

In [None]:
print(learner.episodes_log.keys())

In [None]:
import numpy as np

print(np.array(learner.episodes_log['num_nodes']).shape)

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(len(learner.episodes_log['num_nodes'])), learner.episodes_log['num_nodes'])
plt.show()

Lets try to implement batching

In [None]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Batch, Data

First, we need to format the observation of bipartite data from `ecole` into a format that `pytorch_geometric` can handle.

In [None]:
class BipartiteNodeData(torch_geometric.data.Data):
    """
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` 
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(self, constraint_features, edge_indices, edge_features, variable_features,
                 candidates):
        super().__init__()
        self.constraint_features = torch.FloatTensor(constraint_features)
        self.edge_index = torch.LongTensor(edge_indices.astype(np.int64))
        self.edge_attr = torch.FloatTensor(edge_features).unsqueeze(1)
        self.variable_features = torch.FloatTensor(variable_features)
        self.candidates = candidates
        
        self.num_candidates = len(candidates)
        self.num_variables = self.variable_features.size(0)
        self.num_nodes = self.constraint_features.size(0) + self.variable_features.size(0)

    def __inc__(self, key, value):
        """
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs 
        for those entries (edge index, candidates) for which this is not obvious. This
        enables batching.
        """
        if key == 'edge_index':
            # constraint nodes connected via edge to variable nodes
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'candidates':
            # actions are variable nodes
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value)

In [None]:
print(obs.row_features.shape, obs.edge_features.indices.shape, obs.edge_features.values.shape, obs.column_features.shape, action_set.shape)

In [None]:
data = BipartiteNodeData(obs.row_features, obs.edge_features.indices, obs.edge_features.values, obs.column_features, action_set)
print(data)

In [None]:
batch = Batch.from_data_list([data])
print(batch)
print(type(batch))