Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added subgraph_mode #93

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 8 additions & 5 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
dropout_prob: float = 0.0,
hint_teacher_forcing_noise: float = 0.0,
name: str = 'base_model',
subgraph_mode: str = None,
):
"""Constructor for BaselineModel.

Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(
self.checkpoint_path = checkpoint_path
self.name = name
self._freeze_processor = freeze_processor
self.opt = optax.adam(learning_rate)
self.opt = optax.chain(optax.clip_by_global_norm(1), optax.adam(learning_rate))

self.nb_dims = []
if isinstance(dummy_trajectory, _Feedback):
Expand All @@ -133,18 +134,20 @@ def __init__(
self.nb_dims.append(nb_dims)

self._create_net_fns(hidden_dim, encode_hints, processor_factory, use_lstm,
dropout_prob, hint_teacher_forcing_noise)
dropout_prob, hint_teacher_forcing_noise, subgraph_mode)
self.params = None
self.opt_state = None
self.opt_state_skeleton = None

def _create_net_fns(self, hidden_dim, encode_hints, processor_factory,
use_lstm, dropout_prob, hint_teacher_forcing_noise):
use_lstm, dropout_prob, hint_teacher_forcing_noise, subgraph_mode):
def _use_net(*args, **kwargs):
return nets.Net(self._spec, hidden_dim, encode_hints,
self.decode_hints, self.decode_diffs,
processor_factory, use_lstm, dropout_prob,
hint_teacher_forcing_noise, self.nb_dims)(*args, **kwargs)
hint_teacher_forcing_noise,
subgraph_mode,
self.nb_dims)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
self.net_fn_apply = jax.jit(self.net_fn.apply,
Expand Down Expand Up @@ -328,7 +331,7 @@ class BaselineModelChunked(BaselineModel):
"""

def _create_net_fns(self, hidden_dim, encode_hints, processor_factory,
use_lstm, dropout_prob, hint_teacher_forcing_noise):
use_lstm, dropout_prob, hint_teacher_forcing_noise, subgraph_mode):
def _use_net(*args, **kwargs):
return nets.NetChunked(
self._spec, hidden_dim, encode_hints,
Expand Down
35 changes: 29 additions & 6 deletions clrs/_src/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from clrs._src import processors
from clrs._src import samplers
from clrs._src import specs
from clrs._src import subgraphs_utils

import haiku as hk
import jax
Expand Down Expand Up @@ -86,9 +87,11 @@ def __init__(
use_lstm: bool,
dropout_prob: float,
hint_teacher_forcing_noise: float,
subgraph_mode: str,
nb_dims=None,
name: str = 'net',
):

"""Constructs a `Net`."""
super().__init__(name=name)

Expand All @@ -102,6 +105,7 @@ def __init__(
self.processor_factory = processor_factory
self.nb_dims = nb_dims
self.use_lstm = use_lstm
self.subgraph_mode = subgraph_mode

def _msg_passing_step(self,
mp_state: _MessagePassingScanState,
Expand Down Expand Up @@ -147,29 +151,38 @@ def _msg_passing_step(self,
probing.DataPoint(
name=hint.name, location=loc, type_=typ, data=hint_data))

gt_diffs = None
if hints[0].data.shape[0] > 1 and self.decode_diffs:
def get_gt_diffs(hints, first_idx, second_idx, batch_size, nb_nodes):
gt_diffs = {
_Location.NODE: jnp.zeros((batch_size, nb_nodes)),
_Location.EDGE: jnp.zeros((batch_size, nb_nodes, nb_nodes)),
_Location.GRAPH: jnp.zeros((batch_size))
}
for hint in hints:
hint_cur = jax.lax.dynamic_index_in_dim(hint.data, i, 0, keepdims=False)
hint_cur = jax.lax.dynamic_index_in_dim(hint.data, first_idx, 0, keepdims=False)
hint_nxt = jax.lax.dynamic_index_in_dim(
hint.data, i+1, 0, keepdims=False)
hint.data, second_idx, 0, keepdims=False)
if len(hint_cur.shape) == len(gt_diffs[hint.location].shape):
hint_cur = jnp.expand_dims(hint_cur, -1)
hint_nxt = jnp.expand_dims(hint_nxt, -1)
gt_diffs[hint.location] += jnp.any(hint_cur != hint_nxt, axis=-1)
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
gt_diffs[loc] = (gt_diffs[loc] > 0.0).astype(jnp.float32) * 1.0
return gt_diffs

gt_diffs = None
if hints[0].data.shape[0] > 1 and self.decode_diffs:
gt_diffs = get_gt_diffs(hints, i, i+1, batch_size, nb_nodes)

gt_diffs_prev = None
if hints[0].data.shape[0] > 1 and self.subgraph_mode is not None:
if not first_step:
gt_diffs_prev = get_gt_diffs(hints, i, i-1, batch_size, nb_nodes)

(hiddens, output_preds_cand, hint_preds, diff_logits,
lstm_state) = self._one_step_pred(inputs, cur_hint, mp_state.hiddens,
batch_size, nb_nodes,
mp_state.lstm_state,
spec, encs, decs, diff_decs)
spec, encs, decs, diff_decs, gt_diffs_prev)

if first_step:
output_preds = output_preds_cand
Expand Down Expand Up @@ -377,6 +390,7 @@ def _one_step_pred(
encs: Dict[str, List[hk.Module]],
decs: Dict[str, Tuple[hk.Module]],
diff_decs: Dict[str, Any],
gt_diffs_prev: Dict[_Location, Any],
):
"""Generates one-step predictions."""

Expand Down Expand Up @@ -405,12 +419,21 @@ def _one_step_pred(
except Exception as e:
raise Exception(f'Failed to process {dp}') from e

msg_adj_mat = adj_mat
if gt_diffs_prev is not None:
if self.subgraph_mode == "egonets":
msg_adj_mat = subgraphs_utils.get_egonets(gt_diffs_prev[_Location.NODE], adj_mat)
elif self.subgraph_mode == "stars":
msg_adj_mat = subgraphs_utils.get_stars(gt_diffs_prev[_Location.NODE], adj_mat)
else:
raise ValueError(f"Invalid subgraph_mode {self.subgraph_mode}")

# PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
nxt_hidden = self.processor(
node_fts,
edge_fts,
graph_fts,
adj_mat,
msg_adj_mat,
hidden,
batch_size=batch_size,
nb_nodes=nb_nodes,
Expand Down
76 changes: 76 additions & 0 deletions clrs/_src/subgraphs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import jax.numpy as jnp


# ([B, N], [B, N, N]) -> [B, N, N]
# NOTE: also keeps edges across egonets
def get_egonets(center_nodes, adj_mat):
"""Returns the adj matrix consisting of ego nets around `center_nodes`

Since jnp.nonzero is not compatible with JIT, we add an auxiliary node for each graph,
and an auxiliary graph in the batch and use the `fill_value` arg to return indices to those
added tensors.
"""
num_graphs, num_nodes = center_nodes.shape

# Add one node and one graph
center_nodes = jnp.concatenate([jnp.zeros((1, num_nodes)), center_nodes])
center_nodes = jnp.concatenate([jnp.zeros((num_graphs + 1, 1)), center_nodes], axis=-1)

adj_mat = jnp.concatenate([jnp.zeros((1, num_nodes, num_nodes)), adj_mat])
adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, 1, num_nodes)), adj_mat], axis=1)
adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, num_nodes + 1, 1)), adj_mat], axis=-1)

