In [11]:
import torch
import torch.nn as nn
from ceconv2d import CEConv2D

# Generate dummy input.
x = torch.rand(1, 3, 32, 32)
x_hidden = torch.rand(1, 16, 3, 32, 32)

# Add extra "temporal" dimension to input.
x_3d = x.unsqueeze(1)

print("Vanilla input: \t", x.shape)
print("3D input: \t", x_3d.shape)

Vanilla input: 	 torch.Size([1, 3, 32, 32])
3D input: 	 torch.Size([1, 1, 3, 32, 32])


## Lifting layer

Lifting layer of Color Equivariant Convolution (CEConv) is equivalent to a 3D convolution over the color channels with circular padding.

In [12]:
# Define "vanilla" color equivariant model with single lifting layer.
model_vanilla = nn.Sequential(
    CEConv2D(
        in_rotations=1,
        out_rotations=3,
        in_channels=3,
        out_channels=16,
        kernel_size=3,
        padding=0,
    ),
)
print("Weight: \t", model_vanilla[0].weight.shape)

# Forward pass.
y_vanilla = model_vanilla(x)
print("Output tensor: \t", y_vanilla.shape)



Weight: 	 torch.Size([16, 3, 1, 3, 3])
Output tensor: 	 torch.Size([1, 16, 3, 30, 30])


In [13]:
# Define Conv3d implementation of color equivariant model.
model_3d = nn.Sequential(
    nn.Conv3d(
        in_channels=1,
        out_channels=16,
        kernel_size=(3, 3, 3),  # temporal dimension is 3 because RGB
        padding=(1, 0, 0),  # we apply padding to the temporal dimension
        padding_mode='circular',  # and we use circular padding
    ),
)
print("Weight: \t", model_3d[0].weight.shape)

# Copy weights from vanilla model to 3D model.
w = model_vanilla[0].weight.data
w = torch.permute(w, (0, 2, 1, 3, 4))
w = w[:, :, (2, 0, 1), :, :]  # This permutation is needed because CEConv starts with RGB, whereas Conv3d starts with BRG.
model_3d[0].weight = nn.parameter.Parameter(w)
model_3d[0].bias = model_vanilla[0].bias

y_3d = model_3d(x_3d)
print("Output tensor \t", y_3d.shape)

Weight: 	 torch.Size([16, 1, 3, 3, 3])
Output tensor 	 torch.Size([1, 16, 3, 30, 30])


In [14]:
# Check that the outputs are the same.
torch.allclose(y_vanilla, y_3d, atol=1e-7)

True

## Hidden layer

In [15]:
# Define "vanilla" color equivariant model with single hidden layer.
model_vanilla = nn.Sequential(
    CEConv2D(
        in_rotations=3,
        out_rotations=3,
        in_channels=16,
        out_channels=16,
        kernel_size=3,
        padding=0,
        separable=False,
    ),
)
print("Weight: \t", model_vanilla[0].weight.shape)

# Forward pass.
y_vanilla = model_vanilla(x_hidden)
print("Output tensor: \t", y_vanilla.shape)

Weight: 	 torch.Size([16, 16, 3, 3, 3])
Output tensor: 	 torch.Size([1, 16, 3, 30, 30])


In [16]:
# Define Conv3d implementation of color equivariant model.
model_3d = nn.Sequential(
    nn.Conv3d(
        in_channels=16,
        out_channels=16,
        kernel_size=(3, 3, 3),
        padding=(1, 0, 0),
        padding_mode='circular',
    ),
)
print("Weight: \t", model_3d[0].weight.shape)

# Copy weights from vanilla model to 3D model.
w = model_vanilla[0].weight.data
w = w[:, :, (2, 0, 1), :, :]
model_3d[0].weight = nn.parameter.Parameter(w)
model_3d[0].bias = model_vanilla[0].bias

y_3d = model_3d(x_hidden)
print("Output tensor \t", y_3d.shape)


