What is included and how to use the hypercomplex layers?
1. Implementation of known hypercomplex functions for linear and convolution layers
    - Complex to Sedenion
2. Implementation of flexible hypercomplex functions for not so commonly used type above 16 components
    - All it takes is the use of n_divs to define number of components of the hypercomplex
    

In [1]:
import torch

## Linear Layers understanding

In [2]:
# Let's starts with Known Functions
from fast_hypercomplex import ComplexLinear, QuaternionLinear, OctonionLinear, SedenionLinear

# and the choice of flexibilibity .... n_divs has to be defined
from fast_hypercomplex import HyperLinear

In [3]:
# sample imput
x = torch.rand(32, 128)   # B, C

# complex usage
layer = ComplexLinear(in_features=128, out_features=256, bias=False) # with other regular options in nn.Linear
y = layer(x)
print('Complex Linear:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')


# quaternion usage
layer = QuaternionLinear(in_features=128, out_features=256, bias=False) # with other regular options in nn.Linear
y = layer(x)
print('Quaternion Linear:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')


# octonion usage
layer = OctonionLinear(in_features=128, out_features=256, bias=False) # with other regular options in nn.Linear
y = layer(x)
print('Octonion Linear:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')


# sedenion usage
layer = SedenionLinear(in_features=128, out_features=256, bias=False) # with other regular options in nn.Linear
y = layer(x)
print('Sedenion Linear:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')

Complex Linear:
16384 

Quaternion Linear:
8192 

Octonion Linear:
4096 

Sedenion Linear:
2048 



In [4]:
# and now to use the flexible layer

n_divs_dict = {2: 'Complex', 4: 'Quaternion', 8: 'Octonion', 16: 'Sedenion'}

for n_divs, name in n_divs_dict.items():
    layer = HyperLinear(in_features=128, out_features=256, bias=False, n_divs=n_divs) # note the use of n_divs here
    y = layer(x)
    print(f'{name} Linear:')
    print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')

Complex Linear:
16384 

Quaternion Linear:
8192 

Octonion Linear:
4096 

Sedenion Linear:
2048 



In [5]:
# let's move on to a regular random choice of n_divs

n_divs = 64
layer = HyperLinear(in_features=128, out_features=256, bias=False, n_divs=n_divs) # note the use of n_divs here
y = layer(x)

print(f'{n_divs} division Hypercomplex Linear layer:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')

64 division Hypercomplex Linear layer:
512 



## Convolution Layers Understanding

In [6]:
# Let's starts with Known Functions
from fast_hypercomplex import ComplexConv1d, QuaternionConv1d, OctonionConv1d, SedenionConv1d  # 1d convolution
from fast_hypercomplex import ComplexConv2d, QuaternionConv2d, OctonionConv2d, SedenionConv2d  # 2d convolution
from fast_hypercomplex import ComplexConv3d, QuaternionConv3d, OctonionConv3d, SedenionConv3d  # 3d convolution

# and the choice of flexibilibity .... n_divs has to be defined
from fast_hypercomplex import HyperConv1d, HyperConv2d, HyperConv3d

In [7]:
# sample imput
x = torch.rand(32, 128, 64, 64)   # B, C, H, W

# complex usage
layer = ComplexConv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, bias=False) # with other regular options in nn.Conv2d
y = layer(x)
print('Complex Conv2d:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')


# quaternion usage
layer = QuaternionConv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, bias=False) # with other regular options in nn.Conv2d
y = layer(x)
print('Quaternion Conv2d:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')


# octonion usage
layer = OctonionConv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, bias=False) # with other regular options in nn.Conv2d
y = layer(x)
print('Octonion Conv2d:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')


# sedenion usage
layer = SedenionConv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, bias=False) # with other regular options in nn.Conv2d
y = layer(x)
print('Sedenion Conv2d:')
print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')

Complex Conv2d:
147456 

Quaternion Conv2d:
73728 

Octonion Conv2d:
36864 

Sedenion Conv2d:
18432 



In [8]:
# and now to use the flexible layer

n_divs_dict = {2: 'Complex', 4: 'Quaternion', 8: 'Octonion', 16: 'Sedenion'}

for n_divs, name in n_divs_dict.items():
    layer = HyperConv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, bias=False, n_divs=n_divs) # note the use of n_divs here
    y = layer(x)
    print(f'{name} Conv2d:')
    print(sum(p.numel() for p in layer.parameters() if p.requires_grad), '\n')

Complex Conv2d:
147456 

Quaternion Conv2d:
73728 

Octonion Conv2d:
36864 

Sedenion Conv2d:
18432 



## And then to full implementation of UNet with ResNet backbone

### let's see what we got with resnet classification with 10 classes

In [9]:
#let's see what we got with resnet classification with 10 classes
from resnet import resnet20, resnet32, resnet44 #... 

In [10]:
# sample imput
x = torch.rand(32, 3, 64, 64)   # B, C, H, W

n_divs_dict = {1: 'Real', 2: 'Complex', 4: 'Quaternion', 8: 'Octonion', 16: 'Sedenion'}

for n_divs, name in n_divs_dict.items():
    net = resnet32(n_divs=n_divs) # note the use of n_divs here
    y = net(x)
    print(f'{name} based Resnet32')
    print(sum(p.numel() for p in net.parameters() if p.requires_grad), '\n')

Real based Resnet32
464154 

Complex based Resnet32
233784 

Quaternion based Resnet32
118698 

Octonion based Resnet32
61542 

Sedenion based Resnet32
34494 



### let's see what we got with UNet segmentation

In [11]:
from resnet import HyperUnet  # here, we simplify with use of n_divs

In [14]:
# sample imput
x = torch.rand(32, 3, 64, 64)   # B, C, H, W

n_divs_dict = {1: 'Real', 2: 'Complex', 4: 'Quaternion', 8: 'Octonion', 16: 'Sedenion'}

for n_divs, name in n_divs_dict.items():
    net = HyperUnet(
        num_blocks=[4,4,4], 
        in_channels=3, 
        in_planes=32, 
        num_classes=1,  # just a possible number of segmentation channels
        n_divs=n_divs,
    ) # note the use of n_divs here
    y = net(x)
    print(f'{name} Hypercomplex UNet:', y.shape)
    print(sum(p.numel() for p in net.parameters() if p.requires_grad), '\n')

Real Hypercomplex UNet: torch.Size([32, 1, 64, 64])
2569792 

Complex Hypercomplex UNet: torch.Size([32, 1, 64, 64])
1288012 

Quaternion Hypercomplex UNet: torch.Size([32, 1, 64, 64])
646956 

Octonion Hypercomplex UNet: torch.Size([32, 1, 64, 64])
326648 

Sedenion Hypercomplex UNet: torch.Size([32, 1, 64, 64])
166704 

