In [1]:
import jax.numpy as jnp

channels_in = 10
groups = 2
d = jnp.arange(channels_in * groups).reshape((channels_in, groups))
d.shape

(10, 2)

In [2]:
d

DeviceArray([[ 0,  1],
             [ 2,  3],
             [ 4,  5],
             [ 6,  7],
             [ 8,  9],
             [10, 11],
             [12, 13],
             [14, 15],
             [16, 17],
             [18, 19]], dtype=int32)

In [3]:
d.reshape((channels_in * groups))

DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19], dtype=int32)

In [4]:
d.flatten()

DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19], dtype=int32)

Rows are flattened first, then columns are stacked (not the other way around).

In [13]:
features_in_group = 2

In [14]:
tiled_d = jnp.tile(d, [1, features_in_group])
tiled_d.shape

(10, 4)

Here, features_per_group = 2.

In [15]:
tiled_d

DeviceArray([[ 0,  1,  0,  1],
             [ 2,  3,  2,  3],
             [ 4,  5,  4,  5],
             [ 6,  7,  6,  7],
             [ 8,  9,  8,  9],
             [10, 11, 10, 11],
             [12, 13, 12, 13],
             [14, 15, 14, 15],
             [16, 17, 16, 17],
             [18, 19, 18, 19]], dtype=int32)

Before, we had 2 sets of 10 offsets (1 for each group).
offset_1, offset_2

When we tile, we end up with offset_1, offset_2, offset_1, offset_2 (vs. offset_1, offset_1, offset_2, offset_2).

In [9]:
flat_d = tiled_d.reshape(-1)
flat_d

DeviceArray([ 0,  1,  0,  1,  2,  3,  2,  3,  4,  5,  4,  5,  6,  7,  6,
              7,  8,  9,  8,  9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15,
             14, 15, 16, 17, 16, 17, 18, 19, 18, 19], dtype=int32)

When we flatten end up with the following:

- channel_0, offset_1, filter_1 (1 in group 1)
- channel_0, offset_2, filter_2 (2 in group 2)
- channel_0, offset_1, filter_3 (1 in group 2)
- channel_0, offset_2, filter_4 (2 in group 2)

---
- channel_1, offset_1, filter_1 (1 in group 1)
- channel_1, offset_2, filter_2 (2 in group 2)
- channel_1, offset_1, filter_3 (1 in group 2)
- channel_1, offset_2, filter_4 (2 in group 2)

We want to sum across each filter axis.

features_in_groups * num_groups = num_features

Total # chanels = num_features * channels_in

In [10]:
len(flat_d)

40

In [16]:
num_filters = features_in_group * groups
num_filters, num_filters * channels_in

(4, 40)

Now we pretend that we performed the filtering and reshape it back.

In [44]:
final = jnp.reshape(flat_d, [channels_in, num_filters])

In [47]:
final.transpose((-1, -2))

DeviceArray([[ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18],
             [ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19],
             [ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18],
             [ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19]], dtype=int32)

In [43]:
final.transpose((0, 2,  1)).reshape((num_filters, channels_in))

DeviceArray([[ 0,  4,  1,  5,  0,  6,  1,  7,  2,  6],
             [ 3,  7,  2,  8,  3,  9,  4,  8,  5,  9],
             [10, 14, 11, 15, 10, 16, 11, 17, 12, 16],
             [13, 17, 12, 18, 13, 19, 14, 18, 15, 19]], dtype=int32)

In [27]:
flat_d.shape

(40,)

In [33]:
jnp.reshape(flat_d, [num_groups, channels_in]).shape

(4, 10)

In [19]:
groups

2

In [20]:
jnp.sum(final, axis=-1)

DeviceArray([ 21,  69, 121, 169], dtype=int32)

4 filter outputs as we wanted.