In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

In [2]:
class Conv2dSDK(torch.nn.Module):
  def __init__(self, kernel, pw_width, pw_height):
    super().__init__()

    if kernel.shape[2] != kernel.shape[3]:
      raise ValueError("Kernel is not square. Rectangular Kernel not supported.")

    if pw_height < kernel.shape[2] or pw_width < kernel.shape[3]:
      raise ValueError("Parallel window is smaller than the kernel.")

    if pw_height == 3 and pw_width == 3:
      print("WARNING: Parallel window size is 3. Use Conv2dIm2col instead.")

    self.kernel = kernel
    self.out_channels = kernel.shape[0]
    self.in_channels = kernel.shape[1]
    self.kernel_size = kernel.shape[2]

    self.pw_width = pw_width
    self.pw_height = pw_height

    self.weight_map = torch.nn.Parameter(self._gen_SDK_mapping())

  def _ordered_pairs_sum(self, x):
    a = torch.arange(x + 1)
    b = x - a
    pairs = torch.stack((a, b), dim=1)
    return pairs

  def _gen_SDK_mapping(self):
    h_diff = self.pw_height - self.kernel_size
    w_diff = self.pw_width - self.kernel_size

    ver_pads = self._ordered_pairs_sum(h_diff)
    hor_pads = self._ordered_pairs_sum(w_diff)

    SDK_mapping = []

    for i in range(len(ver_pads)):
      for j in range(len(hor_pads)):
        p2d = (hor_pads[j,0], hor_pads[j,1], ver_pads[i,0], ver_pads[i,1])
        padded_kernel =  F.pad(self.kernel, p2d, mode='constant', value=0)
        flat_kernel = padded_kernel.view(self.out_channels, -1)

        SDK_mapping.append(flat_kernel)

    SDK_mapping = torch.concat(SDK_mapping)

    return SDK_mapping

  def _forward(self, x):
    num, depth, height, width = x.shape

    stride_ver = self.pw_height - self.kernel_size + 1
    stride_hor = self.pw_width  - self.kernel_size + 1

    pad_ver = (height + 2 - self.pw_height) % stride_ver
    pad_hor = (width  + 2 - self.pw_width)  % stride_hor

    slide_ver = math.ceil((height + 2 - self.pw_height) / stride_ver) + 1
    slide_hor = math.ceil((width  + 2 - self.pw_width ) / stride_hor) + 1


    padded_x = F.pad(x, (1, 1 + pad_hor, 1, 1 + pad_ver), 
                     mode='constant', value=0)

    flat_windows = F.unfold(padded_x, 
                            kernel_size=(self.pw_height, self.pw_width), 
                            stride=(stride_ver, stride_hor)).transpose(1,2)

    lin_out = F.linear(flat_windows, self.weight_map)
    # print(lin_out.shape)

    lin_out = lin_out.reshape(num, slide_ver, slide_hor, 
                              self.pw_height - self.kernel_size + 1, 
                              self.pw_width  - self.kernel_size + 1, self.out_channels)
    # print(lin_out.shape)

    lin_out = lin_out.transpose(2,3)
    lin_out = lin_out.reshape(num, 
                              height+int(pad_ver/2), 
                              width+int(pad_hor/2), 
                              self.out_channels)
    lin_out = lin_out.transpose(3,1).transpose(3,2)
    print(lin_out.shape)
    lin_out = lin_out[:,:,:height,:width]
    return lin_out

  def forward(self, input):
    return self._forward(input)

  def string(self):
    return 'testing'

In [32]:
class Conv2dSDK_QR(Conv2dSDK):
  def __init__(self, kernel, pw_width, pw_height, rank):
    super().__init__(kernel, pw_width, pw_height)
    self.rank = rank
    Q, R = self._SVD()
    self.original_weight_map = torch.tensor(self.weight_map)
    self.Q = torch.nn.Parameter(torch.tensor(Q))
    self.R = torch.nn.Parameter(torch.tensor(R))
    self.weight_map = torch.nn.Parameter(torch.matmul(self.Q, self.R))
    # self.new_weight_map = torch.randn(img_num, input_channel, img_width, img_width)

  def _SVD(self):
    u, s, vh = np.linalg.svd(self.weight_map.cpu().detach().numpy(), full_matrices=False)
    u_t = u[:,0:self.rank]
    s_t = np.diag(s[:self.rank])
    v_t = vh[:self.rank,:]
    # print(u_t.shape)
    # print(s_t.shape)
    # print(v_t.shape)
    Q = u_t@s_t
    R = v_t
    # print(Q.shape)
    # print(R.shape)
    return Q, R

In [17]:
def test_script():
  # gen random data
  img_num = 1
  img_width = 8
  input_channel = 32
  kernel_size = 3
  output_channel = 64

  # create a 1D tensor with values ranging from 0 to 8*8*64-1
  # img = torch.arange(img_num*img_width*img_width*input_channel)
  # img = img.reshape(img_num, input_channel, img_width, img_width)
  img = torch.randn(img_num, input_channel, img_width, img_width)

  # create a 4D random tensor
  kernel = torch.randn(output_channel, input_channel, kernel_size, kernel_size)
  # kernel2 = torch.randn(output_channel, input_channel, kernel_size, kernel_size)

  my_conv1 = Conv2dSDK_QR(rank=40, kernel=kernel, pw_width=3, pw_height=3)
  lin_out = my_conv1(img)
  output = F.conv2d(img, kernel, padding=1)

  # See that the two operation is identical
  l1_norm = torch.norm(lin_out - output, p=1)
  print(l1_norm)