# Quaternion Neural networks

## Introduction to quaternions

In [1]:
import torch
from quaternion import QuaternionTensor

Quaternions are hyper-complex numbers very similar to the well-known complex numbers but composed of 3 imaginary parts:


\begin{equation}
q = a + bi + cj + dk
\end{equation}

The bases are related by the following equation:

\begin{equation}
i^2 = j^2 = k^2 = ijk = -1
\end{equation}

What makes them a desirable choice for neural networks? their most important operation, the *Hamilton product*. In fact, the product of two quaternions does not commute and is expressed by the formula:

\begin{gather}
q_1 \otimes q_2 &=&(a_1a_2 - b_1b_2 - c_1c_2 - d_1d_2)  \\&+& (a_1b_2 + b_1a_2 + c_1d_2 - d_1c_2)i  \\&+& (a_1c_2 - b_1d_2 + c_1a_2 + d_1b_2)j  \\&+& (a_1d_2 + b_1c_2 - c_1b_2 + d_1a_2)k
\end{gather}


## Quaternions and neural networks

The Hamilton product defines a way to represent the product of two quaternion as a real-valued matrix multiplication. In fact the quaternion can be rewritten (not uniquely) as:

\begin{bmatrix}
a & -b & -c & -d\\
b & a & -d & c \\
c & d & a & -b \\
d & -c & b & a
\end{bmatrix}
 

In [2]:
def assemble_weight(a, b, c, d):
    
    a = a.transpose(1,0)
    b = b.transpose(1,0)
    c = c.transpose(1,0)
    d = d.transpose(1,0)
    
    return torch.cat([torch.cat([a, -b, -c, -d], dim=1),
                      torch.cat([b,  a, -d,  c], dim=1),
                      torch.cat([c,  d,  a, -b], dim=1),
                      torch.cat([d, -c,  b,  a], dim=1)], dim = 0)

The `quaternion` module handles all the various quaternion operations so that the user can manage them as you would normally do for real tensors.

In [3]:
chann_in = 4
chann_out = 20
val1 = 0.5
val2 = 1
val3 = 1.5
val4 = 2

r = torch.full((chann_out, chann_in//4), val1)
i = torch.full_like(r, val2)
j = torch.full_like(r, val3)
k = torch.full_like(r, val4)

quaternion_1 = assemble_weight(r, i, j, k)
quaternion_2 = QuaternionTensor([r, i, j, k], real_tensor=True)

In [4]:
a, _, _, _ = quaternion_2.chunk()

In [7]:
quaternion_2

tensor([[ 0.5000,  0.5000,  0.5000,  0.5000,  0.5000,  0.5000,  0.5000,  0.5000,
          0.5000,  0.5000,  0.5000,  0.5000,  0.5000,  0.5000,  0.5000,  0.5000,
          0.5000,  0.5000,  0.5000,  0.5000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.5000, -1.5000, -1.5000, -1.5000, -1.5000, -1.5000, -1.5000, -1.5000,
         -1.5000, -1.5000, -1.5000, -1.5000, -1.5000, -1.5000, -1.5000, -1.5000,
         -1.5000, -1.5000, -1.5000, -1.5000, -2.0000, -2.0000, -2.0000, -2.0000,
         -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000,
         -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000],
        [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000, 

In [9]:
quaternion_2.a

tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000]])