Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions bayesflow/summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class SplitNetwork(tf.keras.Model):
of data to provide an individual network for each split of the data.
"""

def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, meta={}, **kwargs):
def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, network_kwargs={}, **kwargs):
"""Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration
specified by `meta`.

Expand All @@ -207,14 +207,22 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
num_splits : int
The number if splits for the data, which will equal the number of sub-networks.
split_data_configurator : callable
Function that takes the arguments `i` and `x` where `i` is the index of the network
and `x` are the inputs to the `SplitNetwork`. Should return the input for the corresponding network.
Function that takes the arguments `i` and `x` where `i` is the index of the
network and `x` are the inputs to the `SplitNetwork`. Should return the input
for the corresponding network.

For example, to achieve a network with is permutation-invariant both vertically (i.e., across rows)
and horizontally (i.e., across columns), one could to:
`def config(i, x):
TODO
For example, to achieve a network with is permutation-invariant both
vertically (i.e., across rows) and horizontally (i.e., across columns), one could to:
`def split(i, x):
selector = tf.where(x[:,:,0]==i, 1.0, 0.0)
selected = x[:,:,1] * selector
split_x = tf.stack((selector, selected), axis=-1)
return split_x
`
where `x[:,:,0]` contains an integer indicating which split the data
in `x[:,:,1]` belongs to. All values in `x[:,:,1]` that are not selected
are set to zero. The selector is passed along with the modified data,
indicating which rows belong to the split `i`.
network_type : callable, optional, default: `InvariantNetowk`
Type of neural network to use.
meta : dict, optional, default: {}
Expand All @@ -227,7 +235,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe

self.num_splits = num_splits
self.split_data_configurator = split_data_configurator
self.networks = [network_type(meta) for _ in range(num_splits)]
self.networks = [network_type(**network_kwargs) for _ in range(num_splits)]

def call(self, x):
"""Performs a forward pass through the subnetworks and concatenates their output.
Expand Down