In [None]:
import networkx as nx
import numpy as np

from pyscipopt import Model, quicksum, SCIP_PARAMSETTING
import ecole

import torch
import torch_geometric

import matplotlib.pyplot as plt
import seaborn as sns

### Environment

In [None]:
SCIP_PARAMETERS = {'separating/maxrounds': 0,
                   'separating/maxroundsroot': 0,
                   'separating/maxcuts': 0,
                   'separating/maxcutsroot': 0,
                   'presolving/maxrounds': 0,
                   'presolving/maxrestarts': 0,
                   'propagating/maxrounds':0,
                   'propagating/maxroundsroot':0,
                   'lp/initalgorithm':'d',
                   'lp/resolvealgorithm':'d',
                   'limits/time': 3600}


env = ecole.environment.Branching(
    observation_function=(
        ecole.observation.NodeBipartite()
    ),
    information_function=(
#         ecole.observation.StrongBranchingScores(),
#         ecole.observation.Pseudocosts(),
        ecole.reward.LpIterations().cumsum(),
        ecole.reward.NNodes().cumsum(),
        ecole.reward.SolvingTime(),
        ecole.reward.SolvingTime().cumsum()
       ),
    
    reward_function=(
        ecole.reward.LpIterations(),
    ),
    scip_params=SCIP_PARAMETERS
)
instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)

observation, action_set, reward_offset, done, info = env.reset(next(instances))

In [None]:
print(f"observation.column_features.: {observation.column_features.shape}")
print(f"observation.row_features.: {observation.row_features.shape}")
print(f"observation.edge_features: {observation.edge_features.values.shape}")
print(f"\tobservation.edge_features.values: {observation.edge_features.values.shape}")
print(f"\tobservation.edge_features.indices: {observation.edge_features.indices.shape}")

### Network

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.data import Batch, Data
from torch_geometric.utils import softmax

First, we need to format the observation of bipartite data from ``ecole`` into a format that ``pytorch_geometric`` can handle.

