### Tutorial: Parameterized Hypercomplex Convolutional (PHC) Layer

#### Author: Eleonora Grassucci

In [1]:
# Imports

import numpy as np
import math
import time
import imageio
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
from torch.nn import init

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Check Pytorch version: torch.kron is available from 1.8.0
torch.__version__

'1.8.1+cu102'

### Learn the convolution

In [4]:
# Define the PHC class

class PHC(nn.Module):
    '''
    Simple PHC Module, the only parameter is A, since F is passed from the trainset.
    '''
    
    def __init__(self, n, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.n = n
        A = torch.empty((n, n, n))
        self.A = nn.Parameter(A)
        self.kernel_size = kernel_size

    def forward(self, X, S):
        H = torch.zeros((self.n*self.kernel_size, self.n*self.kernel_size))
        
        # Sum of Kronecker products
        for i in range(n):
            H = H + torch.kron(self.A[i], S[i])
        H = H.view(4, 4, 2, 2)
        return torch.nn.functional.conv2d(X, H, padding=2, stride=2)

In [6]:
# Setup the training set

num_examples = 1000
batch_size = 1

X = torch.zeros((num_examples, 64)) #1x4x4x4
F = torch.zeros((num_examples, 64)) #4x4x2x2
Y = torch.zeros((num_examples, 64)) #1x4x4x4

for i in range(num_examples):
    # Iterate to create the dataset
    x = torch.randint(low=-10, high=10, size=(64, ), dtype=torch.float)
    f = torch.randint(low=-10, high=10, size=(16, ), dtype=torch.float)

    f1, f2, f3, f4 = f[0:4], f[4:8], f[8:12], f[12:16]
    f1 = f1.view(2,2)
    f2 = f2.view(2,2)
    f3 = f3.view(2,2)
    f4 = f4.view(2,2)

    # Hamilton product rule
    f_1 = torch.cat([f1,-f2,-f3,-f4])
    f_2 = torch.cat([f2,f1,-f4,f3])
    f_3 = torch.cat([f3,f4,f1,-f2])
    f_4 = torch.cat([f4,-f3,f2,f1])

    W = torch.cat([f_1, f_2, f_3, f_4], dim=1)    
    W_conv = W.view(4, 4, 2, 2)
    x_conv = x.view(1, 4, 4, 4)

    # Apply convolution from inputx x_conv and filters W_conv
    y = torch.nn.functional.conv2d(x_conv, W_conv, padding=2, stride=2)
    y = y.view(64, )
    f_loader = torch.cat([f, torch.zeros(48)])

    X[i, :] = x
    F[i, :] = f_loader
    Y[i, :] = y

X = torch.FloatTensor(X).view(num_examples, 64, 1)
F = torch.FloatTensor(F).view(num_examples, 64, 1)
Y = torch.FloatTensor(Y).view(num_examples, 64, 1)

data = torch.cat([X, F, Y], dim=2)
train_iter = torch.utils.data.DataLoader(data, batch_size=batch_size)

# Setup the test set

num_examples = 1
batch_size = 1

X = torch.zeros((num_examples, 64)) #1x4x4x4
F = torch.zeros((num_examples, 64)) #4x4x2x2
Y = torch.zeros((num_examples, 64)) #1x4x4x4

for i in range(num_examples):
    x = torch.randint(low=-10, high=10, size=(64, ), dtype=torch.float)
    f = torch.randint(low=-10, high=10, size=(16, ), dtype=torch.float)

    f1, f2, f3, f4 = f[0:4], f[4:8], f[8:12], f[12:16]
    f1 = f1.view(2,2)
    f2 = f2.view(2,2)
    f3 = f3.view(2,2)
    f4 = f4.view(2,2)

    f_1 = torch.cat([f1,-f2,-f3,-f4])
    f_2 = torch.cat([f2,f1,-f4,f3])
    f_3 = torch.cat([f3,f4,f1,-f2])
    f_4 = torch.cat([f4,-f3,f2,f1])

    W = torch.cat([f_1, f_2, f_3, f_4], dim=1)    
    W_conv = W.view(4, 4, 2, 2)
    x_conv = x.view(1, 4, 4, 4)

    #     y = torch.matmul(x_mult, W.T)
    y = torch.nn.functional.conv2d(x_conv, W_conv, padding=2, stride=2)
    y = y.view(64, )
    f_loader = torch.cat([f, torch.zeros(48)])

    X[i, :] = x
    F[i, :] = f_loader
    Y[i, :] = y

X = torch.FloatTensor(X).view(num_examples, 64, 1)
F = torch.FloatTensor(F).view(num_examples, 64, 1)
Y = torch.FloatTensor(Y).view(num_examples, 64, 1)

data = torch.cat([X, F, Y], dim=2)
test_iter = torch.utils.data.DataLoader(data, batch_size=batch_size)

In [7]:
# Define training function

def train(net, lr, phm=True):
    # Squared loss
    loss = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    
    for epoch in range(5):
        for data in train_iter:
            optimizer.zero_grad()
            X = data[:, :, 0]
            F = data[:, :16, 1]
            Y = data[:, :, 2]
            
            if phm:
                out = net(X.view(1, 4, 4, 4), F.view(4, 2, 2))
            else:
                out = net(X)
            
            l = loss(out, Y.view(1, 4, 4, 4))
            l.backward()
            optimizer.step()
        print(f'epoch {epoch + 1}, loss {float(l.sum() / batch_size):.6f}')

In [8]:
# Perform training

# Initialize weights
def weights_init_uniform(m):
    m.A.data.uniform_(-0.07, 0.07)
    
# Setup the parameter n
n = 4
# Create an instance of the layer
phc_layer = PHC(n, kernel_size=2)
phc_layer.apply(weights_init_uniform)
# torch.nn.init.xavier_uniform_(phm_layer.A)

train(phc_layer, 0.005)

epoch 1, loss 0.013041
epoch 2, loss 0.000000
epoch 3, loss 0.000000
epoch 4, loss 0.000000
epoch 5, loss 0.000000


In [9]:
# check parameters of the layer require grad
for name, param in phc_layer.named_parameters():
    if param.requires_grad:
        print(name, param.data)

A tensor([[[ 1.0000e+00, -3.4842e-08,  9.3512e-09, -5.0719e-09],
         [-2.0920e-08,  1.0000e+00, -1.4429e-07, -1.2676e-08],
         [-6.9660e-08,  2.9980e-09,  1.0000e+00, -7.5538e-08],
         [-3.6260e-08,  3.3173e-08, -7.2906e-08,  1.0000e+00]],

        [[ 5.1754e-08,  1.0000e+00,  1.1234e-07, -4.6763e-08],
         [-1.0000e+00,  1.3011e-07, -1.1934e-08, -2.5277e-08],
         [-8.2055e-08,  5.1245e-08,  6.2276e-08,  1.0000e+00],
         [ 1.0933e-08, -7.0609e-08, -1.0000e+00,  1.5147e-08]],

        [[ 4.0813e-08,  4.1984e-08,  1.0000e+00, -7.7404e-09],
         [-7.8505e-09,  9.8508e-09, -5.7320e-08, -1.0000e+00],
         [-1.0000e+00,  1.7232e-08,  1.1296e-08, -3.2608e-08],
         [ 4.0355e-08,  1.0000e+00, -5.0336e-08,  6.4603e-09]],

        [[ 3.8183e-08,  8.0717e-08, -2.0061e-08,  1.0000e+00],
         [ 6.4894e-08, -2.9346e-08,  1.0000e+00, -4.8612e-08],
         [-7.3210e-08, -1.0000e+00,  3.0607e-08,  4.8066e-08],
         [-1.0000e+00, -1.3118e-08, -1.4183e-08

In [14]:
# Take a look at the convolution performed on the test set

for data in test_iter:
    X = data[:, :, 0]
    F = data[:, :16, 1]
    Y = data[:, :, 2]
    
    
    y_phc = phc_layer(X.view(1, 4, 4, 4), F.view(4, 2, 2))
    
    print('Convolution from test set:\n', Y.view(1, 4, 4, 4))
    print('Performing convolution learned by PHC:\n', y_phc)


Convolution from test set:
 tensor([[[[   0.,    0.,    0.,    0.],
          [   0.,  -99., -165.,    0.],
          [   0.,    4., -286.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[   0.,    0.,    0.,    0.],
          [   0., -138.,  102.,    0.],
          [   0., -246.,  -23.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[   0.,    0.,    0.,    0.],
          [   0.,  -28.,  352.,    0.],
          [   0.,  250., -142.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[   0.,    0.,    0.,    0.],
          [   0., -183., -367.,    0.],
          [   0.,  -69.,  -77.,    0.],
          [   0.,    0.,    0.,    0.]]]])
Performing convolution learned by PHC:
 tensor([[[[   0.0000,    0.0000,    0.0000,    0.0000],
          [   0.0000,  -99.0000, -165.0000,    0.0000],
          [   0.0000,    4.0000, -286.0000,    0.0000],
          [   0.0000,    0.0000,    0.0000,    0.0000]],

         [[   0.0000,    0.0000,    0.0000,    0.0000],
 

In [19]:
# Check the PHC layer have learnt the proper algebra for the marix A

gt = np.array([[1, -1, -1, -1],
      [1, 1, -1, 1],
      [1, 1, 1, -1],
      [1, -1, 1, 1]])

print('Ground-truth Hamilton rule:\n', gt)
print()
print('Learned A matrices in PHC:\n', phc_layer.A)
print()
print('Learned final A in PHC:\n', sum(phc_layer.A).T)

Ground-truth Hamilton rule:
 [[ 1 -1 -1 -1]
 [ 1  1 -1  1]
 [ 1  1  1 -1]
 [ 1 -1  1  1]]

Learned A matrices in PHC:
 Parameter containing:
tensor([[[ 1.0000e+00, -3.4842e-08,  9.3512e-09, -5.0719e-09],
         [-2.0920e-08,  1.0000e+00, -1.4429e-07, -1.2676e-08],
         [-6.9660e-08,  2.9980e-09,  1.0000e+00, -7.5538e-08],
         [-3.6260e-08,  3.3173e-08, -7.2906e-08,  1.0000e+00]],

        [[ 5.1754e-08,  1.0000e+00,  1.1234e-07, -4.6763e-08],
         [-1.0000e+00,  1.3011e-07, -1.1934e-08, -2.5277e-08],
         [-8.2055e-08,  5.1245e-08,  6.2276e-08,  1.0000e+00],
         [ 1.0933e-08, -7.0609e-08, -1.0000e+00,  1.5147e-08]],

        [[ 4.0813e-08,  4.1984e-08,  1.0000e+00, -7.7404e-09],
         [-7.8505e-09,  9.8508e-09, -5.7320e-08, -1.0000e+00],
         [-1.0000e+00,  1.7232e-08,  1.1296e-08, -3.2608e-08],
         [ 4.0355e-08,  1.0000e+00, -5.0336e-08,  6.4603e-09]],

        [[ 3.8183e-08,  8.0717e-08, -2.0061e-08,  1.0000e+00],
         [ 6.4894e-08, -2.9346e-08