Comparison between Parcollet QNN (P-QNN) and this library (QNN), explicitly writing out the operation of a Linear layer with 2 inputs and 2 outputs in both libraries.

In [15]:
# Main PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F

In [16]:
# Import P-QNN in path
import sys
sys.path.append("../Pytorch-Quaternion-Neural-Networks/core_qnn")

In [17]:
# P-QNN import (this is the function used in the forward pass of QuaternionLinearAutograd)
from quaternion_ops import quaternion_linear

In [18]:
# QNN imports (we only use the basic class here)
from quaternion import QuaternionTensor

In [21]:
def init_values():
    # "Input": 0.1 - 0.2i + 0.3j - 0.4k, -0.5 + 0.6i - 0.7j + 0.8k
    x = torch.FloatTensor([[0.1, -0.2, 0.3, -0.4], [-0.5, 0.6, -0.7, 0.8]])

    # "Weight": 1.0 + 0.9i + 0.8j + 0.7k (and the opposite)
    r = torch.FloatTensor([[1.0, -1.0]]).requires_grad_(True)
    i = torch.FloatTensor([[0.9, -0.9]]).requires_grad_(True)
    j = torch.FloatTensor([[0.8, -0.8]]).requires_grad_(True)
    k = torch.FloatTensor([[0.7, -0.7]]).requires_grad_(True)

    return x, (r, i, j, k)

In [22]:
# Multiplication in P-QNN
x, (r, i, j, k) = init_values()
out_pqnn = quaternion_linear(x, r, i, j, k, bias=None)

In [23]:
# Gradient in P-QNN
(g_r_pqnn, g_i_pqnn, g_j_pqnn, g_k_pqnn) = torch.autograd.grad(out_pqnn.sum(), [r, i, j, k])

In [24]:
# Multiplication in QNN
x, (r, i, j, k) = init_values()
w = QuaternionTensor(torch.cat([r, i, j, k], dim=1))
out_qnn = F.linear(x, w, None)

In [25]:
(g_r_qnn, g_i_qnn, g_j_qnn, g_k_qnn) = torch.autograd.grad(out_qnn.sum(), [r, i, j, k])

In [26]:
print(out_pqnn)
print(out_qnn)

tensor([[ 0.3200, -0.3200, -0.6400,  0.6400,  0.6000, -0.6000,  0.1000, -0.1000],
        [-1.0400,  1.0400,  1.2800, -1.2800, -1.4000,  1.4000, -0.6600,  0.6600]],
       grad_fn=<MmBackward>)
tensor([[ 0.3200, -0.3200, -0.6400,  0.6400,  0.6000, -0.6000,  0.1000, -0.1000],
        [-1.0400,  1.0400,  1.2800, -1.2800, -1.4000,  1.4000, -0.6600,  0.6600]],
       grad_fn=<MmBackward>)


In [27]:
print(g_r_pqnn)
print(g_r_qnn)

tensor([[5.9605e-08, 5.9605e-08]])
tensor([[5.9605e-08, 5.9605e-08]])


In [28]:
print(g_i_pqnn)
print(g_i_qnn)

tensor([[-1.6000, -1.6000]])
tensor([[-1.6000, -1.6000]])


In [29]:
print(g_j_pqnn)
print(g_j_qnn)

tensor([[-8.9407e-08, -8.9407e-08]])
tensor([[-8.9407e-08, -8.9407e-08]])


In [30]:
print(g_k_pqnn)
print(g_k_qnn)

tensor([[0., 0.]])
tensor([[0., 0.]])
