Skip to content
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
docs/images
graphSym
notebooks
.gitignore
Gemfile
LICENSE
README.md
_config.yml
index.md
main.py
requirements.txt
setup.py

README.md

An empirical study of rotation equivariant neural networks

Charles Gallay, Michaël Defferrard, Nathanaël Perraudin

This work is a semester project done at EPFL as part of the master program of data science.

Goal

The goal of the project is to test different configuration in order to see the importance of being equivariant to rotation for a CNN.

As a first measure, classical CNN and equivariant to rotation CNN are compare on the task of image classification and more specifically on the CIFAR10 dataset.

Setup

  1. Clone this repository.

    git clone git@github.com:cgallay/GraphSymmetries.git
    cd GraphSymmetries
  2. Install the dependencies.

    pip install -r requirements.txt
    pip install -e .

Example

import torch.nn as nn

from graphSym.graph_conv import GridGraphConv
from graphSym.graph_pool import GraphMaxPool2d

class Net(nn.Module):
    """
    Network with Horizontal symetry
    """
    def __init__(self, input_shape=(32,32), nb_class=5):
        super().__init__()
        underlying_graphs = [['left', 'right'], ['top'], ['bottom']]
        conv1 = GridGraphConv(3, 30, merge_way='cat', underlying_graphs=underlying_graphs)
        conv2 = GridGraphConv(30, 60, merge_way='cat', underlying_graphs=underlying_graphs)
        
        pool1 = GraphMaxPool2d(input_shape=input_shape)
        out_shape = (16, 16)
        
        conv3 = GridGraphConv(60, 60, input_shape=out_shape, merge_way='cat', underlying_graphs=underlying_graphs)
        conv4 = GridGraphConv(60, nb_class, input_shape=out_shape, merge_way='mean', underlying_graphs=underlying_graphs)
        
        self.seq = nn.Sequential(conv1, conv2, pool1, conv3, conv4)
        
    def forward(self, x):
        out = self.seq(x)
        out = out.mean(2)
        return out

net = Net()

License

The content of this repository is released under the terms of the MIT license

You can’t perform that action at this time.