In [1]:
import graphSym
import torch.nn as nn
import numpy as np
import torch

from graphSym.graph_conv import GridGraphConv
from graphSym.graph_pool import GraphMaxPool2d

In [2]:
class Net(nn.Module):
    """
    Network with Horizontal symetry
    """
    def __init__(self, input_shape=(32,32), nb_class=5):
        super().__init__()
        underlying_graphs = [['left', 'right'], ['top'], ['bottom']]
        conv1 = GridGraphConv(3, 30, merge_way='cat', underlying_graphs=underlying_graphs)
        conv2 = GridGraphConv(30, 60, merge_way='cat', underlying_graphs=underlying_graphs)
        
        pool1 = GraphMaxPool2d(input_shape=input_shape)
        out_shape = (16, 16)
        
        conv3 = GridGraphConv(60, 60, input_shape=out_shape, merge_way='cat', underlying_graphs=underlying_graphs)
        conv4 = GridGraphConv(60, nb_class, input_shape=out_shape, merge_way='mean', underlying_graphs=underlying_graphs)
        
        self.seq = nn.Sequential(conv1, conv2, pool1, conv3, conv4)
        
    def forward(self, x):
        out = self.seq(x)
        out = out.mean(2)
        return out

net = Net()

### Test the invariant property to horizontal mirroring of the network 

In [3]:
def test(net, shape=(32, 32)):
    x, y = shape
    dummy = torch.from_numpy(np.random.randn(2, 3, x*y)).float()
    dummy2 = dummy.view(2, 3, x, y).flip([3]).contiguous().view(2, 3, x*y)
    np.testing.assert_almost_equal(
        net(dummy).detach().numpy(),
        net(dummy2).detach().numpy(),
        decimal=4)
test(net)

In [4]:
from models import CIFARNet, AIDNet
test(CIFARNet())
test(AIDNet(), (200, 200))

AssertionError: 
Arrays are not almost equal to 4 decimals

Mismatch: 100%
Max absolute difference: 10.134766
Max relative difference: 57.635338
 x: array([[ -10.8202,   17.4415,   28.9768,  -70.1732, -142.3684,   95.3263,
          91.3249,   13.8723, -136.8235,  -65.718 ,   82.7342,  -94.4663,
          74.4305,   75.6022,   27.3977,    1.4254,    5.7092,   22.3228,...
 y: array([[-7.7602e+00,  1.7087e+01,  2.5187e+01, -6.4065e+01, -1.3223e+02,
         8.9457e+01,  8.9176e+01,  1.2164e+01, -1.2801e+02, -6.5104e+01,
         8.0086e+01, -9.0716e+01,  7.0685e+01,  7.6242e+01,  2.1824e+01,...