In [1]:
import graphSym
import torch.nn as nn
import numpy as np
import torch

from graphSym.graph_conv import GridGraphConv
from graphSym.graph_pool import GraphMaxPool2d

In [2]:
class Net(nn.Module):
    """
    Network with Horizontal symetry
    """
    def __init__(self, input_shape=(32,32), nb_class=5):
        super().__init__()
        underlying_graphs = [['left', 'right'], ['top'], ['bottom']]
        conv1 = GridGraphConv(3, 30, merge_way='cat', underlying_graphs=underlying_graphs)
        conv2 = GridGraphConv(30, 60, merge_way='cat', underlying_graphs=underlying_graphs)
        
        pool1 = GraphMaxPool2d(input_shape=input_shape)
        out_shape = (16, 16)
        
        conv3 = GridGraphConv(60, 60, input_shape=out_shape, merge_way='cat', underlying_graphs=underlying_graphs)
        conv4 = GridGraphConv(60, nb_class, input_shape=out_shape, merge_way='mean', underlying_graphs=underlying_graphs)
        
        self.seq = nn.Sequential(conv1, conv2, pool1, conv3, conv4)
        
    def forward(self, x):
        out = self.seq(x)
        out = out.mean(2)
        return out

net = Net()

### Test the invariant property to horizontal mirroring of the network 

In [3]:
dummy = torch.from_numpy(np.random.randn(2, 3, 32 * 32)).float()
dummy2 = dummy.view(2,3,32,32).flip([3]).contiguous().view(2,3,32*32)

np.testing.assert_almost_equal(
    net(dummy).detach().numpy(),
    net(dummy2).detach().numpy(),
    decimal=4)