graph_idx, node_idx = center_nodes.nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0)

# [K, N] where K is the total number of center_nodes (summed over graphs)
center_adj_cols = adj_mat[graph_idx, :, node_idx]

# [B, N]: for each graph, whether node n is a neighbour of a center_node
center_neighbors = jnp.zeros((adj_mat.shape[0], adj_mat.shape[-1]))
center_neighbors = center_neighbors.at[graph_idx].add(center_adj_cols)

# Add center nodes
ego_nodes = center_neighbors + center_nodes

graph_idx, removed_node_idx = (ego_nodes == 0).nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0)

# Zero out edges incoming/outgoing to/from removed nodes
adj_mat = adj_mat.at[graph_idx, removed_node_idx].set(0)
adj_mat = adj_mat.at[graph_idx, :, removed_node_idx].set(0)

# Remove the added node and graph
adj_mat = adj_mat[1:, 1:, 1:]
return adj_mat

# ([B, N], [B, N, N]) -> [B, N, N]
def get_stars(center_nodes, adj_mat):
"""Returns the adj matrix consisting of star subgraphs around `center_nodes`

Since jnp.nonzero is not compatible with JIT, we add an auxiliary node for each graph,
and an auxiliary graph in the batch and use the `fill_value` arg to return indices to those
added tensors.
"""
num_graphs, num_nodes = center_nodes.shape

# Add one node and one graph
center_nodes = jnp.concatenate([jnp.zeros((1, num_nodes)), center_nodes])
center_nodes = jnp.concatenate([jnp.zeros((num_graphs + 1, 1)), center_nodes], axis=-1)

adj_mat = jnp.concatenate([jnp.zeros((1, num_nodes, num_nodes)), adj_mat])
adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, 1, num_nodes)), adj_mat], axis=1)
adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, num_nodes + 1, 1)), adj_mat], axis=-1)

graph_idx, node_idx = center_nodes.nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0)

# [K, N] where K is the total number of center_nodes (summed over graphs)
center_adj_cols = adj_mat[graph_idx, :, node_idx]
center_adj_rows = adj_mat[graph_idx, node_idx, :]

# Zero out all edges, except those incoming/outgoing to/from center_nodes
new_adj_mat = jnp.zeros(adj_mat.shape)
new_adj_mat = new_adj_mat.at[graph_idx, :, node_idx].set(center_adj_cols)
new_adj_mat = new_adj_mat.at[graph_idx, node_idx].add(center_adj_rows)

# Remove the added node and graph
new_adj_mat = new_adj_mat[1:, 1:, 1:]
return new_adj_mat
9 changes: 9 additions & 0 deletions clrs/examples/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@
'Path in which dataset is stored.')
flags.DEFINE_boolean('freeze_processor', False,
'Whether to freeze the processor of the model.')
flags.DEFINE_enum('subgraph_mode', 'none',
['stars', 'egonets', 'none'],
'If not `None`, then use as adjacency matrix the subgraph '
'around the nodes having hints that changed from the '
'last timestep ')

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -219,6 +224,9 @@ def main(unused_argv):
else:
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')

if FLAGS.subgraph_mode == 'none':
FLAGS.subgraph_mode = None

common_args = dict(folder=dataset_folder,
algorithm=FLAGS.algorithm,
batch_size=FLAGS.batch_size)
Expand Down Expand Up @@ -257,6 +265,7 @@ def main(unused_argv):
freeze_processor=FLAGS.freeze_processor,
dropout_prob=FLAGS.dropout_prob,
hint_teacher_forcing_noise=FLAGS.hint_teacher_forcing_noise,
subgraph_mode=FLAGS.subgraph_mode,
)

eval_model = clrs.models.BaselineModel(
Expand Down