## Quaternion PyTorch - Basic mechanisms

In [30]:
import torch
from qtorch import quaternion

### 1 - Quaternion tensors

A quaternion number is represented by:

$$
x = a + bi + cj + dk
$$

where $a$, $b$, $c$, and $d$ are real values, and $i$, $j$, $k$ are the imaginary parts. A `QuaternionTensor` extends the standard PyTorch `tensor` to handle quaternion values, by specifying the real and imaginary components during initialization:

In [31]:
# Simple scalar quaternion
x = quaternion.QuaternionTensor([0.0, 0.3, 0.4, 0.5])
print(x) # TODO: improve printing

tensor([0.0000, 0.3000, 0.4000, 0.5000])


In [32]:
# Two-dimensional vector
x = quaternion.QuaternionTensor(torch.rand(2, 4))
print(x)

tensor([[0.8958, 0.4577, 0.8571, 0.0426],
        [0.0478, 0.4696, 0.6397, 0.5274]])


All standard quaternion operations can be applied on the tensor (see `QuaternionTensor` for a full list):

In [33]:
# Conjugation
print(x.conj)

tensor([[ 0.8958, -0.4577, -0.8571, -0.0426],
        [ 0.0478, -0.4696, -0.6397, -0.5274]])


In [34]:
# Element-wise norm
print(x.norm)

tensor([[1.3223],
        [0.9540]])


In [35]:
# Element-wise angle
print(x.theta)

tensor([[0.8265],
        [1.5206]])


In [36]:
# Quaternion multiplication (Hamilton product)
print(x * x)

tensor([[-0.1434,  0.8200,  1.5357,  0.0763],
        [-0.9056,  0.0449,  0.0612,  0.0505]])


In [37]:
# Quaternion matrix multiplication
print(x.t() @ x)

tensor([[0.8048, 0.4324, 0.7985, 0.0634],
        [0.4324, 0.4300, 0.6927, 0.2671],
        [0.7985, 0.6927, 1.1439, 0.3739],
        [0.0634, 0.2671, 0.3739, 0.2799]])


Importantly, quaternion tensors and real-valued tensors are interoperable (real-valued tensors being casted to quaternion tensors with 0 imaginary parts):

In [None]:
# Not working yet :-(
print(x + torch.rand(2))

### 2 - Quaternion gradients

Gradients can be computed with the PyTorch autograd mechanisms:

In [51]:
x = quaternion.QuaternionTensor(torch.rand(2, 4))
y = x.norm.sum()
y.backward()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn