From 408b570d920182a2ec2aa7527c663ddbb2a837f5 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 13 Nov 2022 11:59:35 +0100 Subject: [PATCH] Updated SplitNetwork to latest changes in networks. Added docs --- bayesflow/summary_networks.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/bayesflow/summary_networks.py b/bayesflow/summary_networks.py index 4514fa4cd..29ed9d202 100644 --- a/bayesflow/summary_networks.py +++ b/bayesflow/summary_networks.py @@ -198,7 +198,7 @@ class SplitNetwork(tf.keras.Model): of data to provide an individual network for each split of the data. """ - def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, meta={}, **kwargs): + def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, network_kwargs={}, **kwargs): """Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration specified by `meta`. @@ -207,14 +207,22 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe num_splits : int The number if splits for the data, which will equal the number of sub-networks. split_data_configurator : callable - Function that takes the arguments `i` and `x` where `i` is the index of the network - and `x` are the inputs to the `SplitNetwork`. Should return the input for the corresponding network. + Function that takes the arguments `i` and `x` where `i` is the index of the + network and `x` are the inputs to the `SplitNetwork`. Should return the input + for the corresponding network. - For example, to achieve a network with is permutation-invariant both vertically (i.e., across rows) - and horizontally (i.e., across columns), one could to: - `def config(i, x): - TODO + For example, to achieve a network with is permutation-invariant both + vertically (i.e., across rows) and horizontally (i.e., across columns), one could to: + `def split(i, x): + selector = tf.where(x[:,:,0]==i, 1.0, 0.0) + selected = x[:,:,1] * selector + split_x = tf.stack((selector, selected), axis=-1) + return split_x ` + where `x[:,:,0]` contains an integer indicating which split the data + in `x[:,:,1]` belongs to. All values in `x[:,:,1]` that are not selected + are set to zero. The selector is passed along with the modified data, + indicating which rows belong to the split `i`. network_type : callable, optional, default: `InvariantNetowk` Type of neural network to use. meta : dict, optional, default: {} @@ -227,7 +235,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe self.num_splits = num_splits self.split_data_configurator = split_data_configurator - self.networks = [network_type(meta) for _ in range(num_splits)] + self.networks = [network_type(**network_kwargs) for _ in range(num_splits)] def call(self, x): """Performs a forward pass through the subnetworks and concatenates their output.