### Tutorial: Parameterized Hypercomplex Multiplication (PHM) Layer

#### Author: Eleonora Grassucci

Original paper: Beyond Fully-Connected Layers with Quaternions: Parameterization of Hypercomplex Multiplications with 1/n Parameters.

Aston Zhang, Yi Tay, Shuai Zhang, Alvin Chan, Anh Tuan Luu, Siu Cheung Hui, Jie Fu.

[ArXiv link](https://arxiv.org/pdf/2102.08597.pdf).

In [3]:
# Imports

import numpy as np
import math
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
from torch.nn import init

In [4]:
# Check Pytorch version: torch.kron is available from 1.8.0
torch.__version__

'1.8.1+cu102'

In [5]:
# Define the PHM class

class PHM(nn.Module):
    '''
    Simple PHM Module, the only parameter is A, since S is passed from the trainset.
    '''

    def __init__(self, n, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.n = n
        A = torch.empty((n-1, n, n))
        self.A = nn.Parameter(A)
        self.kernel_size = kernel_size

    def forward(self, X, S):
        H = torch.zeros((self.n*self.kernel_size, self.n*self.kernel_size))
        
        # Sum of Kronecker products
        for i in range(n-1):
            H = H + torch.kron(self.A[i], S[i])
        return torch.matmul(X, H.T)

### Learn the Hamilton product between two pure quaternions

A pure quaternion is a quaternion with scalar part equal to 0.

In [6]:
# Setup the training set

x = torch.FloatTensor([0, 1, 2, 3]).view(4, 1) # Scalar part equal to 0
W = torch.FloatTensor([[0,-1,-1,-1], [1,0,-1,1], [1,1,0,-1], [1,-1,1,0]]) # Scalar parts equal to 0

y = torch.matmul(W, x)

num_examples = 1000
batch_size = 1

X = torch.zeros((num_examples, 16))
S = torch.zeros((num_examples, 16))
Y = torch.zeros((num_examples, 16))

for i in range(num_examples):
    x = torch.randint(low=-10, high=10, size=(12, ), dtype=torch.float)
    s = torch.randint(low=-10, high=10, size=(12, ), dtype=torch.float)
    
    s1, s2, s3, s4 = torch.FloatTensor([0]*4), s[0:4], s[4:8], s[8:12]
    s1 = s1.view(2,2)
    s2 = s2.view(2,2)
    s3 = s3.view(2,2)
    s4 = s4.view(2,2)

    s_1 = torch.cat([s1,-s2,-s3,-s4])
    s_2 = torch.cat([s2,s1,-s4,s3])
    s_3 = torch.cat([s3,s4,s1,-s2])
    s_4 = torch.cat([s4,-s3,s2,s1])

    W = torch.cat([s_1,s_2, s_3, s_4], dim=1) 
    x = torch.cat([torch.FloatTensor([0]*4), x])
    s = torch.cat([torch.FloatTensor([0]*4), s])
    x_mult = x.view(2, 8)
    y = torch.matmul(x_mult, W.T)    
    y = y.view(16, )

    X[i, :] = x
    S[i, :] = s
    Y[i, :] = y

X = torch.FloatTensor(X).view(num_examples, 16, 1)
S = torch.FloatTensor(S).view(num_examples, 16, 1)
Y = torch.FloatTensor(Y).view(num_examples, 16, 1)

data = torch.cat([X, S, Y], dim=2)
train_iter = torch.utils.data.DataLoader(data, batch_size=batch_size)

### Setup the test set

num_examples = 1
batch_size = 1

X = torch.zeros((num_examples, 16))
S = torch.zeros((num_examples, 16))
Y = torch.zeros((num_examples, 16))

for i in range(num_examples):
    x = torch.randint(low=-10, high=10, size=(12, ), dtype=torch.float)
    s = torch.randint(low=-10, high=10, size=(12, ), dtype=torch.float)
    
    s1, s2, s3, s4 = torch.FloatTensor([0]*4), s[0:4], s[4:8], s[8:12]
    s1 = s1.view(2,2)
    s2 = s2.view(2,2)
    s3 = s3.view(2,2)
    s4 = s4.view(2,2)

    s_1 = torch.cat([s1,-s2,-s3,-s4])
    s_2 = torch.cat([s2,s1,-s4,s3])
    s_3 = torch.cat([s3,s4,s1,-s2])
    s_4 = torch.cat([s4,-s3,s2,s1])

    W = torch.cat([s_1,s_2, s_3, s_4], dim=1) 
    x = torch.cat([torch.FloatTensor([0]*4), x])
    s = torch.cat([torch.FloatTensor([0]*4), s])
    x_mult = x.view(2, 8)
    y = torch.matmul(x_mult, W.T)    
    y = y.view(16, )

    X[i, :] = x
    S[i, :] = s
    Y[i, :] = y

X = torch.FloatTensor(X).view(num_examples, 16, 1)
S = torch.FloatTensor(S).view(num_examples, 16, 1)
Y = torch.FloatTensor(Y).view(num_examples, 16, 1)

data = torch.cat([X, S, Y], dim=2)
test_iter = torch.utils.data.DataLoader(data, batch_size=batch_size)

In [7]:
# Define training function

def train(net, lr, phm=True):
    # Squared loss
    loss = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    
    for epoch in range(5):
        for data in train_iter:
            optimizer.zero_grad()
            X = data[:, :, 0]
            S = data[:, 4:, 1]
            Y = data[:, :, 2]
            
            if phm:
                out = net(X.view(2, 8), S.view(3, 2, 2))
            else:
                out = net(X)
            
            l = loss(out, Y.view(2, 8))
            l.backward()
            optimizer.step()
        print(f'epoch {epoch + 1}, loss {float(l.sum() / batch_size):.6f}')

In [8]:
# Initialize model parameters
def weights_init_uniform(m):
    m.A.data.uniform_(-0.07, 0.07)
    
# Create layer instance
n = 4
phm_layer = PHM(n, kernel_size=2)
phm_layer.apply(weights_init_uniform)

# Train the model
train(phm_layer, 0.005)

epoch 1, loss 0.021605
epoch 2, loss 0.000000
epoch 3, loss 0.000000
epoch 4, loss 0.000000
epoch 5, loss 0.000000


In [9]:
# Check parameters of the layer require grad
for name, param in phm_layer.named_parameters():
    if param.requires_grad:
        print(name, param.data)

A tensor([[[-6.0884e-08,  1.0000e+00, -1.6100e-08,  2.6916e-08],
         [-1.0000e+00, -1.8684e-08, -2.1245e-08, -8.8355e-08],
         [-1.2780e-08,  1.2693e-07, -3.8119e-08,  1.0000e+00],
         [-1.0182e-07,  4.7619e-08, -1.0000e+00,  3.8946e-08]],

        [[ 1.5405e-08, -3.1784e-08,  1.0000e+00,  2.9003e-08],
         [-3.5486e-08, -3.5375e-08,  3.3766e-08, -1.0000e+00],
         [-1.0000e+00, -2.9093e-08, -5.3595e-08,  3.2789e-08],
         [ 6.2255e-09,  1.0000e+00,  3.7168e-08,  8.2059e-09]],

        [[-3.9100e-08, -5.8766e-09,  2.8090e-09,  1.0000e+00],
         [-1.5466e-07,  5.3471e-08,  1.0000e+00,  3.3222e-08],
         [ 3.3584e-08, -1.0000e+00, -6.5275e-08,  1.9724e-07],
         [-1.0000e+00, -3.0299e-08,  1.3472e-08, -2.8102e-08]]])


In [11]:
# Take a look at the convolution performed on the test set

for data in test_iter:
    X = data[:, :, 0]
    S = data[:, 4:, 1]
    Y = data[:, :, 2]

    y_phm = phm_layer(X.view(2, 8), S.view(3, 2, 2))
    
    print('Hamilton product result from test set:\n', Y.view(2, 8))
    print('Performing Hamilton product learned by PHM:\n', y_phm)

Hamilton product result from test set:
 tensor([[  82., -198.,    2.,   70.,   -4.,   54., -160.,   52.],
        [  51.,   45., -133.,   86., -103.,  225., -125.,   92.]])
Performing Hamilton product learned by PHM:
 tensor([[  82.0000, -198.0000,    2.0000,   70.0000,   -4.0000,   54.0000,
         -160.0000,   52.0000],
        [  51.0000,   45.0000, -133.0000,   86.0000, -103.0000,  225.0001,
         -125.0000,   92.0000]], grad_fn=<MmBackward>)


In [13]:
# Check the PHC layer have learnt the proper algebra for the marix A

W = torch.FloatTensor([[0,-1,-1,-1], [1,0,-1,1], [1,1,0,-1], [1,-1,1,0]])

print('Ground-truth Hamilton product matrix:\n', W)
print()
print('Learned A in PHM:\n', phm_layer.A)
print()
print('Learned A sum in PHM:\n', sum(phm_layer.A).T)

Ground-truth Hamilton product matrix:
 tensor([[ 0., -1., -1., -1.],
        [ 1.,  0., -1.,  1.],
        [ 1.,  1.,  0., -1.],
        [ 1., -1.,  1.,  0.]])

Learned A in PHM:
 Parameter containing:
tensor([[[-6.0884e-08,  1.0000e+00, -1.6100e-08,  2.6916e-08],
         [-1.0000e+00, -1.8684e-08, -2.1245e-08, -8.8355e-08],
         [-1.2780e-08,  1.2693e-07, -3.8119e-08,  1.0000e+00],
         [-1.0182e-07,  4.7619e-08, -1.0000e+00,  3.8946e-08]],

        [[ 1.5405e-08, -3.1784e-08,  1.0000e+00,  2.9003e-08],
         [-3.5486e-08, -3.5375e-08,  3.3766e-08, -1.0000e+00],
         [-1.0000e+00, -2.9093e-08, -5.3595e-08,  3.2789e-08],
         [ 6.2255e-09,  1.0000e+00,  3.7168e-08,  8.2059e-09]],

        [[-3.9100e-08, -5.8766e-09,  2.8090e-09,  1.0000e+00],
         [-1.5466e-07,  5.3471e-08,  1.0000e+00,  3.3222e-08],
         [ 3.3584e-08, -1.0000e+00, -6.5275e-08,  1.9724e-07],
         [-1.0000e+00, -3.0299e-08,  1.3472e-08, -2.8102e-08]]],
       requires_grad=True)

Learned 