In [1]:
import torch
import torch.nn as nn
import numpy as np

import math
import time

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
from torch.nn import init


import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import imageio

In [2]:
torch.__version__

'1.8.1+cu102'

In [None]:
###############
## PHC LAYER ##
###############

class PHConvLayer(nn.Module):
    '''
    Parameterized Hypercomplex Convolutional (PHC) Layer.
    '''

  def __init__(self, n, in_features, out_features, kernel_size, padding=1, stride=1):
    super(PHMConvLayer, self).__init__()
    self.n = n
    self.in_features = in_features
    self.out_features = out_features
    self.padding = padding
    self.stride = stride
    

    self.bias = nn.Parameter(torch.Tensor(out_features))
    self.a = nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
    self.s = nn.Parameter(torch.nn.init.xavier_uniform_(
        torch.zeros((n, self.out_features//n, self.in_features//n, kernel_size, kernel_size))))
    self.weight = torch.zeros((self.out_features, self.in_features, kernel_size, kernel_size))
    self.kernel_size = kernel_size

    fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
    bound = 1 / math.sqrt(fan_in)
    init.uniform_(self.bias, -bound, bound)


  def kronecker_product1(self, a, b): 
    '''
    Faster implementation for sum of Kronecker products.
    '''
    
    siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-4:-2]))
    siz2 = torch.Size(torch.tensor(b.shape[-2:]))
    res = a.unsqueeze(-1).unsqueeze(-3).unsqueeze(-1).unsqueeze(-1) * b.unsqueeze(-4).unsqueeze(-6)
    siz0 = res.shape[:1]
    out = res.reshape(siz0 + siz1 + siz2)
    print(out.size())
    return out

  def kronecker_product2(self):
    '''
    Alternative implementation for sum of Kronecker products, SLOWER.
    '''
    
    H = torch.zeros((self.out_features, self.in_features, self.kernel_size, self.kernel_size))
    for i in range(self.n):
        kron_prod = torch.kron(self.a[i], self.s[i]).view(self.out_features, self.in_features, self.kernel_size, self.kernel_size)
        H = H + kron_prod
    return H


  def forward(self, input):
    self.weight = torch.sum(self.kronecker_product1(self.a, self.s), dim=0)
#     self.weight = self.kronecker_product2() # <-SLOWER
    input = input.type(dtype=self.weight.type())
    return F.conv2d(input, weight=self.weight, stride=self.stride, padding=self.padding)

  def extra_repr(self) -> str:
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None)
    
  def reset_parameters(self) -> None:
    init.kaiming_uniform_(self.a, a=math.sqrt(5))
    init.kaiming_uniform_(self.s, a=math.sqrt(5))
    fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
    bound = 1 / math.sqrt(fan_in)
    init.uniform_(self.bias, -bound, bound)