## Quaternion PyTorch - Training a QNN

In [1]:
import torch
from htorch import quaternion, layers, utils

### Step 1 - Loading the data

We provide a `collate_fn` to convert any standard image dataset to a quaternion format (by using the RGB values as imaginary components and the greyscale version as real component).

In [2]:
from torchvision.datasets import CIFAR10
from torchvision import transforms

In [4]:
# Standard loading for the CIFAR-10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
data = CIFAR10(root='data', train=True, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz
170499072it [06:02, 470083.88it/s]                               
Extracting data\cifar-10-python.tar.gz to data


In [5]:
# Batch the data using a custom collate_fn function to convert to quaternion-valued images
loader = torch.utils.data.DataLoader(data, batch_size=8, shuffle=True, \
    collate_fn=utils.convert_data_for_quaternion)

In [6]:
xb, yb = next(iter(loader))
print(xb.shape) # We now have 4 input channels as needed

torch.Size([8, 4, 32, 32])


### Step 2 - Building the QNN

We use a simple QNN with three convolutional blocks with split-ReLU activations.

In [7]:
model = torch.nn.Sequential(
    layers.QConv2d(1, 20, kernel_size=10, bias=True), # We only have 1 channel in terms of quaternions
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2), # Max-pool is okay because it acts on the channels
    layers.QConv2d(20, 20, kernel_size=10, bias=True),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),
    torch.nn.Flatten(),
    layers.QLinear(20, 10),
    layers.QuaternionToReal(10), # Take the absolute value before the softmax
)

In [8]:
# Test the model is working correctly
model(xb).shape

torch.Size([8, 10])

### Step 3 - Training loop

At this point, everything is classical PyTorch:
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [8]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [9]:
for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

[1,  2000] loss: 1.733
[1,  4000] loss: 1.468
[1,  6000] loss: 1.393
[2,  2000] loss: 1.240
[2,  4000] loss: 1.218
[2,  6000] loss: 1.189
Finished Training


### Converting an existing nn.Module

Using the new [torch.fx](https://pytorch.org/docs/stable/fx.html) functionals, we can also convert an existing PyTorch's `nn.Module` into a quaternion-valued one, provided all shapes are divisible by 4:

In [15]:
# This model is similar to the previous one, but all dimensions are multiplied by 4
model = torch.nn.Sequential(
    torch.nn.Conv2d(4, 80, kernel_size=10, bias=True),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),
    torch.nn.Conv2d(80, 80, kernel_size=10, bias=True),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),
    torch.nn.Flatten(),
    torch.nn.Linear(80, 40),
)

In [16]:
# Convert to a QNN and run on the previous set of images
utils.convert_to_quaternion(model)(xb).shape

torch.Size([8, 40])