# Gradient check

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
sys.path.append("../Pytorch-Quaternion-Neural-Networks/core_qnn")
from quaternion_layers import *
import quaternion_ops
import functions
from layers import *
import convert_to_quaternion
import random
import os
import copy

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

torch.manual_seed(0)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
torch.manual_seed(0)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)
grayscale = transforms.Grayscale(num_output_channels=1)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
inputs, labels = next(iter(trainloader))
gray_image = grayscale(inputs)
inputs = torch.cat([inputs, gray_image],1)
targets = torch.ones(4).long()
criterion = nn.CrossEntropyLoss()

In [219]:
def assemble_weight(r, i, j, k):
    
    weight = torch.cat([torch.cat([r, -i, -j,  -k], dim=1),
                        torch.cat([i,  r, -k,   j], dim=1),
                        torch.cat([j,  k,  r,  -i], dim=1),
                        torch.cat([k, -j,  i,   r], dim=1)], dim = 0)
    
    return weight


def create_dummy_weight(chann):
    
    r, i, j, k = torch.full((chann, size//4), 0.5), torch.full((chann, size//4), 0.25),\
             torch.full((chann, size//4), 1.5), torch.full((chann, size//4), 2.5)

    weight = assemble_weight(r, i, j, k)
    
    return weight, r, i, j, k



In [209]:
size = 32*32*4
real_lin = nn.Linear(size, 8, bias = False)
quat_lin = QLinear(size, 8, bias = False)
parc_lin = QuaternionLinear(size, 8, bias = False)

weight, r, i, j, k = create_dummy_weight(2)

real_lin.weight = nn.Parameter(weight)
quat_lin.weight = nn.Parameter(weight)
                   
parc_lin.r_weight.data = nn.Parameter(r.t())
parc_lin.i_weight.data = nn.Parameter(i.t())
parc_lin.j_weight.data = nn.Parameter(j.t())
parc_lin.k_weight.data = nn.Parameter(k.t())

In [21]:
weight.shape

torch.Size([8, 4096])

In [210]:
x = inputs.view(4, 32*32*4)


set_seed(0)
out_real = real_lin(x)
loss_real = criterion(out_real, targets)
real_grad, = torch.autograd.grad(loss_real, real_lin.weight)

set_seed(0)
out_quat = quat_lin(x)
loss_quat = criterion(out_quat, targets)
quat_grad, = torch.autograd.grad(loss_quat, quat_lin.weight)

set_seed(0)
out_parc = parc_lin(x)
loss_parc = criterion(out_parc, targets)
r_grad, i_grad, j_grad, k_grad = torch.autograd.grad(
    loss_parc, [parc_lin.r_weight, parc_lin.i_weight, parc_lin.j_weight, parc_lin.i_weight]
)

In [211]:
print(
      ((out_parc - out_real).abs() < 1e-6).all().item(),                 # parcollet vs real
      ((out_quat - out_real).abs() < 1e-6).all().item(),                 # real vs ours
      sep = "\n"
)

True
True


In [212]:
parc_weight_grad = torch.cat([torch.cat([r_grad.t(), -i_grad.t(), -j_grad.t(),  -k_grad.t()], dim=1),
                         torch.cat([i_grad.t(),  r_grad.t(), -k_grad.t(),   j_grad.t()], dim=1),
                         torch.cat([j_grad.t(),  k_grad.t(),  r_grad.t(),  -i_grad.t()], dim=1),
                         torch.cat([k_grad.t(), -j_grad.t(),  i_grad.t(),   r_grad.t()], dim=1)], dim = 0)

print(
      ((real_grad - quat_grad).abs() < 1e-6).all().item(),
      ((real_grad - parc_weight_grad).abs() < 1e-6).all().item(),
      sep = "\n"
     )

True
False


In [213]:
r, i, j, k = torch.randn((2, 16//4, 2, 2)), torch.randn((2, 16//4, 2, 2)),\
             torch.randn((2, 16//4, 2, 2)), torch.randn((2, 16//4, 2, 2))

weight = torch.cat([torch.cat([r, -i, -j,  -k], dim=1),
                    torch.cat([i,  r, -k,   j], dim=1),
                    torch.cat([j,  k,  r,  -i], dim=1),
                    torch.cat([k, -j,  i,   r], dim=1)], dim = 0)

In [214]:
parc_weight_grad[0, :10]

tensor([-0.1989, -0.2087, -0.1950, -0.1911, -0.1862, -0.1911, -0.1901, -0.1636,
        -0.1656, -0.1940])

In [215]:
real_grad[0, :10]

tensor([-0.2074, -0.2064, -0.1887, -0.1946, -0.1926, -0.1877, -0.1858, -0.1740,
        -0.1799, -0.1985])

# Gradient check pt 2

In [236]:
x = torch.randn((4, 32, 32))
weight1, r, i, j, k = create_dummy_weight(1)
weight2, r_, i_, j_, k_  = create_dummy_weight(1)

In [237]:
weight1.requires_grad = True
weight2.requires_grad = True

weight_parc_1 = assemble_weight(r, i, j, k)

weight1.requires_grad = True
weight2.requires_grad = True

# first whole tensor as parameter then separate parts
weight_parc_1.requires_grad = True

# does it change anything if we require grad on every part?
r_.requires_grad = True
i_.requires_grad = True
j_.requires_grad = True
k_.requires_grad = True

weight_parc_2 = assemble_weight(r_, i_, j_, k_)

# weight for ours and parc is exactly the same
print((weight2.t() == weight1.t()).all().item())
print((weight_parc_1.t() == weight1.t()).all().item())
print((weight_parc_2.t() == cat_kernels_4_quaternion).all().item())

set_seed(0)
out_parc = torch.matmul(x.view(-1), weight_parc_1.t())
set_seed(0)
out_parc_ = torch.matmul(x.view(-1), weight_parc_2.t())


True
True
True


In [238]:
set_seed(0)
out_real = torch.matmul(x.view(-1), weight1.t())
set_seed(0)
out_quat = torch.matmul(x.view(-1), weight2.t())


In [239]:
print(
      ((out_real - out_quat).abs() < 1e-6).all().item(),
      ((out_real - out_parc).abs() < 1e-6).all().item(),
      ((out_real - out_parc_).abs() < 1e-6).all().item(),
      sep = "\n"
     )

True
True
True


In [240]:
out_real - out_quat

tensor([0., 0., 0., 0.], grad_fn=<SubBackward0>)

In [241]:
out_parc - out_real

tensor([0., 0., 0., 0.], grad_fn=<SubBackward0>)

In [242]:
out_parc - out_parc_

tensor([0., 0., 0., 0.], grad_fn=<SubBackward0>)

In [243]:
loss_real = out_real.mean()
loss_quat = out_quat.mean()
loss_parc = out_parc.mean()
loss_parc_ = out_parc_.mean()

real_grad, = torch.autograd.grad(loss_real, weight1)
quat_grad, = torch.autograd.grad(loss_quat, weight2)
parc_grad, = torch.autograd.grad(loss_parc, weight_parc_1)
parc_grad_r, parc_grad_i, parc_grad_j, parc_grad_k = torch.autograd.grad(loss_parc_, [r_, i_, j_, k_])



In [244]:
parc_grad_ = assemble_weight(parc_grad_r, parc_grad_i, parc_grad_j, parc_grad_k)

print(
      ((real_grad - quat_grad).abs() < 1e-6).all().item(),
      ((real_grad - parc_grad).abs() < 1e-6).all().item(),
      ((parc_grad - parc_grad_).abs() < 1e-6).all().item(),
      sep = "\n"
     )

True
True
False
