## Quaternion PyTorch - Basic mechanisms

In [1]:
import torch
import sys
sys.path.append("..")
from qtorch import quaternion

### 1 - Quaternion tensors

A quaternion number is represented by:

$$
x = a + bi + cj + dk
$$

where $a$, $b$, $c$, and $d$ are real values, and $i$, $j$, $k$ are the imaginary parts. A `QuaternionTensor` extends the standard PyTorch `tensor` to handle quaternion values, by specifying the real and imaginary components during initialization:

In [2]:
# Simple scalar quaternion
x = quaternion.QuaternionTensor([0.0, 0.3, 0.4, 0.5])
print(x) # TODO: improve printing

tensor([0.0000, 0.3000, 0.4000, 0.5000])


In [3]:
# Two-dimensional vector
x = quaternion.QuaternionTensor(torch.rand(2, 4))
print(x)

tensor([[0.0340, 0.6425, 0.0317, 0.1571],
        [0.7023, 0.0532, 0.7900, 0.1253]])


All standard quaternion operations can be applied on the tensor (see `QuaternionTensor` for a full list):

In [4]:
# Conjugation
print(x.conj)

tensor([[ 0.0340, -0.6425, -0.0317, -0.1571],
        [ 0.7023, -0.0532, -0.7900, -0.1253]])


In [5]:
# Element-wise norm
print(x.norm)

tensor([[0.6631],
        [1.0657]])


In [6]:
# Element-wise angle
print(x.theta)

tensor([[1.5195],
        [0.8514]])


In [7]:
# Quaternion multiplication (Hamilton product)
print(x * x)

tensor([[-0.4374,  0.0437,  0.0022,  0.0107],
        [-0.1495,  0.0747,  1.1096,  0.1760]])


In [8]:
# Quaternion matrix multiplication
print(x.t() @ x)

tensor([[0.4943, 0.0592, 0.5559, 0.0933],
        [0.0592, 0.4157, 0.0624, 0.1076],
        [0.5559, 0.0624, 0.6251, 0.1040],
        [0.0933, 0.1076, 0.1040, 0.0404]])


Importantly, quaternion tensors and real-valued tensors are interoperable (real-valued tensors being casted to quaternion tensors with 0 imaginary parts):

In [9]:
# Quaternion scalar multiplication
print(x * torch.rand(2))

tensor([[0.0326, 0.6150, 0.0304, 0.1504],
        [0.3891, 0.0295, 0.4377, 0.0694]])


### 2 - Quaternion gradients

Gradients can be computed with the PyTorch autograd mechanisms:

In [10]:
x = torch.nn.Parameter(quaternion.QuaternionTensor(torch.rand(2, 4)))
y = x.norm().sum()
y.backward()