In [None]:
class BipartiteData(Data):
    """
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` 
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(self, bipartite_observation, candidates):
        super().__init__()
        self.constraint_features = torch.from_numpy(bipartite_observation.row_features).float()
        self.edge_index = torch.from_numpy(bipartite_observation.edge_features.indices.astype(np.int64)).long()
        self.edge_features = torch.from_numpy(bipartite_observation.edge_features.values).float()
        self.variable_features = torch.from_numpy(bipartite_observation.column_features).float()
        
        if self.edge_features.dim()==1:
            self.edge_features.unsqueeze_(-1)
        self.edge_index_c2v = self.edge_index
        self.edge_index_v2c = self.edge_index[[1,0]]
        
        self.candidates = torch.from_numpy(candidates.astype(np.int64)).long()
        self.raw_candidates = torch.from_numpy(candidates.astype(np.int64)).long()
        self.num_candidates = self.candidates.size(0)
        self.num_variables = self.variable_features.size(0)
        
        self.num_nodes = self.constraint_features.size(0) + self.variable_features.size(0)

    def __inc__(self, key, value):
        """
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs 
        for those entries (edge index, candidates) for which this is not obvious.
        """
        if key == 'edge_index' or key == 'edge_index_c2v':
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'edge_index_v2c':
            return torch.tensor([[self.variable_features.size(0)], [self.constraint_features.size(0)]])
        elif key == 'candidates':
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value)

In [None]:
data = BipartiteData(observation, action_set)

In [None]:
action_set

The message passing network itself is standard, an encoder, message passing rounds, and a decoder.

This is the simplest possible implementation.  For now, we are not going to do anything with the edge features.  Nor are we considering how to mask out unavailable actions, or preserve state along search tree's etc.

In [None]:
class Encoder(nn.Module):
    
    def __init__(self,
                 constraint_dim=32,
                 edge_dim=32,
                 variable_dim=32):
        super().__init__()
        
        constraint_dim_in = 5
        edge_dim_in = 1
        variable_dim_in = 19
        
        self.constraint_embedding = nn.Sequential(
#             nn.LayerNorm(constraint_dim_in),
            nn.Linear(constraint_dim_in, constraint_dim),
            nn.ReLU(),
            nn.Linear(constraint_dim, constraint_dim),
            nn.ReLU(),
        )

        self.edge_embedding = nn.Sequential(
            torch.nn.LayerNorm(edge_dim_in),
            nn.Linear(edge_dim_in, edge_dim),
            nn.ReLU()
        )

        self.variable_embedding = nn.Sequential(
#             nn.LayerNorm(variable_dim_in),
            nn.Linear(variable_dim_in, variable_dim),
            nn.ReLU(),
            nn.Linear(variable_dim, variable_dim),
            nn.ReLU(),
        )
        
    def forward(self, constraint_features, edge_features, variable_features):
        return (
            self.constraint_embedding(constraint_features),
            self.edge_embedding(edge_features),
            self.variable_embedding(variable_features)
               )
    
class MessagePassingLayer(nn.Module):
    
    def __init__(self,
                 constraint_dim=32,
                 edge_dim=32,
                 variable_dim=32):
        super().__init__()
        
        self.conv_v2c = gnn.GATConv(variable_dim, constraint_dim)
        self.conv_c2v = gnn.GATConv(constraint_dim, variable_dim)
        
    def forward(self,
                constraint_features,
                edge_features,
                variable_features,
                edge_index_v2c,
                edge_index_c2v):
        constraint_features = self.conv_v2c((variable_features, constraint_features),
                                            edge_index_v2c,
                                            size=(variable_features.size(0), constraint_features.size(0)))
        variable_features = self.conv_c2v((constraint_features, variable_features),
                                           edge_index_c2v,
                                           size=(constraint_features.size(0), variable_features.size(0)))
        
        return constraint_features, edge_features, variable_features
    
class Readout(nn.Module):
    
    def __init__(self,
                 variable_dim=32):
        super().__init__()
        
        self.readout = nn.Sequential(
            nn.Linear(variable_dim, variable_dim),
            nn.ReLU(),
            nn.Linear(variable_dim, 1)
        )
        
    def forward(self, variable_features):
        return self.readout(variable_features)

In [None]:
enc = Encoder()
mpnn = MessagePassingLayer()
readout = Readout()

constraint_features, edge_features, variable_features = enc(data.constraint_features,
                                                            data.edge_features,
                                                            data.variable_features)

print(constraint_features.shape, edge_features.shape, variable_features.shape)

constraint_features, edge_features, variable_features = mpnn(constraint_features,
                                                             edge_features,
                                                             variable_features,
                                                             data.edge_index_v2c,
                                                             data.edge_index_c2v)

print(constraint_features.shape, edge_features.shape, variable_features.shape)

out = readout(variable_features)

print(out.shape)

In [None]:
class QNetwork(nn.Module):
    
    def __init__(self,
                 constraint_dim=32,
                 edge_dim=32,
                 variable_dim=32,
                 num_rounds=1):
        super().__init__()
        
        self.encoder = Encoder(constraint_dim, edge_dim, variable_dim)
        self.mpnn = MessagePassingLayer(constraint_dim, edge_dim, variable_dim)
        self.readout = Readout(variable_dim)
        self.num_rounds = 1
        
    def forward(self,
                constraint_features,
                edge_features,
                variable_features,
                edge_index_v2c,
                edge_index_c2v):
        
        print(f'constr: {constraint_features.shape} | edge_feats: {edge_features.shape} | variable_feats: {variable_features.shape}')
        
        constraint_features, edge_features, variable_features = self.encoder(constraint_features,
                                                                             edge_features,
                                                                             variable_features)
        for _ in range(self.num_rounds):
            constraint_features, edge_features, variable_features = self.mpnn(constraint_features,
                                                                              edge_features,
                                                                              variable_features,
                                                                              edge_index_v2c,
                                                                              edge_index_c2v)
            
        return self.readout(variable_features)

In [None]:
qnet = QNetwork(32,1,32,num_rounds=2)

In [None]:
batch = Batch.from_data_list([data])
preds = qnet(batch.constraint_features,
             batch.edge_features,
             batch.variable_features,
             batch.edge_index_v2c,
             batch.edge_index_c2v)

In [None]:
batch.num_candidates.cumsum(0)

In [None]:
torch.stack([q.argmax() for q in preds[batch.candidates].split_with_sizes(tuple(batch.num_candidates))])

### Learner

In [None]:
from collections import namedtuple, deque

Transition = namedtuple('Transition', field_names=['state', # BipartiteData
                                                   'action', # Int
                                                   'reward', # Float
                                                   'done', # Bool
                                                   'new_state']) # BipartiteData

class ReplayBuffer:
    """
    Replay Buffer for storing past experiences allowing the agent to learn from them
    Args:
        capacity: size of the buffer
    """

    def __init__(self, capacity: int) -> None:
        self.buffer = deque(maxlen=capacity)

    def __len__(self) -> None:
        return len(self.buffer)

    def append(self, transition) -> None:
        """
        Add transition to the buffer
        Args:
            transition: tuple (state, action, reward, done, new_state)
        """
        self.buffer.append(transition)

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])

        return (Batch.from_data_list(states),
                torch.tensor(actions),
                torch.tensor(rewards),
                torch.tensor(dones).float(),
                Batch.from_data_list(next_states))

In [None]:
from copy import deepcopy
from collections import defaultdict
from collections.abc import Iterable

class Actor:
    
    def __init__(self,
                 q_network):
        self.policy_network = q_network
        self.target_network = deepcopy(q_network)
        self.update_target_network()
        
    def update_target_network(self):
        self.target_network.load_state_dict(self.policy_network.state_dict())
        for param in self.target_network.parameters():
            param.requires_grad = False
        return self.target_network
    
    def calc_Q_values(self, observation, use_target_network=False):
        if not use_target_network:
            preds = self.policy_network(observation.constraint_features,
                                        observation.edge_features,
                                        observation.variable_features,
                                        observation.edge_index_v2c,
                                        observation.edge_index_c2v)
        else:
            preds = self.target_network(observation.constraint_features,
                                        observation.edge_features,
                                        observation.variable_features,
                                        observation.edge_index_v2c,
                                        observation.edge_index_c2v)
        return preds
        
    def action_select(self, observation_batch, epsilon=0):
        preds = self.calc_Q_values(observation_batch)
        valid_preds = preds[observation_batch.candidates]
        if isinstance(observation_batch.num_candidates, Iterable):
            valid_preds = valid_preds.split_with_sizes(tuple(observation_batch.num_candidates))
            action_idxs = observation_batch.raw_candidates.split_with_sizes(tuple(observation_batch.num_candidates))
        else:
            valid_preds = [valid_preds]
            action_idxs = [observation_batch.raw_candidates]
        print(f'action select action idxs: {action_idxs}')
        actions = torch.stack([idxs[q.argmax()] for q, idxs in zip(valid_preds,action_idxs)])
                                                                           
        if epsilon > 0:
            act_randomly = (np.random.rand(len(actions)) < epsilon)
            rand_actions = [np.random.choice(acts) for acts, rand in zip(action_idxs, act_randomly) if rand]
            actions[act_randomly] = torch.LongTensor(rand_actions)
            
        return actions
    
    def parameters(self):
        return self.policy_network.parameters()
            
                    
class DQNLearner:
    
    def __init__(self,
                 
                 actor,
                 
                 env,
                 
                 instances,
                 
                 buffer_capacity=1000,
                 buffer_min_length=100,
                 batch_size=32,
                
                 steps_per_update = 100,
                 lr = 3e-4,
                 gamma = 0.99,
                 update_target_frequency = 100,
                
                # Exploration
                initial_epsilon=1,
                final_epsilon=0.05,
                final_epsilon_epoch=1000,
                 
                log_frequency=10
                ):
                
        self.actor = actor
        self.env = env
        self.instances = instances
        
        self.buffer = ReplayBuffer(buffer_capacity)
        self.buffer_capacity = buffer_capacity
        self.buffer_min_length=buffer_min_length
        self.batch_size = batch_size
        
        self.steps_per_update = steps_per_update
        self.lr = lr
        self.gamma = gamma
        self.update_target_frequency = update_target_frequency
        
        self.optimizer = self.reset_optimizer()
                
        self.env_ready = False
        self.epsilon = 0.1
        self.num_steps = 0
        self.num_episodes = 0
        self.num_epochs = 0
        
        self.initial_epsilon = initial_epsilon
        self.final_epsilon = final_epsilon
        self.final_epsilon_epoch = final_epsilon_epoch
        
        self.log = defaultdict(list)
        self.log_frequency = log_frequency
        
    def reset_optimizer(self):
        self.optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr)
        return self.optimizer
        
    def reset_env(self, max_attempts=50):
        self.env_ready, num_resets = False, 0
        while not self.env_ready and num_resets < max_attempts:
            observation, action_set, reward, done, info = self.env.reset(next(self.instances))
            num_resets += 1
            self.env_ready = not done
        return BipartiteData(observation, action_set), reward, done, info
    
    def get_epsilon(self):
        return self.initial_epsilon - (self.initial_epsilon-self.final_epsilon)*max(1,self.num_epochs/self.final_epsilon_epoch)
    
    def step_env(self, state):
        action = actor.action_select(state, self.get_epsilon()).item()
        observation, action_set, reward, done, info = env.step(action)
        if not done:
            state = BipartiteData(observation, action_set)
        else:
            state = None
        return state, action, reward, done, info
    
    def update_log(self, info):
        self.log['lp_iter'].append(info[0])
        self.log['num_nodes'].append(info[1])
        self.log['time'].append(info[3])
        
    def get_log_str(self):
        log_str = f"Epoch {self.num_epochs:5}"
        log_str += f" | lp_iter : {self.log['lp_iter'][-1]:5.1f}"
        log_str += f" | num_nodes : {self.log['num_nodes'][-1]:5.1f}"
        log_str += f" | sovle time : {self.log['time'][-1]:5.3f}"
        return log_str

    @torch.no_grad()
    def act(self, num_steps):
        for i in range(num_steps):
            if not self.env_ready:
                self.prev_state = None
                self.state, _, _, _ = self.reset_env()
                
            self.state, action, reward, done, info = self.step_env(self.state)
            if self.prev_state is not None:
                if done:
                    self.state = self.prev_state # hack
                self.buffer.append(Transition(self.prev_state, action, reward, done, self.state))
                print(f'Added to buffer: action: {action} | reward: {reward} | done: {done}')
            self.prev_state = self.state
                
            if done:
                self.update_log(info)
                self.env_ready = False
                self.num_episodes += 1
        self.num_steps += 1
                
    def update_step(self):
        print('\nStepping optimizer')
        def action_to_batch_idxs(action, state):
            return torch.cat([action[[0]], action[1:] + state.num_variables[:-1].cumsum(0)])
        
        # Take DQN update step
        state, action, reward, done, next_state = self.buffer.sample(self.batch_size)
        print(f'action: {action.shape} {action}')
        action_batch_idxs = action_to_batch_idxs(action, state)
        print(f'action_batch_idxs: {action_batch_idxs.shape} {action_batch_idxs}')
        q_value = actor.calc_Q_values(state)[action_batch_idxs].squeeze()
        print(f'q_value: {q_value.shape} {q_value}')
            

        with torch.no_grad():
            action_next = actor.action_select(next_state, epsilon=0)
            print(f'action_next: {action_next.shape} {action_next}')
            action_next_batch_idxs = action_to_batch_idxs(action_next, next_state)
            print(f'action_next_batch_idxs: {action_next_batch_idxs.shape} {action_next_batch_idxs}')
            td_target = reward +\
                        (1-done)*self.gamma*actor.calc_Q_values(next_state, use_target_network=True)[action_next_batch_idxs].squeeze()
            
        self.optimizer.zero_grad()
        loss = F.mse_loss(q_value, td_target)
        loss.backward()
        self.optimizer.step()
        
        # Check to see if we should update target network
        self.num_epochs += 1
        if self.num_epochs % self.update_target_frequency == 0:
            self.actor.update_target_network()
            
        if self.num_epochs % self.log_frequency == 0:
            print(self.get_log_str())
        
    def train(self, num_epoch):
        if len(self.buffer) < self.buffer_min_length:
            # Fill replay buffer to minimum level.
            print("Waiting for replay buffer to fill", end="...")
            self.act(self.buffer_min_length - len(self.buffer))
            print("done.")
        for _ in range(num_epoch):
            self.act(self.steps_per_update)
            self.update_step()

In [None]:
MAX_RESET_ATTEMPTS = 10

env = ecole.environment.Branching(
    observation_function=(
        ecole.observation.NodeBipartite()
    ),
    information_function=(
#         ecole.observation.StrongBranchingScores(),
#         ecole.observation.Pseudocosts(),
        ecole.reward.LpIterations().cumsum(),
        ecole.reward.NNodes().cumsum(),
        ecole.reward.SolvingTime(),
        ecole.reward.SolvingTime().cumsum()
       ),
    
    reward_function=(
#         -ecole.reward.NNodes(),
        -ecole.reward.LpIterations()
#         -ecole.reward.IsDone()
    ),
    scip_params=SCIP_PARAMETERS
)
instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)
actor = Actor(QNetwork(constraint_dim=64, edge_dim=1, variable_dim=64, num_rounds=1))

In [None]:
learner = DQNLearner(actor,
                     env,
                     instances,
                     buffer_capacity=2,
                     buffer_min_length=2,
                     batch_size=2,

                     steps_per_update = 2,
                     lr = 1e-3,
                     gamma = 0.99,
                     update_target_frequency = 50,
                    
                     # Exploration
                     initial_epsilon=1,
                     final_epsilon=0.025,
                     final_epsilon_epoch=250,
                     
                     log_frequency=5)

In [None]:
learner.train(10)

In [None]:
len(learner.log['num_nodes'])

In [None]:
window = 100
data = np.convolve(np.array(learner.log['num_nodes']), np.ones(window)/window, mode='valid')

with sns.plotting_context('paper'):
    plt.plot(data, linewidth=0.5)
    plt.xlabel("Episodes")
    plt.ylabel("num_nodes")
#     plt.yscale('log')
#     plt.legend()
sns.despine()

In [None]:
window = 100
data = np.convolve(np.array(learner.log['lp_iter']), np.ones(window)/window, mode='valid')

with sns.plotting_context('paper'):
    plt.plot(data, linewidth=0.5)
    plt.xlabel("Episodes")
    plt.ylabel("lp_iter")
#     plt.yscale('log')
#     plt.legend()
sns.despine()

In [None]:
window = 100
data = np.convolve(np.array(learner.log['time']), np.ones(window)/window, mode='valid')

with sns.plotting_context('paper'):
    plt.plot(data, linewidth=0.5)
    plt.xlabel("Episodes")
    plt.ylabel("Solve time")
#     plt.yscale('log')
#     plt.legend()
sns.despine()