# Gradient check

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
sys.path.append("../Pytorch-Quaternion-Neural-Networks/core_qnn")
from quaternion_layers import *
import quaternion_ops
import functions
from layers import *
import convert_to_quaternion
import random
import os

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

torch.manual_seed(0)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
torch.manual_seed(0)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)
grayscale = transforms.Grayscale(num_output_channels=1)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
inputs, labels = next(iter(trainloader))
gray_image = grayscale(inputs)
inputs = torch.cat([inputs, gray_image],1)
targets = torch.ones(4).long()
criterion = nn.CrossEntropyLoss()

In [4]:
size = 32*32*4
real_lin = nn.Linear(size, 8, bias = False)
quat_lin = QLinear(size, 8, bias = False)
parc_lin = QuaternionLinear(size, 8, bias = False)

r, i, j, k = torch.full((2, size//4), 0.5), torch.full((2, size//4), 0.25),\
             torch.full((2, size//4), 1.5), torch.full((2, size//4), 2.5)

weight = torch.cat([torch.cat([r, -i, -j,  -k], dim=1),
                    torch.cat([i,  r, -k,   j], dim=1),
                    torch.cat([j,  k,  r,  -i], dim=1),
                    torch.cat([k, -j,  i,   r], dim=1)], dim = 0)

real_lin.weight = nn.Parameter(weight)
quat_lin.weight = nn.Parameter(weight)
                   
parc_lin.r_weight.data = nn.Parameter(r.t())
parc_lin.i_weight.data = nn.Parameter(i.t())
parc_lin.j_weight.data = nn.Parameter(j.t())
parc_lin.k_weight.data = nn.Parameter(k.t())

In [5]:
x = inputs.view(4, 32*32*4)

set_seed(0)
out_real = real_lin(x)

optimizer_real = optim.SGD(real_lin.parameters(), lr=0.1, momentum=0)
loss_real = criterion(out_real, targets)
loss_real.backward()
optimizer_real.step()

set_seed(0)
out_quat = quat_lin(x)

optimizer_real.zero_grad() 
optimizer_quat = optim.SGD(quat_lin.parameters(), lr=0.1, momentum=0)
loss_quat = criterion(out_quat, targets)
loss_quat.backward()
optimizer_quat.step()

set_seed(0)
out_parc = parc_lin(x)

optimizer_quat.zero_grad() 
optimizer_parc = optim.SGD(parc_lin.parameters(), lr=0.1, momentum=0)
loss_parc = criterion(out_parc, targets)
loss_parc.backward()
optimizer_parc.step()

In [6]:
print(
      (out_parc == out_real).all().item(),                 # parcollet vs real
      (out_real == out_quat).all().item(),                 # real vs ours
      sep = "\n"
)

True
False


In [7]:
parc_weight_grad = torch.cat([torch.cat([parc_lin.r_weight.t(), -parc_lin.i_weight.t(), -parc_lin.j_weight.t(),  -parc_lin.k_weight.t()], dim=1),
                         torch.cat([parc_lin.i_weight.t(),  parc_lin.r_weight.t(), -parc_lin.k_weight.t(),   parc_lin.j_weight.t()], dim=1),
                         torch.cat([parc_lin.j_weight.t(),  parc_lin.k_weight.t(),  parc_lin.r_weight.t(),  -parc_lin.i_weight.t()], dim=1),
                         torch.cat([parc_lin.k_weight.t(), -parc_lin.j_weight.t(),  parc_lin.i_weight.t(),   parc_lin.r_weight.t()], dim=1)], dim = 0)

print(
      (real_lin.weight.grad == quat_lin.weight.grad).all().item(),
      (real_lin.weight.grad == parc_weight_grad).all().item(),
      sep = "\n"
     )

True
False