Weight: 	 torch.Size([16, 16, 3, 3, 3])
Output tensor 	 torch.Size([1, 16, 3, 30, 30])


In [17]:
# Check that the outputs are the same.
torch.allclose(y_vanilla, y_3d, atol=1e-6)

True

## Multilayer model

In [18]:
# Define "vanilla" color equivariant model.
model_vanilla = nn.Sequential(
    CEConv2D(
        in_rotations=1, 
        out_rotations=3,
        in_channels=3,
        out_channels=16,
        kernel_size=3,
        padding=0,
    ),
    nn.ReLU(),
    CEConv2D(
        in_rotations=3,
        out_rotations=3,
        in_channels=16,
        out_channels=16,
        kernel_size=3,
        padding=0,
        separable=False,
    ),
    nn.ReLU(),
    CEConv2D(
        in_rotations=3,
        out_rotations=3,
        in_channels=16,
        out_channels=16,
        kernel_size=3,
        padding=0,
        separable=False,
    ),
)
print("Weight layer 1: \t", model_vanilla[0].weight.shape)
print("Weight layer 2: \t", model_vanilla[2].weight.shape)
print("Weight layer 3: \t", model_vanilla[4].weight.shape)

# Forward pass.
y_vanilla = model_vanilla(x)
print("Output tensor: \t\t", y_vanilla.shape)

Weight layer 1: 	 torch.Size([16, 3, 1, 3, 3])
Weight layer 2: 	 torch.Size([16, 16, 3, 3, 3])
Weight layer 3: 	 torch.Size([16, 16, 3, 3, 3])
Output tensor: 		 torch.Size([1, 16, 3, 26, 26])


In [19]:
# Define Conv3d implementation of color equivariant model.
model_3d = nn.Sequential(
    nn.Conv3d(
        in_channels=1,
        out_channels=16,
        kernel_size=(3, 3, 3),
        padding=(1, 0, 0),
        padding_mode='circular',
    ),
    nn.ReLU(),
    nn.Conv3d(
        in_channels=16,
        out_channels=16,
        kernel_size=(3, 3, 3),
        padding=(1, 0, 0),
        padding_mode='circular',
    ),
    nn.ReLU(),
    nn.Conv3d(
        in_channels=16,
        out_channels=16,
        kernel_size=(3, 3, 3),
        padding=(1, 0, 0),
        padding_mode='circular',
    ),
)

# Copy weights from vanilla model to 3D model.
w = model_vanilla[0].weight.data
w = torch.permute(w, (0, 2, 1, 3, 4))
w = w[:, :, (2, 0, 1), :, :]
model_3d[0].weight = nn.parameter.Parameter(w.clone())
model_3d[0].bias = model_vanilla[0].bias

w = model_vanilla[2].weight.data
w = w[:, :, (2, 0, 1), :, :]
model_3d[2].weight = nn.parameter.Parameter(w.clone())
model_3d[2].bias = model_vanilla[2].bias

w = model_vanilla[4].weight.data
w = w[:, :, (2, 0, 1), :, :]
model_3d[4].weight = nn.parameter.Parameter(w)
model_3d[4].bias = model_vanilla[4].bias

y_3d = model_3d(x_3d)
print("Output tensor \t", y_3d.shape)

Output tensor 	 torch.Size([1, 16, 3, 26, 26])


In [20]:
# Check that the outputs are the same.
torch.allclose(y_vanilla, y_3d, atol=1e-5)

True

Remarks:

* `torch.allclose` only returns `True` if the two tensors are equal within a relatively high tolerance, so there are some subtle numerical differences between the two implementations.
* It would be interesting to see if the differences in the two implementations lead to different benchmark results.
* The Conv3d implementation may be faster than the CEConv implementation, so may be worth replacing it under the hood.
* I still have to implement the separable version of CEConv as Conv3d, but that should be straightforward.