In time profile, noticed `optimizer.step()` was taking 25% of time, most of which was from `Batch.from_data_list()`. Using this notebook to try investigate.

In [1]:
%load_ext autoreload
# %load_ext snakeviz

%autoreload
from retro_branching.environments import EcoleBranching, EcoleConfiguring
from retro_branching.agents import StrongBranchingAgent, PseudocostBranchingAgent
from retro_branching.utils import seed_stochastic_modules_globally

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Batch, Data

import ecole

from os import path
import numpy as np
import copy
import time
from collections import defaultdict, deque, namedtuple
import itertools
from tqdm import tqdm
import random
import pickle
import gzip

import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout

seed = 0
seed_stochastic_modules_globally(default_seed=seed)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


## Looking at `buffer.sample()` time

Minimial implementation of replay buffer

In [None]:
# torch geoemtric data object for states
class BipartiteNodeData(torch_geometric.data.Data):
    """
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` 
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(self, obs, candidates):
        super().__init__()
        self.obs = obs
        self.constraint_features = torch.FloatTensor(obs.row_features)
        self.edge_index = torch.LongTensor(obs.edge_features.indices.astype(np.int64))
        self.edge_attr = torch.FloatTensor(obs.edge_features.values).unsqueeze(1)
        self.variable_features = torch.FloatTensor(obs.column_features)
        self.candidates = torch.from_numpy(candidates.astype(np.int64)).long()
        self.raw_candidates = torch.from_numpy(candidates.astype(np.int64)).long()
        
        self.num_candidates = len(candidates)
        self.num_variables = self.variable_features.size(0)
        self.num_nodes = self.constraint_features.size(0) + self.variable_features.size(0)

    def __inc__(self, key, value):
        """
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs 
        for those entries (edge index, candidates) for which this is not obvious. This
        enables batching.
        """
        if key == 'edge_index':
            # constraint nodes connected via edge to variable nodes
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'candidates':
            # actions are variable nodes
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value)

Transition = namedtuple('Transition', field_names=['state', 'action', 'reward', 'done',  'next_state'])

class ReplayBuffer:
    def __init__(self, 
                 capacity):
        '''
        Args:
            capacity (int): Maximum capacity of replay buffer.
        '''
        # init experience replay buffer
        self.capacity = capacity
        self.buffer = []
        for _ in range(capacity):
            self.buffer.append(None)
        self.curr_write_idx, self.available_samples = 0, 0

    def __len__(self):
        return self.available_samples

    def append(self, transition):
        self.buffer[self.curr_write_idx] = transition

        # update write idx
        self.curr_write_idx += 1
        if self.curr_write_idx >= self.capacity:
            # reset to start overwriting old experiences
            self.curr_write_idx = 0

        # max out the available samples at the memory buffer size
        self.available_samples = min(self.available_samples+1, self.capacity)

    def sample(self, batch_size, per_beta=None):
        # standard experience replay with random uniform sampling
        indices = np.random.choice(len(self.buffer[:self.available_samples]), batch_size, replace=False)

        # collect the sampled transitions 
        state, action, reward, done, next_state = zip(*[self.buffer[idx] for idx in indices])

        return (Batch.from_data_list(copy.deepcopy(state)),
                torch.tensor(copy.deepcopy(action)),
                torch.tensor(copy.deepcopy(reward)),
                torch.tensor(copy.deepcopy(done)).float(),
                Batch.from_data_list(copy.deepcopy(next_state)))


In [None]:
# init agent, env, and instances
agent = StrongBranchingAgent()

env = EcoleBranching(observation_function='43_var_features',
                      information_function='default',
                      reward_function='default',
                      scip_params='default')
env.seed(seed)

instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)

In [None]:
# fill buffer with experiences
buffer = ReplayBuffer(capacity=128)

env_ready = False
total_env_step_time = 0
while len(buffer) < buffer.capacity:
    while not env_ready:
        # find instance not pre-solved by the environment
        env.seed(seed)
        instance = next(instances)
        agent.before_reset(instance)
        obs, action_set, reward, done, info = env.reset(instance)
        if obs is not None:
            env_ready = True
            state = BipartiteNodeData(obs, action_set).to('cpu')
        
    # store prev transition params
    prev_obs, prev_action_set, prev_state = copy.deepcopy(obs), copy.deepcopy(action_set), copy.deepcopy(state)
        
    # get  branching action
    action, action_idx = agent.action_select(action_set, env.model, done)
    
    # take step in environment
    start = time.time()
    obs, action_set, reward, done, info = env.step(action)
    total_env_step_time += (time.time() - start)
    
    if done:
        # solved instance, avoid None values in buffer
        obs = copy.deepcopy(prev_obs)
        action_set = copy.deepcopy(prev_action_set)
        env_ready = False
    
    # add transition to buffer
    state = BipartiteNodeData(obs, action_set).to('cpu')
    buffer.append(Transition(prev_state, action.item(), reward['normalised_lp_gain'], done, state))
    print(f'Buffer size: {len(buffer)}/{buffer.capacity}')
    
print(f'Total env step time when collection {len(buffer)} samples: {total_env_step_time:.3f} s.')

In [None]:
# sample batch of experiences
batch_size = 128
start = time.time()
# %snakeviz state, action, reward, done, next_state = buffer.sample(batch_size)
state, action, reward, done, next_state = buffer.sample(batch_size)
print(f'Time to sample {batch_size} samples from buffer: {time.time()-start:.3f} s.')

## Looking at `learner.step_optimizer()` time

Snakeviz time profiler seemed to suggest that, for large 500x1000 instances, the bottleneck is in some list comprehension statements.

In [16]:
num_heads = 1
num_logits = 128000
num_agents = 2

logits = [[torch.randn(num_logits, device='cuda:1') for _ in range(num_heads)] for _ in range(num_agents)]
print(len(logits))
print(len(logits[0]))
print(logits[0][0].shape)
print(logits[0][0].get_device())

2
1
torch.Size([128000])
1


In [3]:
def find_min_head_logits_across_agents(logits):
    '''
    Given a list of lists, where logits[agent_idx][head_idx] is a tensor of agent head outputs,
    finds the minimum between each agent_idx's head outputs and returns
    the minimum logits for each head as a list of tensors.
    '''
    min_logits = []
    for head in range(len(logits[0])):
        head_logits = []
        for agent_idx in range(len(logits)):
            head_logits.append(logits[agent_idx][head])
        min_logits.append(torch.stack(head_logits, dim=-1).min(dim=-1))
    logits = [_logits.values for _logits in min_logits]

In [4]:
%timeit -n10 find_min_head_logits_across_agents(logits)

The slowest run took 4.18 times longer than the fastest. This could mean that an intermediate result is being cached.
36.4 µs ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    find_min_head_logits_across_agents(logits)
print(prof.table(sort_by='cuda_time_total'))

--------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
--------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
         aten::stack        17.95%      34.998us        70.82%     138.042us     138.042us       6.144us        37.50%      13.312us      13.312us             1  
           aten::cat         7.02%      13.680us        45.67%      89.015us      89.015us       4.096us        25.00%       7.168us       7.168us             1  
          aten::_cat        25.59%      49.879us        38.65%      75.335us      75.335us       3.072us        18.75%       3.072us       3.072us             1  
           aten::min 