## Quaternion PyTorch - Basic mechanisms

In [1]:
import torch
import sys
sys.path.append("..")
from qtorch import quaternion

### 1 - Quaternion tensors

A quaternion number is represented by:

$$
x = a + bi + cj + dk
$$

where $a$, $b$, $c$, and $d$ are real values, and $i$, $j$, $k$ are the imaginary parts. A `QuaternionTensor` extends the standard PyTorch `tensor` to handle quaternion values, by specifying the real and imaginary components during initialization:

In [2]:
# Simple scalar quaternion
x = quaternion.QuaternionTensor([0.0, 0.3, 0.4, 0.5])
x

real part: tensor([0.])
imaginary part (i): tensor([0.3000])
imaginary part (i): tensor([0.4000])
imaginary part (j): tensor([0.5000])

In [3]:
# Two-dimensional vector
x = quaternion.QuaternionTensor(torch.rand(2, 4))
print(x)

tensor([[0.4292, 0.2528, 0.2074, 0.1736],
        [0.1889, 0.0865, 0.0153, 0.2673]])


All standard quaternion operations can be applied on the tensor (see `QuaternionTensor` for a full list):

In [4]:
# Conjugation
print(x.conj)

tensor([[ 0.4292, -0.2528, -0.2074, -0.1736],
        [ 0.1889, -0.0865, -0.0153, -0.2673]])


In [5]:
# Element-wise norm
print(x.norm)

tensor([[0.5668],
        [0.3389]])


In [6]:
# Element-wise angle
print(x.theta)

tensor([[0.7118],
        [0.9795]])


In [7]:
# Quaternion multiplication (Hamilton product)
print(x * x)

tensor([[ 0.0471,  0.2170,  0.1780,  0.1490],
        [-0.0435,  0.0327,  0.0058,  0.1010]])


In [8]:
# Quaternion matrix multiplication
print(x.t() @ x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x4 and 2x4)

Importantly, quaternion tensors and real-valued tensors are interoperable (real-valued tensors being casted to quaternion tensors with 0 imaginary parts):

In [None]:
# Quaternion scalar multiplication
print(x * torch.rand(2))

### 2 - Quaternion gradients

Gradients can be computed with the PyTorch autograd mechanisms:

In [None]:
x = torch.nn.Parameter(quaternion.QuaternionTensor(torch.rand(2, 4)))
y = x.norm().sum()
y.backward()